SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Written (W) 2012 Heiko Strathmann 00009 * 00010 */ 00011 00012 #ifndef __CROSSVALIDATIONOUTPUT_H_ 00013 #define __CROSSVALIDATIONOUTPUT_H_ 00014 00015 #include <shogun/base/SGObject.h> 00016 00017 namespace shogun 00018 { 00019 00020 class CMachine; 00021 class CLabels; 00022 class CEvaluation; 00023 00040 class CCrossValidationOutput: public CSGObject 00041 { 00042 public: 00043 00045 CCrossValidationOutput() : CSGObject() 00046 { 00047 m_current_run_index=0; 00048 m_current_fold_index=0; 00049 m_num_runs=0; 00050 m_num_folds=0; 00051 } 00052 00054 virtual ~CCrossValidationOutput() {} 00055 00057 virtual const char* get_name() const=0; 00058 00064 virtual void init_num_runs(index_t num_runs, const char* prefix="") 00065 { 00066 m_num_runs=num_runs; 00067 } 00068 00073 virtual void init_num_folds(index_t num_folds, const char* prefix="") 00074 { 00075 m_num_folds=num_folds; 00076 } 00077 00081 virtual void init_expose_labels(CLabels* labels) { } 00082 00084 virtual void post_init() { } 00085 00091 virtual void update_run_index(index_t run_index, 00092 const char* prefix="") 00093 { 00094 m_current_run_index=run_index; 00095 } 00096 00102 virtual void update_fold_index(index_t fold_index, 00103 const char* prefix="") 00104 { 00105 m_current_fold_index=fold_index; 00106 } 00107 00113 virtual void update_train_indices(SGVector<index_t> indices, 00114 const char* prefix="") {} 00115 00121 virtual void update_test_indices(SGVector<index_t> indices, 00122 const char* prefix="") {} 00123 00129 virtual void update_trained_machine(CMachine* machine, 00130 const char* prefix="") {} 00131 00137 virtual void update_test_result(CLabels* results, 00138 const char* prefix="") {} 00139 00145 virtual void update_test_true_result(CLabels* results, 00146 const char* prefix="") {} 00147 00150 virtual void post_update_results() {} 00151 00157 virtual void update_evaluation_result(float64_t result, 00158 const char* prefix="") {} 00159 00160 protected: 00162 index_t m_current_run_index; 00163 00165 index_t m_current_fold_index; 00166 00168 index_t m_num_runs; 00169 00171 index_t m_num_folds; 00172 }; 00173 00174 } 00175 00176 #endif /* __CROSSVALIDATIONOUTPUT_H_ */