SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #ifndef __CROSSVALIDATION_H_ 00012 #define __CROSSVALIDATION_H_ 00013 00014 #include <shogun/evaluation/EvaluationResult.h> 00015 #include <shogun/evaluation/MachineEvaluation.h> 00016 00017 namespace shogun 00018 { 00019 00020 class CMachineEvaluation; 00021 class CCrossValidationOutput; 00022 class CList; 00023 00029 class CCrossValidationResult : public CEvaluationResult 00030 { 00031 public: 00032 CCrossValidationResult() 00033 { 00034 mean = 0; 00035 has_conf_int = 0; 00036 conf_int_low = 0; 00037 conf_int_up = 0; 00038 conf_int_alpha = 0; 00039 } 00040 00046 virtual EEvaluationResultType get_result_type() 00047 { 00048 return CROSSVALIDATION_RESULT; 00049 } 00050 00056 virtual const char* get_name() const { return "CrossValidationResult"; } 00057 00059 virtual void print_result() 00060 { 00061 if (has_conf_int) 00062 { 00063 SG_SPRINT("[%f,%f] with alpha=%f, mean=%f\n", conf_int_low, 00064 conf_int_up, conf_int_alpha, mean); 00065 } 00066 else 00067 SG_SPRINT("%f\n", mean); 00068 } 00069 00070 public: 00072 float64_t mean; 00074 bool has_conf_int; 00076 float64_t conf_int_low; 00078 float64_t conf_int_up; 00080 float64_t conf_int_alpha; 00081 00082 }; 00083 00109 class CCrossValidation: public CMachineEvaluation 00110 { 00111 public: 00113 CCrossValidation(); 00114 00123 CCrossValidation(CMachine* machine, CFeatures* features, CLabels* labels, 00124 CSplittingStrategy* splitting_strategy, 00125 CEvaluation* evaluation_criterion, bool autolock=true); 00126 00134 CCrossValidation(CMachine* machine, CLabels* labels, 00135 CSplittingStrategy* splitting_strategy, 00136 CEvaluation* evaluation_criterion, bool autolock=true); 00137 00139 virtual ~CCrossValidation(); 00140 00142 void set_num_runs(int32_t num_runs); 00143 00145 void set_conf_int_alpha(float64_t m_conf_int_alpha); 00146 00148 virtual CEvaluationResult* evaluate(); 00149 00155 void add_cross_validation_output( 00156 CCrossValidationOutput* cross_validation_output); 00157 00159 inline virtual const char* get_name() const 00160 { 00161 return "CrossValidation"; 00162 } 00163 00164 private: 00165 void init(); 00166 00167 protected: 00176 virtual float64_t evaluate_one_run(); 00177 00179 int32_t m_num_runs; 00181 float64_t m_conf_int_alpha; 00182 00184 CList* m_xval_outputs; 00185 }; 00186 00187 } 00188 00189 #endif /* __CROSSVALIDATION_H_ */