SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
CrossValidationMulticlassStorage.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
00008  */
00009 
00010 #include <shogun/evaluation/CrossValidationMulticlassStorage.h>
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/evaluation/PRCEvaluation.h>
00013 #include <shogun/evaluation/MulticlassAccuracy.h>
00014 
00015 using namespace shogun;
00016 
00017 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) : 
00018     CCrossValidationOutput()
00019 {
00020     m_initialized = false;
00021     m_compute_ROC = compute_ROC;
00022     m_compute_PRC = compute_PRC;
00023     m_compute_conf_matrices = compute_conf_matrices;
00024     m_pred_labels = NULL;
00025     m_true_labels = NULL;
00026     m_num_classes = 0;
00027     m_binary_evaluations = new CDynamicObjectArray();
00028 }
00029 
00030 
00031 CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage()
00032 {
00033     if (m_compute_ROC)
00034     {
00035         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00036             m_fold_ROC_graphs[i].~SGMatrix<float64_t>();
00037 
00038         SG_FREE(m_fold_ROC_graphs);
00039     }
00040 
00041     if (m_compute_PRC)
00042     {
00043         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00044             m_fold_PRC_graphs[i].~SGMatrix<float64_t>();
00045 
00046         SG_FREE(m_fold_PRC_graphs);
00047     }
00048 
00049     if (m_compute_conf_matrices)
00050     {
00051         for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00052             m_conf_matrices[i].~SGMatrix<int32_t>();
00053     
00054         SG_FREE(m_conf_matrices);
00055     }
00056 
00057     SG_UNREF(m_binary_evaluations);
00058 };
00059 
00060 
00061 void CCrossValidationMulticlassStorage::post_init()
00062 {
00063     if (m_initialized)
00064         SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n");
00065 
00066     if (m_compute_ROC)
00067     {
00068         SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00069         m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00070         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00071             new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>();
00072     }
00073 
00074     if (m_compute_PRC)
00075     {
00076         SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00077         m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00078         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00079             new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>();
00080     }
00081 
00082     if (m_binary_evaluations->get_num_elements())
00083         m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements());
00084 
00085     m_accuracies = SGVector<float64_t>(m_num_folds*m_num_runs);
00086 
00087     if (m_compute_conf_matrices)
00088     {
00089         m_conf_matrices = SG_MALLOC(SGMatrix<int32_t>, m_num_folds*m_num_runs);
00090         for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00091             new (&m_conf_matrices[i]) SGMatrix<int32_t>();
00092     }
00093 
00094     m_initialized = true;
00095 }
00096 
00097 void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)
00098 {
00099     ASSERT((CMulticlassLabels*)labels);
00100     m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
00101 }
00102 
00103 void CCrossValidationMulticlassStorage::post_update_results()
00104 {
00105     CROCEvaluation eval_ROC;
00106     CPRCEvaluation eval_PRC;
00107     int32_t n_evals = m_binary_evaluations->get_num_elements();
00108     for (int32_t c=0; c<m_num_classes; c++)
00109     {
00110         SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
00111         CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
00112         CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
00113         if (m_compute_ROC)
00114         {
00115             eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
00116             m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 
00117                 eval_ROC.get_ROC();
00118         }
00119         if (m_compute_PRC)
00120         {
00121             eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
00122             m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 
00123                 eval_PRC.get_PRC();
00124         }
00125         
00126         for (int32_t i=0; i<n_evals; i++)
00127         {
00128             CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i);
00129             m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
00130                 evaluator->evaluate(pred_labels_binary, true_labels_binary);
00131             SG_UNREF(evaluator);
00132         }
00133 
00134         SG_UNREF(pred_labels_binary);
00135         SG_UNREF(true_labels_binary);
00136     }
00137     CMulticlassAccuracy accuracy;
00138 
00139     m_accuracies[m_current_run_index*m_num_folds+m_current_fold_index] = accuracy.evaluate(m_pred_labels, m_true_labels);
00140     
00141     if (m_compute_conf_matrices)
00142     {
00143         m_conf_matrices[m_current_run_index*m_num_folds+m_current_fold_index] = CMulticlassAccuracy::get_confusion_matrix(m_pred_labels, m_true_labels);
00144     }
00145 }
00146 
00147 void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix)
00148 {
00149     m_pred_labels = (CMulticlassLabels*)results;
00150 }
00151 
00152 void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix)
00153 {
00154     m_true_labels = (CMulticlassLabels*)results;
00155 }
00156 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation