SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann 00008 */ 00009 00010 #include <shogun/evaluation/CrossValidationMulticlassStorage.h> 00011 #include <shogun/evaluation/ROCEvaluation.h> 00012 #include <shogun/evaluation/PRCEvaluation.h> 00013 #include <shogun/evaluation/MulticlassAccuracy.h> 00014 00015 using namespace shogun; 00016 00017 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) : 00018 CCrossValidationOutput() 00019 { 00020 m_initialized = false; 00021 m_compute_ROC = compute_ROC; 00022 m_compute_PRC = compute_PRC; 00023 m_compute_conf_matrices = compute_conf_matrices; 00024 m_pred_labels = NULL; 00025 m_true_labels = NULL; 00026 m_num_classes = 0; 00027 m_binary_evaluations = new CDynamicObjectArray(); 00028 } 00029 00030 00031 CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage() 00032 { 00033 if (m_compute_ROC) 00034 { 00035 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00036 m_fold_ROC_graphs[i].~SGMatrix<float64_t>(); 00037 00038 SG_FREE(m_fold_ROC_graphs); 00039 } 00040 00041 if (m_compute_PRC) 00042 { 00043 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00044 m_fold_PRC_graphs[i].~SGMatrix<float64_t>(); 00045 00046 SG_FREE(m_fold_PRC_graphs); 00047 } 00048 00049 if (m_compute_conf_matrices) 00050 { 00051 for (int32_t i=0; i<m_num_folds*m_num_runs; i++) 00052 m_conf_matrices[i].~SGMatrix<int32_t>(); 00053 00054 SG_FREE(m_conf_matrices); 00055 } 00056 00057 SG_UNREF(m_binary_evaluations); 00058 }; 00059 00060 00061 void CCrossValidationMulticlassStorage::post_init() 00062 { 00063 if (m_initialized) 00064 SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n"); 00065 00066 if (m_compute_ROC) 00067 { 00068 SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes); 00069 m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes); 00070 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00071 new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>(); 00072 } 00073 00074 if (m_compute_PRC) 00075 { 00076 SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes); 00077 m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes); 00078 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++) 00079 new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>(); 00080 } 00081 00082 if (m_binary_evaluations->get_num_elements()) 00083 m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements()); 00084 00085 m_accuracies = SGVector<float64_t>(m_num_folds*m_num_runs); 00086 00087 if (m_compute_conf_matrices) 00088 { 00089 m_conf_matrices = SG_MALLOC(SGMatrix<int32_t>, m_num_folds*m_num_runs); 00090 for (int32_t i=0; i<m_num_folds*m_num_runs; i++) 00091 new (&m_conf_matrices[i]) SGMatrix<int32_t>(); 00092 } 00093 00094 m_initialized = true; 00095 } 00096 00097 void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels) 00098 { 00099 ASSERT((CMulticlassLabels*)labels); 00100 m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes(); 00101 } 00102 00103 void CCrossValidationMulticlassStorage::post_update_results() 00104 { 00105 CROCEvaluation eval_ROC; 00106 CPRCEvaluation eval_PRC; 00107 int32_t n_evals = m_binary_evaluations->get_num_elements(); 00108 for (int32_t c=0; c<m_num_classes; c++) 00109 { 00110 SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c); 00111 CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c); 00112 CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c); 00113 if (m_compute_ROC) 00114 { 00115 eval_ROC.evaluate(pred_labels_binary, true_labels_binary); 00116 m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 00117 eval_ROC.get_ROC(); 00118 } 00119 if (m_compute_PRC) 00120 { 00121 eval_PRC.evaluate(pred_labels_binary, true_labels_binary); 00122 m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 00123 eval_PRC.get_PRC(); 00124 } 00125 00126 for (int32_t i=0; i<n_evals; i++) 00127 { 00128 CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i); 00129 m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] = 00130 evaluator->evaluate(pred_labels_binary, true_labels_binary); 00131 SG_UNREF(evaluator); 00132 } 00133 00134 SG_UNREF(pred_labels_binary); 00135 SG_UNREF(true_labels_binary); 00136 } 00137 CMulticlassAccuracy accuracy; 00138 00139 m_accuracies[m_current_run_index*m_num_folds+m_current_fold_index] = accuracy.evaluate(m_pred_labels, m_true_labels); 00140 00141 if (m_compute_conf_matrices) 00142 { 00143 m_conf_matrices[m_current_run_index*m_num_folds+m_current_fold_index] = CMulticlassAccuracy::get_confusion_matrix(m_pred_labels, m_true_labels); 00144 } 00145 } 00146 00147 void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix) 00148 { 00149 m_pred_labels = (CMulticlassLabels*)results; 00150 } 00151 00152 void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix) 00153 { 00154 m_true_labels = (CMulticlassLabels*)results; 00155 } 00156