SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/PRCEvaluation.h> 00012 #include <shogun/labels/RegressionLabels.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 #include <shogun/mathematics/Math.h> 00015 00016 using namespace shogun; 00017 00018 CPRCEvaluation::~CPRCEvaluation() 00019 { 00020 } 00021 00022 float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00023 { 00024 ASSERT(predicted && ground_truth); 00025 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels()); 00026 ASSERT(predicted->get_label_type()==LT_BINARY); 00027 ASSERT(ground_truth->get_label_type()==LT_BINARY); 00028 ground_truth->ensure_valid(); 00029 00030 // number of true positive examples 00031 float64_t tp = 0.0; 00032 int32_t i; 00033 00034 // total number of positive labels in predicted 00035 int32_t pos_count=0; 00036 00037 // initialize number of labels and labels 00038 SGVector<float64_t> orig_labels = predicted->get_confidences(); 00039 int32_t length = orig_labels.vlen; 00040 float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length); 00041 00042 // get indexes for sort 00043 int32_t* idxs = SG_MALLOC(int32_t, length); 00044 for(i=0; i<length; i++) 00045 idxs[i] = i; 00046 00047 // sort indexes by labels ascending 00048 CMath::qsort_backward_index(labels,idxs,length); 00049 00050 // clean and initialize graph and auPRC 00051 SG_FREE(labels); 00052 m_PRC_graph = SGMatrix<float64_t>(2,length); 00053 m_thresholds = SGVector<float64_t>(length); 00054 m_auPRC = 0.0; 00055 00056 // get total numbers of positive and negative labels 00057 for (i=0; i<length; i++) 00058 { 00059 if (ground_truth->get_confidence(i) > 0) 00060 pos_count++; 00061 } 00062 00063 // assure number of positive examples is >0 00064 ASSERT(pos_count>0); 00065 00066 // create PRC curve 00067 for (i=0; i<length; i++) 00068 { 00069 // update number of true positive examples 00070 if (ground_truth->get_confidence(idxs[i]) > 0) 00071 tp += 1.0; 00072 00073 // precision (x) 00074 m_PRC_graph[2*i] = tp/float64_t(i+1); 00075 // recall (y) 00076 m_PRC_graph[2*i+1] = tp/float64_t(pos_count); 00077 00078 m_thresholds[i]= predicted->get_confidence(idxs[i]); 00079 } 00080 00081 // calc auRPC using area under curve 00082 m_auPRC = CMath::area_under_curve(m_PRC_graph.matrix,length,true); 00083 00084 // set computed indicator 00085 m_computed = true; 00086 00087 return m_auPRC; 00088 } 00089 00090 SGMatrix<float64_t> CPRCEvaluation::get_PRC() 00091 { 00092 if (!m_computed) 00093 SG_ERROR("Uninitialized, please call evaluate first"); 00094 00095 return m_PRC_graph; 00096 } 00097 00098 SGVector<float64_t> CPRCEvaluation::get_thresholds() 00099 { 00100 if (!m_computed) 00101 SG_ERROR("Uninitialized, please call evaluate first"); 00102 00103 return m_thresholds; 00104 } 00105 00106 float64_t CPRCEvaluation::get_auPRC() 00107 { 00108 if (!m_computed) 00109 SG_ERROR("Uninitialized, please call evaluate first"); 00110 00111 return m_auPRC; 00112 }