SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/evaluation/MulticlassOVREvaluation.h> 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/evaluation/PRCEvaluation.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 #include <shogun/mathematics/Statistics.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassOVREvaluation::CMulticlassOVREvaluation() : 00019 CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0) 00020 { 00021 } 00022 00023 CMulticlassOVREvaluation::CMulticlassOVREvaluation(CBinaryClassEvaluation* binary_evaluation) : 00024 CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0) 00025 { 00026 set_binary_evaluation(binary_evaluation); 00027 } 00028 00029 CMulticlassOVREvaluation::~CMulticlassOVREvaluation() 00030 { 00031 SG_UNREF(m_binary_evaluation); 00032 if (m_graph_results) 00033 { 00034 for (int32_t i=0; i<m_num_graph_results; i++) 00035 m_graph_results[i].~SGMatrix<float64_t>(); 00036 SG_FREE(m_graph_results); 00037 } 00038 } 00039 00040 float64_t CMulticlassOVREvaluation::evaluate(CLabels* predicted, CLabels* ground_truth) 00041 { 00042 ASSERT(m_binary_evaluation); 00043 ASSERT(predicted); 00044 ASSERT(ground_truth); 00045 int32_t n_labels = predicted->get_num_labels(); 00046 ASSERT(n_labels); 00047 CMulticlassLabels* predicted_mc = (CMulticlassLabels*)predicted; 00048 CMulticlassLabels* ground_truth_mc = (CMulticlassLabels*)ground_truth; 00049 int32_t n_classes = predicted_mc->get_multiclass_confidences(0).size(); 00050 ASSERT(n_classes>0); 00051 m_last_results = SGVector<float64_t>(n_classes); 00052 00053 SGMatrix<float64_t> all(n_labels,n_classes); 00054 for (int32_t i=0; i<n_labels; i++) 00055 { 00056 SGVector<float64_t> confs = predicted_mc->get_multiclass_confidences(i); 00057 for (int32_t j=0; j<n_classes; j++) 00058 { 00059 all(i,j) = confs[j]; 00060 } 00061 } 00062 if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation) || dynamic_cast<CPRCEvaluation*>(m_binary_evaluation)) 00063 { 00064 for (int32_t i=0; i<m_num_graph_results; i++) 00065 m_graph_results[i].~SGMatrix<float64_t>(); 00066 SG_FREE(m_graph_results); 00067 m_graph_results = SG_MALLOC(SGMatrix<float64_t>, n_classes); 00068 m_num_graph_results = n_classes; 00069 } 00070 for (int32_t c=0; c<n_classes; c++) 00071 { 00072 CLabels* pred = new CBinaryLabels(SGVector<float64_t>(all.get_column_vector(c),n_labels,false)); 00073 SGVector<float64_t> gt_vec(n_labels); 00074 for (int32_t i=0; i<n_labels; i++) 00075 { 00076 if (ground_truth_mc->get_label(i)==c) 00077 gt_vec[i] = +1.0; 00078 else 00079 gt_vec[i] = -1.0; 00080 } 00081 CLabels* gt = new CBinaryLabels(gt_vec); 00082 m_last_results[c] = m_binary_evaluation->evaluate(pred, gt); 00083 00084 if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation)) 00085 { 00086 new (&m_graph_results[c]) SGMatrix<float64_t>(); 00087 m_graph_results[c] = ((CROCEvaluation*)m_binary_evaluation)->get_ROC(); 00088 } 00089 if (dynamic_cast<CPRCEvaluation*>(m_binary_evaluation)) 00090 { 00091 new (&m_graph_results[c]) SGMatrix<float64_t>(); 00092 m_graph_results[c] = ((CPRCEvaluation*)m_binary_evaluation)->get_PRC(); 00093 } 00094 } 00095 return CStatistics::mean(m_last_results); 00096 }