SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/mathematics/Math.h> 00013 00014 using namespace shogun; 00015 00016 CROCEvaluation::~CROCEvaluation() 00017 { 00018 } 00019 00020 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00021 { 00022 return evaluate_roc(predicted,ground_truth); 00023 } 00024 00025 float64_t CROCEvaluation::evaluate_roc(CLabels* predicted, CLabels* ground_truth) 00026 { 00027 ASSERT(predicted && ground_truth); 00028 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels()); 00029 ASSERT(predicted->get_label_type()==LT_BINARY); 00030 ASSERT(ground_truth->get_label_type()==LT_BINARY); 00031 ground_truth->ensure_valid(); 00032 00033 // assume threshold as negative infinity 00034 float64_t threshold = CMath::ALMOST_NEG_INFTY; 00035 // false positive rate 00036 float64_t fp = 0.0; 00037 // true positive rate 00038 float64_t tp=0.0; 00039 00040 int32_t i; 00041 // total number of positive labels in predicted 00042 int32_t pos_count=0; 00043 int32_t neg_count=0; 00044 00045 // initialize number of labels and labels 00046 SGVector<float64_t> orig_labels(predicted->get_num_labels()); 00047 int32_t length = orig_labels.vlen; 00048 for (i=0; i<length; i++) 00049 orig_labels[i] = predicted->get_confidence(i); 00050 float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length); 00051 00052 // get sorted indexes 00053 int32_t* idxs = SG_MALLOC(int32_t, length); 00054 for(i=0; i<length; i++) 00055 idxs[i] = i; 00056 00057 CMath::qsort_backward_index(labels,idxs,length); 00058 00059 // number of different predicted labels 00060 int32_t diff_count=1; 00061 00062 // get number of different labels 00063 for (i=0; i<length-1; i++) 00064 { 00065 if (labels[i] != labels[i+1]) 00066 diff_count++; 00067 } 00068 00069 SG_FREE(labels); 00070 00071 // initialize graph and auROC 00072 m_ROC_graph = SGMatrix<float64_t>(2,diff_count+1); 00073 m_thresholds = SGVector<float64_t>(length); 00074 m_auROC = 0.0; 00075 00076 // get total numbers of positive and negative labels 00077 for(i=0; i<length; i++) 00078 { 00079 if (ground_truth->get_confidence(i) >= 0) 00080 pos_count++; 00081 else 00082 neg_count++; 00083 } 00084 00085 // assure both number of positive and negative examples is >0 00086 ASSERT(pos_count>0 && neg_count>0); 00087 00088 int32_t j = 0; 00089 float64_t label; 00090 00091 // create ROC curve and calculate auROC 00092 for(i=0; i<length; i++) 00093 { 00094 label = predicted->get_confidence(idxs[i]); 00095 00096 if (label != threshold) 00097 { 00098 threshold = label; 00099 m_ROC_graph[2*j] = fp/neg_count; 00100 m_ROC_graph[2*j+1] = tp/pos_count; 00101 j++; 00102 } 00103 00104 m_thresholds[i]=threshold; 00105 00106 if (ground_truth->get_confidence(idxs[i]) > 0) 00107 tp+=1.0; 00108 else 00109 fp+=1.0; 00110 } 00111 00112 // add (1,1) to ROC curve 00113 m_ROC_graph[2*diff_count] = 1.0; 00114 m_ROC_graph[2*diff_count+1] = 1.0; 00115 00116 // calc auROC using area under curve 00117 m_auROC = CMath::area_under_curve(m_ROC_graph.matrix,diff_count+1,false); 00118 00119 m_computed = true; 00120 00121 return m_auROC; 00122 } 00123 00124 SGMatrix<float64_t> CROCEvaluation::get_ROC() 00125 { 00126 if (!m_computed) 00127 SG_ERROR("Uninitialized, please call evaluate first"); 00128 00129 return m_ROC_graph; 00130 } 00131 00132 SGVector<float64_t> CROCEvaluation::get_thresholds() 00133 { 00134 if (!m_computed) 00135 SG_ERROR("Uninitialized, please call evaluate first"); 00136 00137 return m_thresholds; 00138 } 00139 00140 float64_t CROCEvaluation::get_auROC() 00141 { 00142 if (!m_computed) 00143 SG_ERROR("Uninitialized, please call evaluate first"); 00144 00145 return m_auROC; 00146 }