opencv 2.2.0
|
00001 /*M/////////////////////////////////////////////////////////////////////////////////////// 00002 // 00003 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 00004 // 00005 // By downloading, copying, installing or using the software you agree to this license. 00006 // If you do not agree to this license, do not download, install, 00007 // copy or use the software. 00008 // 00009 // 00010 // Intel License Agreement 00011 // 00012 // Copyright (C) 2000, Intel Corporation, all rights reserved. 00013 // Third party copyrights are property of their respective owners. 00014 // 00015 // Redistribution and use in source and binary forms, with or without modification, 00016 // are permitted provided that the following conditions are met: 00017 // 00018 // * Redistribution's of source code must retain the above copyright notice, 00019 // this list of conditions and the following disclaimer. 00020 // 00021 // * Redistribution's in binary form must reproduce the above copyright notice, 00022 // this list of conditions and the following disclaimer in the documentation 00023 // and/or other materials provided with the distribution. 00024 // 00025 // * The name of Intel Corporation may not be used to endorse or promote products 00026 // derived from this software without specific prior written permission. 00027 // 00028 // This software is provided by the copyright holders and contributors "as is" and 00029 // any express or implied warranties, including, but not limited to, the implied 00030 // warranties of merchantability and fitness for a particular purpose are disclaimed. 00031 // In no event shall the Intel Corporation or contributors be liable for any direct, 00032 // indirect, incidental, special, exemplary, or consequential damages 00033 // (including, but not limited to, procurement of substitute goods or services; 00034 // loss of use, data, or profits; or business interruption) however caused 00035 // and on any theory of liability, whether in contract, strict liability, 00036 // or tort (including negligence or otherwise) arising in any way out of 00037 // the use of this software, even if advised of the possibility of such damage. 00038 // 00039 //M*/ 00040 00041 #ifndef __OPENCV_ML_HPP__ 00042 #define __OPENCV_ML_HPP__ 00043 00044 // disable deprecation warning which appears in VisualStudio 8.0 00045 #if _MSC_VER >= 1400 00046 #pragma warning( disable : 4996 ) 00047 #endif 00048 00049 #ifndef SKIP_INCLUDES 00050 00051 #include "opencv2/core/core.hpp" 00052 #include <limits.h> 00053 00054 #if defined WIN32 || defined _WIN32 00055 #include <windows.h> 00056 #endif 00057 00058 #else // SKIP_INCLUDES 00059 00060 #if defined WIN32 || defined _WIN32 00061 #define CV_CDECL __cdecl 00062 #define CV_STDCALL __stdcall 00063 #else 00064 #define CV_CDECL 00065 #define CV_STDCALL 00066 #endif 00067 00068 #ifndef CV_EXTERN_C 00069 #ifdef __cplusplus 00070 #define CV_EXTERN_C extern "C" 00071 #define CV_DEFAULT(val) = val 00072 #else 00073 #define CV_EXTERN_C 00074 #define CV_DEFAULT(val) 00075 #endif 00076 #endif 00077 00078 #ifndef CV_EXTERN_C_FUNCPTR 00079 #ifdef __cplusplus 00080 #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; } 00081 #else 00082 #define CV_EXTERN_C_FUNCPTR(x) typedef x 00083 #endif 00084 #endif 00085 00086 #ifndef CV_INLINE 00087 #if defined __cplusplus 00088 #define CV_INLINE inline 00089 #elif (defined WIN32 || defined _WIN32) && !defined __GNUC__ 00090 #define CV_INLINE __inline 00091 #else 00092 #define CV_INLINE static 00093 #endif 00094 #endif /* CV_INLINE */ 00095 00096 #if (defined WIN32 || defined _WIN32) && defined CVAPI_EXPORTS 00097 #define CV_EXPORTS __declspec(dllexport) 00098 #else 00099 #define CV_EXPORTS 00100 #endif 00101 00102 #ifndef CVAPI 00103 #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL 00104 #endif 00105 00106 #endif // SKIP_INCLUDES 00107 00108 00109 #ifdef __cplusplus 00110 00111 // Apple defines a check() macro somewhere in the debug headers 00112 // that interferes with a method definiton in this header 00113 #undef check 00114 00115 /****************************************************************************************\ 00116 * Main struct definitions * 00117 \****************************************************************************************/ 00118 00119 /* log(2*PI) */ 00120 #define CV_LOG2PI (1.8378770664093454835606594728112) 00121 00122 /* columns of <trainData> matrix are training samples */ 00123 #define CV_COL_SAMPLE 0 00124 00125 /* rows of <trainData> matrix are training samples */ 00126 #define CV_ROW_SAMPLE 1 00127 00128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE) 00129 00130 struct CvVectors 00131 { 00132 int type; 00133 int dims, count; 00134 CvVectors* next; 00135 union 00136 { 00137 uchar** ptr; 00138 float** fl; 00139 double** db; 00140 } data; 00141 }; 00142 00143 #if 0 00144 /* A structure, representing the lattice range of statmodel parameters. 00145 It is used for optimizing statmodel parameters by cross-validation method. 00146 The lattice is logarithmic, so <step> must be greater then 1. */ 00147 typedef struct CvParamLattice 00148 { 00149 double min_val; 00150 double max_val; 00151 double step; 00152 } 00153 CvParamLattice; 00154 00155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val, 00156 double log_step ) 00157 { 00158 CvParamLattice pl; 00159 pl.min_val = MIN( min_val, max_val ); 00160 pl.max_val = MAX( min_val, max_val ); 00161 pl.step = MAX( log_step, 1. ); 00162 return pl; 00163 } 00164 00165 CV_INLINE CvParamLattice cvDefaultParamLattice( void ) 00166 { 00167 CvParamLattice pl = {0,0,0}; 00168 return pl; 00169 } 00170 #endif 00171 00172 /* Variable type */ 00173 #define CV_VAR_NUMERICAL 0 00174 #define CV_VAR_ORDERED 0 00175 #define CV_VAR_CATEGORICAL 1 00176 00177 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm" 00178 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn" 00179 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian" 00180 #define CV_TYPE_NAME_ML_EM "opencv-ml-em" 00181 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree" 00182 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree" 00183 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" 00184 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" 00185 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" 00186 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees" 00187 00188 #define CV_TRAIN_ERROR 0 00189 #define CV_TEST_ERROR 1 00190 00191 class CV_EXPORTS_W CvStatModel 00192 { 00193 public: 00194 CvStatModel(); 00195 virtual ~CvStatModel(); 00196 00197 virtual void clear(); 00198 00199 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const; 00200 CV_WRAP virtual void load( const char* filename, const char* name=0 ); 00201 00202 virtual void write( CvFileStorage* storage, const char* name ) const; 00203 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00204 00205 protected: 00206 const char* default_model_name; 00207 }; 00208 00209 /****************************************************************************************\ 00210 * Normal Bayes Classifier * 00211 \****************************************************************************************/ 00212 00213 /* The structure, representing the grid range of statmodel parameters. 00214 It is used for optimizing statmodel accuracy by varying model parameters, 00215 the accuracy estimate being computed by cross-validation. 00216 The grid is logarithmic, so <step> must be greater then 1. */ 00217 00218 class CvMLData; 00219 00220 struct CV_EXPORTS_W_MAP CvParamGrid 00221 { 00222 // SVM params type 00223 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 }; 00224 00225 CvParamGrid() 00226 { 00227 min_val = max_val = step = 0; 00228 } 00229 00230 CvParamGrid( double _min_val, double _max_val, double log_step ) 00231 { 00232 min_val = _min_val; 00233 max_val = _max_val; 00234 step = log_step; 00235 } 00236 //CvParamGrid( int param_id ); 00237 bool check() const; 00238 00239 CV_PROP_RW double min_val; 00240 CV_PROP_RW double max_val; 00241 CV_PROP_RW double step; 00242 }; 00243 00244 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel 00245 { 00246 public: 00247 CV_WRAP CvNormalBayesClassifier(); 00248 virtual ~CvNormalBayesClassifier(); 00249 00250 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses, 00251 const CvMat* varIdx=0, const CvMat* sampleIdx=0 ); 00252 00253 virtual bool train( const CvMat* trainData, const CvMat* responses, 00254 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false ); 00255 00256 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const; 00257 CV_WRAP virtual void clear(); 00258 00259 #ifndef SWIG 00260 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses, 00261 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() ); 00262 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00263 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00264 bool update=false ); 00265 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const; 00266 #endif 00267 00268 virtual void write( CvFileStorage* storage, const char* name ) const; 00269 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00270 00271 protected: 00272 int var_count, var_all; 00273 CvMat* var_idx; 00274 CvMat* cls_labels; 00275 CvMat** count; 00276 CvMat** sum; 00277 CvMat** productsum; 00278 CvMat** avg; 00279 CvMat** inv_eigen_values; 00280 CvMat** cov_rotate_mats; 00281 CvMat* c; 00282 }; 00283 00284 00285 /****************************************************************************************\ 00286 * K-Nearest Neighbour Classifier * 00287 \****************************************************************************************/ 00288 00289 // k Nearest Neighbors 00290 class CV_EXPORTS_W CvKNearest : public CvStatModel 00291 { 00292 public: 00293 00294 CV_WRAP CvKNearest(); 00295 virtual ~CvKNearest(); 00296 00297 CvKNearest( const CvMat* trainData, const CvMat* responses, 00298 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 ); 00299 00300 virtual bool train( const CvMat* trainData, const CvMat* responses, 00301 const CvMat* sampleIdx=0, bool is_regression=false, 00302 int maxK=32, bool updateBase=false ); 00303 00304 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0, 00305 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const; 00306 00307 #ifndef SWIG 00308 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses, 00309 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 ); 00310 00311 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00312 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, 00313 int maxK=32, bool updateBase=false ); 00314 00315 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0, 00316 const float** neighbors=0, cv::Mat* neighborResponses=0, 00317 cv::Mat* dist=0 ) const; 00318 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results, 00319 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const; 00320 #endif 00321 00322 virtual void clear(); 00323 int get_max_k() const; 00324 int get_var_count() const; 00325 int get_sample_count() const; 00326 bool is_regression() const; 00327 00328 protected: 00329 00330 virtual float write_results( int k, int k1, int start, int end, 00331 const float* neighbor_responses, const float* dist, CvMat* _results, 00332 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const; 00333 00334 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 00335 float* neighbor_responses, const float** neighbors, float* dist ) const; 00336 00337 00338 int max_k, var_count; 00339 int total; 00340 bool regression; 00341 CvVectors* samples; 00342 }; 00343 00344 /****************************************************************************************\ 00345 * Support Vector Machines * 00346 \****************************************************************************************/ 00347 00348 // SVM training parameters 00349 struct CV_EXPORTS_W_MAP CvSVMParams 00350 { 00351 CvSVMParams(); 00352 CvSVMParams( int _svm_type, int _kernel_type, 00353 double _degree, double _gamma, double _coef0, 00354 double Cvalue, double _nu, double _p, 00355 CvMat* _class_weights, CvTermCriteria _term_crit ); 00356 00357 CV_PROP_RW int svm_type; 00358 CV_PROP_RW int kernel_type; 00359 CV_PROP_RW double degree; // for poly 00360 CV_PROP_RW double gamma; // for poly/rbf/sigmoid 00361 CV_PROP_RW double coef0; // for poly/sigmoid 00362 00363 CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR 00364 CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR 00365 CV_PROP_RW double p; // for CV_SVM_EPS_SVR 00366 CvMat* class_weights; // for CV_SVM_C_SVC 00367 CV_PROP_RW CvTermCriteria term_crit; // termination criteria 00368 }; 00369 00370 00371 struct CV_EXPORTS CvSVMKernel 00372 { 00373 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs, 00374 const float* another, float* results ); 00375 CvSVMKernel(); 00376 CvSVMKernel( const CvSVMParams* params, Calc _calc_func ); 00377 virtual bool create( const CvSVMParams* params, Calc _calc_func ); 00378 virtual ~CvSVMKernel(); 00379 00380 virtual void clear(); 00381 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results ); 00382 00383 const CvSVMParams* params; 00384 Calc calc_func; 00385 00386 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs, 00387 const float* another, float* results, 00388 double alpha, double beta ); 00389 00390 virtual void calc_linear( int vec_count, int vec_size, const float** vecs, 00391 const float* another, float* results ); 00392 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs, 00393 const float* another, float* results ); 00394 virtual void calc_poly( int vec_count, int vec_size, const float** vecs, 00395 const float* another, float* results ); 00396 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs, 00397 const float* another, float* results ); 00398 }; 00399 00400 00401 struct CvSVMKernelRow 00402 { 00403 CvSVMKernelRow* prev; 00404 CvSVMKernelRow* next; 00405 float* data; 00406 }; 00407 00408 00409 struct CvSVMSolutionInfo 00410 { 00411 double obj; 00412 double rho; 00413 double upper_bound_p; 00414 double upper_bound_n; 00415 double r; // for Solver_NU 00416 }; 00417 00418 class CV_EXPORTS CvSVMSolver 00419 { 00420 public: 00421 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j ); 00422 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed ); 00423 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r ); 00424 00425 CvSVMSolver(); 00426 00427 CvSVMSolver( int count, int var_count, const float** samples, schar* y, 00428 int alpha_count, double* alpha, double Cp, double Cn, 00429 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00430 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00431 virtual bool create( int count, int var_count, const float** samples, schar* y, 00432 int alpha_count, double* alpha, double Cp, double Cn, 00433 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 00434 SelectWorkingSet select_working_set, CalcRho calc_rho ); 00435 virtual ~CvSVMSolver(); 00436 00437 virtual void clear(); 00438 virtual bool solve_generic( CvSVMSolutionInfo& si ); 00439 00440 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y, 00441 double Cp, double Cn, CvMemStorage* storage, 00442 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si ); 00443 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y, 00444 CvMemStorage* storage, CvSVMKernel* kernel, 00445 double* alpha, CvSVMSolutionInfo& si ); 00446 virtual bool solve_one_class( int count, int var_count, const float** samples, 00447 CvMemStorage* storage, CvSVMKernel* kernel, 00448 double* alpha, CvSVMSolutionInfo& si ); 00449 00450 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y, 00451 CvMemStorage* storage, CvSVMKernel* kernel, 00452 double* alpha, CvSVMSolutionInfo& si ); 00453 00454 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y, 00455 CvMemStorage* storage, CvSVMKernel* kernel, 00456 double* alpha, CvSVMSolutionInfo& si ); 00457 00458 virtual float* get_row_base( int i, bool* _existed ); 00459 virtual float* get_row( int i, float* dst ); 00460 00461 int sample_count; 00462 int var_count; 00463 int cache_size; 00464 int cache_line_size; 00465 const float** samples; 00466 const CvSVMParams* params; 00467 CvMemStorage* storage; 00468 CvSVMKernelRow lru_list; 00469 CvSVMKernelRow* rows; 00470 00471 int alpha_count; 00472 00473 double* G; 00474 double* alpha; 00475 00476 // -1 - lower bound, 0 - free, 1 - upper bound 00477 schar* alpha_status; 00478 00479 schar* y; 00480 double* b; 00481 float* buf[2]; 00482 double eps; 00483 int max_iter; 00484 double C[2]; // C[0] == Cn, C[1] == Cp 00485 CvSVMKernel* kernel; 00486 00487 SelectWorkingSet select_working_set_func; 00488 CalcRho calc_rho_func; 00489 GetRow get_row_func; 00490 00491 virtual bool select_working_set( int& i, int& j ); 00492 virtual bool select_working_set_nu_svm( int& i, int& j ); 00493 virtual void calc_rho( double& rho, double& r ); 00494 virtual void calc_rho_nu_svm( double& rho, double& r ); 00495 00496 virtual float* get_row_svc( int i, float* row, float* dst, bool existed ); 00497 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed ); 00498 virtual float* get_row_svr( int i, float* row, float* dst, bool existed ); 00499 }; 00500 00501 00502 struct CvSVMDecisionFunc 00503 { 00504 double rho; 00505 int sv_count; 00506 double* alpha; 00507 int* sv_index; 00508 }; 00509 00510 00511 // SVM model 00512 class CV_EXPORTS_W CvSVM : public CvStatModel 00513 { 00514 public: 00515 // SVM type 00516 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; 00517 00518 // SVM kernel type 00519 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 }; 00520 00521 // SVM params type 00522 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; 00523 00524 CV_WRAP CvSVM(); 00525 virtual ~CvSVM(); 00526 00527 CvSVM( const CvMat* trainData, const CvMat* responses, 00528 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 00529 CvSVMParams params=CvSVMParams() ); 00530 00531 virtual bool train( const CvMat* trainData, const CvMat* responses, 00532 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 00533 CvSVMParams params=CvSVMParams() ); 00534 00535 virtual bool train_auto( const CvMat* trainData, const CvMat* responses, 00536 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params, 00537 int kfold = 10, 00538 CvParamGrid Cgrid = get_default_grid(CvSVM::C), 00539 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA), 00540 CvParamGrid pGrid = get_default_grid(CvSVM::P), 00541 CvParamGrid nuGrid = get_default_grid(CvSVM::NU), 00542 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF), 00543 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE), 00544 bool balanced=false ); 00545 00546 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const; 00547 00548 #ifndef SWIG 00549 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses, 00550 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00551 CvSVMParams params=CvSVMParams() ); 00552 00553 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 00554 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 00555 CvSVMParams params=CvSVMParams() ); 00556 00557 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses, 00558 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params, 00559 int k_fold = 10, 00560 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C), 00561 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA), 00562 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P), 00563 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU), 00564 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF), 00565 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE), 00566 bool balanced=false); 00567 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const; 00568 #endif 00569 00570 CV_WRAP virtual int get_support_vector_count() const; 00571 virtual const float* get_support_vector(int i) const; 00572 virtual CvSVMParams get_params() const { return params; }; 00573 CV_WRAP virtual void clear(); 00574 00575 static CvParamGrid get_default_grid( int param_id ); 00576 00577 virtual void write( CvFileStorage* storage, const char* name ) const; 00578 virtual void read( CvFileStorage* storage, CvFileNode* node ); 00579 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; } 00580 00581 protected: 00582 00583 virtual bool set_params( const CvSVMParams& params ); 00584 virtual bool train1( int sample_count, int var_count, const float** samples, 00585 const void* responses, double Cp, double Cn, 00586 CvMemStorage* _storage, double* alpha, double& rho ); 00587 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, 00588 const CvMat* responses, CvMemStorage* _storage, double* alpha ); 00589 virtual void create_kernel(); 00590 virtual void create_solver(); 00591 00592 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const; 00593 00594 virtual void write_params( CvFileStorage* fs ) const; 00595 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00596 00597 CvSVMParams params; 00598 CvMat* class_labels; 00599 int var_all; 00600 float** sv; 00601 int sv_total; 00602 CvMat* var_idx; 00603 CvMat* class_weights; 00604 CvSVMDecisionFunc* decision_func; 00605 CvMemStorage* storage; 00606 00607 CvSVMSolver* solver; 00608 CvSVMKernel* kernel; 00609 }; 00610 00611 /****************************************************************************************\ 00612 * Expectation - Maximization * 00613 \****************************************************************************************/ 00614 00615 struct CV_EXPORTS_W_MAP CvEMParams 00616 { 00617 CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/), 00618 start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0) 00619 { 00620 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON ); 00621 } 00622 00623 CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/, 00624 int _start_step=0/*CvEM::START_AUTO_STEP*/, 00625 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON), 00626 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) : 00627 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step), 00628 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit) 00629 {} 00630 00631 CV_PROP_RW int nclusters; 00632 CV_PROP_RW int cov_mat_type; 00633 CV_PROP_RW int start_step; 00634 const CvMat* probs; 00635 const CvMat* weights; 00636 const CvMat* means; 00637 const CvMat** covs; 00638 CV_PROP_RW CvTermCriteria term_crit; 00639 }; 00640 00641 00642 class CV_EXPORTS_W CvEM : public CvStatModel 00643 { 00644 public: 00645 // Type of covariation matrices 00646 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 }; 00647 00648 // The initial step 00649 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 }; 00650 00651 CV_WRAP CvEM(); 00652 CvEM( const CvMat* samples, const CvMat* sampleIdx=0, 00653 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00654 //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights, 00655 // CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats); 00656 00657 virtual ~CvEM(); 00658 00659 virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0, 00660 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 00661 00662 virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const; 00663 00664 #ifndef SWIG 00665 CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(), 00666 CvEMParams params=CvEMParams() ); 00667 00668 CV_WRAP virtual bool train( const cv::Mat& samples, 00669 const cv::Mat& sampleIdx=cv::Mat(), 00670 CvEMParams params=CvEMParams(), 00671 CV_OUT cv::Mat* labels=0 ); 00672 00673 CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const; 00674 00675 CV_WRAP int getNClusters() const; 00676 CV_WRAP cv::Mat getMeans() const; 00677 CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs) const; 00678 CV_WRAP cv::Mat getWeights() const; 00679 CV_WRAP cv::Mat getProbs() const; 00680 00681 CV_WRAP inline double getLikelihood() const { return log_likelihood; }; 00682 #endif 00683 00684 CV_WRAP virtual void clear(); 00685 00686 int get_nclusters() const; 00687 const CvMat* get_means() const; 00688 const CvMat** get_covs() const; 00689 const CvMat* get_weights() const; 00690 const CvMat* get_probs() const; 00691 00692 inline double get_log_likelihood () const { return log_likelihood; }; 00693 00694 // inline const CvMat * get_log_weight_div_det () const { return log_weight_div_det; }; 00695 // inline const CvMat * get_inv_eigen_values () const { return inv_eigen_values; }; 00696 // inline const CvMat ** get_cov_rotate_mats () const { return cov_rotate_mats; }; 00697 00698 protected: 00699 00700 virtual void set_params( const CvEMParams& params, 00701 const CvVectors& train_data ); 00702 virtual void init_em( const CvVectors& train_data ); 00703 virtual double run_em( const CvVectors& train_data ); 00704 virtual void init_auto( const CvVectors& samples ); 00705 virtual void kmeans( const CvVectors& train_data, int nclusters, 00706 CvMat* labels, CvTermCriteria criteria, 00707 const CvMat* means ); 00708 CvEMParams params; 00709 double log_likelihood; 00710 00711 CvMat* means; 00712 CvMat** covs; 00713 CvMat* weights; 00714 CvMat* probs; 00715 00716 CvMat* log_weight_div_det; 00717 CvMat* inv_eigen_values; 00718 CvMat** cov_rotate_mats; 00719 }; 00720 00721 /****************************************************************************************\ 00722 * Decision Tree * 00723 \****************************************************************************************/\ 00724 struct CvPair16u32s 00725 { 00726 unsigned short* u; 00727 int* i; 00728 }; 00729 00730 00731 #define CV_DTREE_CAT_DIR(idx,subset) \ 00732 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 00733 00734 struct CvDTreeSplit 00735 { 00736 int var_idx; 00737 int condensed_idx; 00738 int inversed; 00739 float quality; 00740 CvDTreeSplit* next; 00741 union 00742 { 00743 int subset[2]; 00744 struct 00745 { 00746 float c; 00747 int split_point; 00748 } 00749 ord; 00750 }; 00751 }; 00752 00753 struct CvDTreeNode 00754 { 00755 int class_idx; 00756 int Tn; 00757 double value; 00758 00759 CvDTreeNode* parent; 00760 CvDTreeNode* left; 00761 CvDTreeNode* right; 00762 00763 CvDTreeSplit* split; 00764 00765 int sample_count; 00766 int depth; 00767 int* num_valid; 00768 int offset; 00769 int buf_idx; 00770 double maxlr; 00771 00772 // global pruning data 00773 int complexity; 00774 double alpha; 00775 double node_risk, tree_risk, tree_error; 00776 00777 // cross-validation pruning data 00778 int* cv_Tn; 00779 double* cv_node_risk; 00780 double* cv_node_error; 00781 00782 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; } 00783 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; } 00784 }; 00785 00786 00787 struct CV_EXPORTS_W_MAP CvDTreeParams 00788 { 00789 CV_PROP_RW int max_categories; 00790 CV_PROP_RW int max_depth; 00791 CV_PROP_RW int min_sample_count; 00792 CV_PROP_RW int cv_folds; 00793 CV_PROP_RW bool use_surrogates; 00794 CV_PROP_RW bool use_1se_rule; 00795 CV_PROP_RW bool truncate_pruned_tree; 00796 CV_PROP_RW float regression_accuracy; 00797 const float* priors; 00798 00799 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10), 00800 cv_folds(10), use_surrogates(true), use_1se_rule(true), 00801 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0) 00802 {} 00803 00804 CvDTreeParams( int _max_depth, int _min_sample_count, 00805 float _regression_accuracy, bool _use_surrogates, 00806 int _max_categories, int _cv_folds, 00807 bool _use_1se_rule, bool _truncate_pruned_tree, 00808 const float* _priors ) : 00809 max_categories(_max_categories), max_depth(_max_depth), 00810 min_sample_count(_min_sample_count), cv_folds (_cv_folds), 00811 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule), 00812 truncate_pruned_tree(_truncate_pruned_tree), 00813 regression_accuracy(_regression_accuracy), 00814 priors(_priors) 00815 {} 00816 }; 00817 00818 00819 struct CV_EXPORTS CvDTreeTrainData 00820 { 00821 CvDTreeTrainData(); 00822 CvDTreeTrainData( const CvMat* trainData, int tflag, 00823 const CvMat* responses, const CvMat* varIdx=0, 00824 const CvMat* sampleIdx=0, const CvMat* varType=0, 00825 const CvMat* missingDataMask=0, 00826 const CvDTreeParams& params=CvDTreeParams(), 00827 bool _shared=false, bool _add_labels=false ); 00828 virtual ~CvDTreeTrainData(); 00829 00830 virtual void set_data( const CvMat* trainData, int tflag, 00831 const CvMat* responses, const CvMat* varIdx=0, 00832 const CvMat* sampleIdx=0, const CvMat* varType=0, 00833 const CvMat* missingDataMask=0, 00834 const CvDTreeParams& params=CvDTreeParams(), 00835 bool _shared=false, bool _add_labels=false, 00836 bool _update_data=false ); 00837 virtual void do_responses_copy(); 00838 00839 virtual void get_vectors( const CvMat* _subsample_idx, 00840 float* values, uchar* missing, float* responses, bool get_class_idx=false ); 00841 00842 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 00843 00844 virtual void write_params( CvFileStorage* fs ) const; 00845 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 00846 00847 // release all the data 00848 virtual void clear(); 00849 00850 int get_num_classes() const; 00851 int get_var_type(int vi) const; 00852 int get_work_var_count() const {return work_var_count;} 00853 00854 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf ); 00855 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf ); 00856 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 00857 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 00858 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 00859 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf, 00860 const float** ord_values, const int** sorted_indices, int* sample_indices_buf ); 00861 virtual int get_child_buf_idx( CvDTreeNode* n ); 00862 00864 00865 virtual bool set_params( const CvDTreeParams& params ); 00866 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count, 00867 int storage_idx, int offset ); 00868 00869 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val, 00870 int split_point, int inversed, float quality ); 00871 virtual CvDTreeSplit* new_split_cat( int vi, float quality ); 00872 virtual void free_node_data( CvDTreeNode* node ); 00873 virtual void free_train_data(); 00874 virtual void free_node( CvDTreeNode* node ); 00875 00876 int sample_count, var_all, var_count, max_c_count; 00877 int ord_var_count, cat_var_count, work_var_count; 00878 bool have_labels, have_priors; 00879 bool is_classifier; 00880 int tflag; 00881 00882 const CvMat* train_data; 00883 const CvMat* responses; 00884 CvMat* responses_copy; // used in Boosting 00885 00886 int buf_count, buf_size; 00887 bool shared; 00888 int is_buf_16u; 00889 00890 CvMat* cat_count; 00891 CvMat* cat_ofs; 00892 CvMat* cat_map; 00893 00894 CvMat* counts; 00895 CvMat* buf; 00896 CvMat* direction; 00897 CvMat* split_buf; 00898 00899 CvMat* var_idx; 00900 CvMat* var_type; // i-th element = 00901 // k<0 - ordered 00902 // k>=0 - categorical, see k-th element of cat_* arrays 00903 CvMat* priors; 00904 CvMat* priors_mult; 00905 00906 CvDTreeParams params; 00907 00908 CvMemStorage* tree_storage; 00909 CvMemStorage* temp_storage; 00910 00911 CvDTreeNode* data_root; 00912 00913 CvSet* node_heap; 00914 CvSet* split_heap; 00915 CvSet* cv_heap; 00916 CvSet* nv_heap; 00917 00918 cv::RNG* rng; 00919 }; 00920 00921 class CvDTree; 00922 class CvForestTree; 00923 00924 namespace cv 00925 { 00926 struct DTreeBestSplitFinder; 00927 struct ForestTreeBestSplitFinder; 00928 } 00929 00930 class CV_EXPORTS_W CvDTree : public CvStatModel 00931 { 00932 public: 00933 CV_WRAP CvDTree(); 00934 virtual ~CvDTree(); 00935 00936 virtual bool train( const CvMat* trainData, int tflag, 00937 const CvMat* responses, const CvMat* varIdx=0, 00938 const CvMat* sampleIdx=0, const CvMat* varType=0, 00939 const CvMat* missingDataMask=0, 00940 CvDTreeParams params=CvDTreeParams() ); 00941 00942 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() ); 00943 00944 // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 00945 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 ); 00946 00947 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx ); 00948 00949 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0, 00950 bool preprocessedInput=false ) const; 00951 00952 #ifndef SWIG 00953 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 00954 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 00955 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 00956 const cv::Mat& missingDataMask=cv::Mat(), 00957 CvDTreeParams params=CvDTreeParams() ); 00958 00959 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(), 00960 bool preprocessedInput=false ) const; 00961 CV_WRAP virtual cv::Mat getVarImportance(); 00962 #endif 00963 00964 virtual const CvMat* get_var_importance(); 00965 CV_WRAP virtual void clear(); 00966 00967 virtual void read( CvFileStorage* fs, CvFileNode* node ); 00968 virtual void write( CvFileStorage* fs, const char* name ) const; 00969 00970 // special read & write methods for trees in the tree ensembles 00971 virtual void read( CvFileStorage* fs, CvFileNode* node, 00972 CvDTreeTrainData* data ); 00973 virtual void write( CvFileStorage* fs ) const; 00974 00975 const CvDTreeNode* get_root() const; 00976 int get_pruned_tree_idx() const; 00977 CvDTreeTrainData* get_data(); 00978 00979 protected: 00980 friend struct cv::DTreeBestSplitFinder; 00981 00982 virtual bool do_train( const CvMat* _subsample_idx ); 00983 00984 virtual void try_split_node( CvDTreeNode* n ); 00985 virtual void split_node_data( CvDTreeNode* n ); 00986 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 00987 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 00988 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00989 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 00990 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00991 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 00992 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00993 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 00994 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 00995 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00996 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 00997 virtual double calc_node_dir( CvDTreeNode* node ); 00998 virtual void complete_node_dir( CvDTreeNode* node ); 00999 virtual void cluster_categories( const int* vectors, int vector_count, 01000 int var_count, int* sums, int k, int* cluster_labels ); 01001 01002 virtual void calc_node_value( CvDTreeNode* node ); 01003 01004 virtual void prune_cv(); 01005 virtual double update_tree_rnc( int T, int fold ); 01006 virtual int cut_tree( int T, int fold, double min_alpha ); 01007 virtual void free_prune_data(bool cut_tree); 01008 virtual void free_tree(); 01009 01010 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const; 01011 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const; 01012 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent ); 01013 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node ); 01014 virtual void write_tree_nodes( CvFileStorage* fs ) const; 01015 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node ); 01016 01017 CvDTreeNode* root; 01018 CvMat* var_importance; 01019 CvDTreeTrainData* data; 01020 01021 public: 01022 int pruned_tree_idx; 01023 }; 01024 01025 01026 /****************************************************************************************\ 01027 * Random Trees Classifier * 01028 \****************************************************************************************/ 01029 01030 class CvRTrees; 01031 01032 class CV_EXPORTS CvForestTree: public CvDTree 01033 { 01034 public: 01035 CvForestTree(); 01036 virtual ~CvForestTree(); 01037 01038 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest ); 01039 01040 virtual int get_var_count() const {return data ? data->var_count : 0;} 01041 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data ); 01042 01043 /* dummy methods to avoid warnings: BEGIN */ 01044 virtual bool train( const CvMat* trainData, int tflag, 01045 const CvMat* responses, const CvMat* varIdx=0, 01046 const CvMat* sampleIdx=0, const CvMat* varType=0, 01047 const CvMat* missingDataMask=0, 01048 CvDTreeParams params=CvDTreeParams() ); 01049 01050 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 01051 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01052 virtual void read( CvFileStorage* fs, CvFileNode* node, 01053 CvDTreeTrainData* data ); 01054 /* dummy methods to avoid warnings: END */ 01055 01056 protected: 01057 friend struct cv::ForestTreeBestSplitFinder; 01058 01059 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 01060 CvRTrees* forest; 01061 }; 01062 01063 01064 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams 01065 { 01066 //Parameters for the forest 01067 CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance 01068 CV_PROP_RW int nactive_vars; 01069 CV_PROP_RW CvTermCriteria term_crit; 01070 01071 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ), 01072 calc_var_importance(false), nactive_vars(0) 01073 { 01074 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 ); 01075 } 01076 01077 CvRTParams( int _max_depth, int _min_sample_count, 01078 float _regression_accuracy, bool _use_surrogates, 01079 int _max_categories, const float* _priors, bool _calc_var_importance, 01080 int _nactive_vars, int max_num_of_trees_in_the_forest, 01081 float forest_accuracy, int termcrit_type ) : 01082 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy, 01083 _use_surrogates, _max_categories, 0, 01084 false, false, _priors ), 01085 calc_var_importance(_calc_var_importance), 01086 nactive_vars(_nactive_vars) 01087 { 01088 term_crit = cvTermCriteria(termcrit_type, 01089 max_num_of_trees_in_the_forest, forest_accuracy); 01090 } 01091 }; 01092 01093 01094 class CV_EXPORTS_W CvRTrees : public CvStatModel 01095 { 01096 public: 01097 CV_WRAP CvRTrees(); 01098 virtual ~CvRTrees(); 01099 virtual bool train( const CvMat* trainData, int tflag, 01100 const CvMat* responses, const CvMat* varIdx=0, 01101 const CvMat* sampleIdx=0, const CvMat* varType=0, 01102 const CvMat* missingDataMask=0, 01103 CvRTParams params=CvRTParams() ); 01104 01105 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01106 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; 01107 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const; 01108 01109 #ifndef SWIG 01110 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01111 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01112 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01113 const cv::Mat& missingDataMask=cv::Mat(), 01114 CvRTParams params=CvRTParams() ); 01115 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01116 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 01117 CV_WRAP virtual cv::Mat getVarImportance(); 01118 #endif 01119 01120 CV_WRAP virtual void clear(); 01121 01122 virtual const CvMat* get_var_importance(); 01123 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, 01124 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; 01125 01126 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01127 01128 virtual float get_train_error(); 01129 01130 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01131 virtual void write( CvFileStorage* fs, const char* name ) const; 01132 01133 CvMat* get_active_var_mask(); 01134 CvRNG* get_rng(); 01135 01136 int get_tree_count() const; 01137 CvForestTree* get_tree(int i) const; 01138 01139 protected: 01140 01141 virtual bool grow_forest( const CvTermCriteria term_crit ); 01142 01143 // array of the trees of the forest 01144 CvForestTree** trees; 01145 CvDTreeTrainData* data; 01146 int ntrees; 01147 int nclasses; 01148 double oob_error; 01149 CvMat* var_importance; 01150 int nsamples; 01151 01152 cv::RNG* rng; 01153 CvMat* active_var_mask; 01154 }; 01155 01156 /****************************************************************************************\ 01157 * Extremely randomized trees Classifier * 01158 \****************************************************************************************/ 01159 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData 01160 { 01161 virtual void set_data( const CvMat* trainData, int tflag, 01162 const CvMat* responses, const CvMat* varIdx=0, 01163 const CvMat* sampleIdx=0, const CvMat* varType=0, 01164 const CvMat* missingDataMask=0, 01165 const CvDTreeParams& params=CvDTreeParams(), 01166 bool _shared=false, bool _add_labels=false, 01167 bool _update_data=false ); 01168 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, 01169 const float** ord_values, const int** missing, int* sample_buf = 0 ); 01170 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 01171 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 01172 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 01173 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing, 01174 float* responses, bool get_class_idx=false ); 01175 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 01176 const CvMat* missing_mask; 01177 }; 01178 01179 class CV_EXPORTS CvForestERTree : public CvForestTree 01180 { 01181 protected: 01182 virtual double calc_node_dir( CvDTreeNode* node ); 01183 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01184 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01185 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01186 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01187 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01188 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01189 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01190 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01191 virtual void split_node_data( CvDTreeNode* n ); 01192 }; 01193 01194 class CV_EXPORTS_W CvERTrees : public CvRTrees 01195 { 01196 public: 01197 CV_WRAP CvERTrees(); 01198 virtual ~CvERTrees(); 01199 virtual bool train( const CvMat* trainData, int tflag, 01200 const CvMat* responses, const CvMat* varIdx=0, 01201 const CvMat* sampleIdx=0, const CvMat* varType=0, 01202 const CvMat* missingDataMask=0, 01203 CvRTParams params=CvRTParams()); 01204 #ifndef SWIG 01205 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01206 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01207 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01208 const cv::Mat& missingDataMask=cv::Mat(), 01209 CvRTParams params=CvRTParams()); 01210 #endif 01211 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 01212 protected: 01213 virtual bool grow_forest( const CvTermCriteria term_crit ); 01214 }; 01215 01216 01217 /****************************************************************************************\ 01218 * Boosted tree classifier * 01219 \****************************************************************************************/ 01220 01221 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams 01222 { 01223 CV_PROP_RW int boost_type; 01224 CV_PROP_RW int weak_count; 01225 CV_PROP_RW int split_criteria; 01226 CV_PROP_RW double weight_trim_rate; 01227 01228 CvBoostParams(); 01229 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate, 01230 int max_depth, bool use_surrogates, const float* priors ); 01231 }; 01232 01233 01234 class CvBoost; 01235 01236 class CV_EXPORTS CvBoostTree: public CvDTree 01237 { 01238 public: 01239 CvBoostTree(); 01240 virtual ~CvBoostTree(); 01241 01242 virtual bool train( CvDTreeTrainData* trainData, 01243 const CvMat* subsample_idx, CvBoost* ensemble ); 01244 01245 virtual void scale( double s ); 01246 virtual void read( CvFileStorage* fs, CvFileNode* node, 01247 CvBoost* ensemble, CvDTreeTrainData* _data ); 01248 virtual void clear(); 01249 01250 /* dummy methods to avoid warnings: BEGIN */ 01251 virtual bool train( const CvMat* trainData, int tflag, 01252 const CvMat* responses, const CvMat* varIdx=0, 01253 const CvMat* sampleIdx=0, const CvMat* varType=0, 01254 const CvMat* missingDataMask=0, 01255 CvDTreeParams params=CvDTreeParams() ); 01256 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 01257 01258 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01259 virtual void read( CvFileStorage* fs, CvFileNode* node, 01260 CvDTreeTrainData* data ); 01261 /* dummy methods to avoid warnings: END */ 01262 01263 protected: 01264 01265 virtual void try_split_node( CvDTreeNode* n ); 01266 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01267 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 01268 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 01269 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01270 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 01271 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01272 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 01273 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01274 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 01275 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 01276 virtual void calc_node_value( CvDTreeNode* n ); 01277 virtual double calc_node_dir( CvDTreeNode* n ); 01278 01279 CvBoost* ensemble; 01280 }; 01281 01282 01283 class CV_EXPORTS_W CvBoost : public CvStatModel 01284 { 01285 public: 01286 // Boosting type 01287 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 }; 01288 01289 // Splitting criteria 01290 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 }; 01291 01292 CV_WRAP CvBoost(); 01293 virtual ~CvBoost(); 01294 01295 CvBoost( const CvMat* trainData, int tflag, 01296 const CvMat* responses, const CvMat* varIdx=0, 01297 const CvMat* sampleIdx=0, const CvMat* varType=0, 01298 const CvMat* missingDataMask=0, 01299 CvBoostParams params=CvBoostParams() ); 01300 01301 virtual bool train( const CvMat* trainData, int tflag, 01302 const CvMat* responses, const CvMat* varIdx=0, 01303 const CvMat* sampleIdx=0, const CvMat* varType=0, 01304 const CvMat* missingDataMask=0, 01305 CvBoostParams params=CvBoostParams(), 01306 bool update=false ); 01307 01308 virtual bool train( CvMLData* data, 01309 CvBoostParams params=CvBoostParams(), 01310 bool update=false ); 01311 01312 virtual float predict( const CvMat* sample, const CvMat* missing=0, 01313 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 01314 bool raw_mode=false, bool return_sum=false ) const; 01315 01316 #ifndef SWIG 01317 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag, 01318 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01319 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01320 const cv::Mat& missingDataMask=cv::Mat(), 01321 CvBoostParams params=CvBoostParams() ); 01322 01323 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01324 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01325 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01326 const cv::Mat& missingDataMask=cv::Mat(), 01327 CvBoostParams params=CvBoostParams(), 01328 bool update=false ); 01329 01330 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 01331 const cv::Range& slice=cv::Range::all(), bool rawMode=false, 01332 bool returnSum=false ) const; 01333 #endif 01334 01335 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 01336 01337 CV_WRAP virtual void prune( CvSlice slice ); 01338 01339 CV_WRAP virtual void clear(); 01340 01341 virtual void write( CvFileStorage* storage, const char* name ) const; 01342 virtual void read( CvFileStorage* storage, CvFileNode* node ); 01343 virtual const CvMat* get_active_vars(bool absolute_idx=true); 01344 01345 CvSeq* get_weak_predictors(); 01346 01347 CvMat* get_weights(); 01348 CvMat* get_subtree_weights(); 01349 CvMat* get_weak_response(); 01350 const CvBoostParams& get_params() const; 01351 const CvDTreeTrainData* get_data() const; 01352 01353 protected: 01354 01355 virtual bool set_params( const CvBoostParams& params ); 01356 virtual void update_weights( CvBoostTree* tree ); 01357 virtual void trim_weights(); 01358 virtual void write_params( CvFileStorage* fs ) const; 01359 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 01360 01361 CvDTreeTrainData* data; 01362 CvBoostParams params; 01363 CvSeq* weak; 01364 01365 CvMat* active_vars; 01366 CvMat* active_vars_abs; 01367 bool have_active_cat_vars; 01368 01369 CvMat* orig_response; 01370 CvMat* sum_response; 01371 CvMat* weak_eval; 01372 CvMat* subsample_mask; 01373 CvMat* weights; 01374 CvMat* subtree_weights; 01375 bool have_subsample; 01376 }; 01377 01378 01379 /****************************************************************************************\ 01380 * Gradient Boosted Trees * 01381 \****************************************************************************************/ 01382 01383 // DataType: STRUCT CvGBTreesParams 01384 // Parameters of GBT (Gradient Boosted trees model), including single 01385 // tree settings and ensemble parameters. 01386 // 01387 // weak_count - count of trees in the ensemble 01388 // loss_function_type - loss function used for ensemble training 01389 // subsample_portion - portion of whole training set used for 01390 // every single tree training. 01391 // subsample_portion value is in (0.0, 1.0]. 01392 // subsample_portion == 1.0 when whole dataset is 01393 // used on each step. Count of sample used on each 01394 // step is computed as 01395 // int(total_samples_count * subsample_portion). 01396 // shrinkage - regularization parameter. 01397 // Each tree prediction is multiplied on shrinkage value. 01398 01399 01400 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams 01401 { 01402 CV_PROP_RW int weak_count; 01403 CV_PROP_RW int loss_function_type; 01404 CV_PROP_RW float subsample_portion; 01405 CV_PROP_RW float shrinkage; 01406 01407 CvGBTreesParams(); 01408 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage, 01409 float subsample_portion, int max_depth, bool use_surrogates ); 01410 }; 01411 01412 // DataType: CLASS CvGBTrees 01413 // Gradient Boosting Trees (GBT) algorithm implementation. 01414 // 01415 // data - training dataset 01416 // params - parameters of the CvGBTrees 01417 // weak - array[0..(class_count-1)] of CvSeq 01418 // for storing tree ensembles 01419 // orig_response - original responses of the training set samples 01420 // sum_response - predicitons of the current model on the training dataset. 01421 // this matrix is updated on every iteration. 01422 // sum_response_tmp - predicitons of the model on the training set on the next 01423 // step. On every iteration values of sum_responses_tmp are 01424 // computed via sum_responses values. When the current 01425 // step is complete sum_response values become equal to 01426 // sum_responses_tmp. 01427 // sampleIdx - indices of samples used for training the ensemble. 01428 // CvGBTrees training procedure takes a set of samples 01429 // (train_data) and a set of responses (responses). 01430 // Only pairs (train_data[i], responses[i]), where i is 01431 // in sample_idx are used for training the ensemble. 01432 // subsample_train - indices of samples used for training a single decision 01433 // tree on the current step. This indices are countered 01434 // relatively to the sample_idx, so that pairs 01435 // (train_data[sample_idx[i]], responses[sample_idx[i]]) 01436 // are used for training a decision tree. 01437 // Training set is randomly splited 01438 // in two parts (subsample_train and subsample_test) 01439 // on every iteration accordingly to the portion parameter. 01440 // subsample_test - relative indices of samples from the training set, 01441 // which are not used for training a tree on the current 01442 // step. 01443 // missing - mask of the missing values in the training set. This 01444 // matrix has the same size as train_data. 1 - missing 01445 // value, 0 - not a missing value. 01446 // class_labels - output class labels map. 01447 // rng - random number generator. Used for spliting the 01448 // training set. 01449 // class_count - count of output classes. 01450 // class_count == 1 in the case of regression, 01451 // and > 1 in the case of classification. 01452 // delta - Huber loss function parameter. 01453 // base_value - start point of the gradient descent procedure. 01454 // model prediction is 01455 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where 01456 // f_0 is the base value. 01457 01458 01459 01460 class CV_EXPORTS_W CvGBTrees : public CvStatModel 01461 { 01462 public: 01463 01464 /* 01465 // DataType: ENUM 01466 // Loss functions implemented in CvGBTrees. 01467 // 01468 // SQUARED_LOSS 01469 // problem: regression 01470 // loss = (x - x')^2 01471 // 01472 // ABSOLUTE_LOSS 01473 // problem: regression 01474 // loss = abs(x - x') 01475 // 01476 // HUBER_LOSS 01477 // problem: regression 01478 // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta 01479 // 1/2*(x - x')^2, if abs(x - x') <= delta, 01480 // where delta is the alpha-quantile of pseudo responses from 01481 // the training set. 01482 // 01483 // DEVIANCE_LOSS 01484 // problem: classification 01485 // 01486 */ 01487 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS}; 01488 01489 01490 /* 01491 // Default constructor. Creates a model only (without training). 01492 // Should be followed by one form of the train(...) function. 01493 // 01494 // API 01495 // CvGBTrees(); 01496 01497 // INPUT 01498 // OUTPUT 01499 // RESULT 01500 */ 01501 CV_WRAP CvGBTrees(); 01502 01503 01504 /* 01505 // Full form constructor. Creates a gradient boosting model and does the 01506 // train. 01507 // 01508 // API 01509 // CvGBTrees( const CvMat* trainData, int tflag, 01510 const CvMat* responses, const CvMat* varIdx=0, 01511 const CvMat* sampleIdx=0, const CvMat* varType=0, 01512 const CvMat* missingDataMask=0, 01513 CvGBTreesParams params=CvGBTreesParams() ); 01514 01515 // INPUT 01516 // trainData - a set of input feature vectors. 01517 // size of matrix is 01518 // <count of samples> x <variables count> 01519 // or <variables count> x <count of samples> 01520 // depending on the tflag parameter. 01521 // matrix values are float. 01522 // tflag - a flag showing how do samples stored in the 01523 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 01524 // or column by column (tflag=CV_COL_SAMPLE). 01525 // responses - a vector of responses corresponding to the samples 01526 // in trainData. 01527 // varIdx - indices of used variables. zero value means that all 01528 // variables are active. 01529 // sampleIdx - indices of used samples. zero value means that all 01530 // samples from trainData are in the training set. 01531 // varType - vector of <variables count> length. gives every 01532 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 01533 // varType = 0 means all variables are numerical. 01534 // missingDataMask - a mask of misiing values in trainData. 01535 // missingDataMask = 0 means that there are no missing 01536 // values. 01537 // params - parameters of GTB algorithm. 01538 // OUTPUT 01539 // RESULT 01540 */ 01541 CvGBTrees( const CvMat* trainData, int tflag, 01542 const CvMat* responses, const CvMat* varIdx=0, 01543 const CvMat* sampleIdx=0, const CvMat* varType=0, 01544 const CvMat* missingDataMask=0, 01545 CvGBTreesParams params=CvGBTreesParams() ); 01546 01547 01548 /* 01549 // Destructor. 01550 */ 01551 virtual ~CvGBTrees(); 01552 01553 01554 /* 01555 // Gradient tree boosting model training 01556 // 01557 // API 01558 // virtual bool train( const CvMat* trainData, int tflag, 01559 const CvMat* responses, const CvMat* varIdx=0, 01560 const CvMat* sampleIdx=0, const CvMat* varType=0, 01561 const CvMat* missingDataMask=0, 01562 CvGBTreesParams params=CvGBTreesParams(), 01563 bool update=false ); 01564 01565 // INPUT 01566 // trainData - a set of input feature vectors. 01567 // size of matrix is 01568 // <count of samples> x <variables count> 01569 // or <variables count> x <count of samples> 01570 // depending on the tflag parameter. 01571 // matrix values are float. 01572 // tflag - a flag showing how do samples stored in the 01573 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 01574 // or column by column (tflag=CV_COL_SAMPLE). 01575 // responses - a vector of responses corresponding to the samples 01576 // in trainData. 01577 // varIdx - indices of used variables. zero value means that all 01578 // variables are active. 01579 // sampleIdx - indices of used samples. zero value means that all 01580 // samples from trainData are in the training set. 01581 // varType - vector of <variables count> length. gives every 01582 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 01583 // varType = 0 means all variables are numerical. 01584 // missingDataMask - a mask of misiing values in trainData. 01585 // missingDataMask = 0 means that there are no missing 01586 // values. 01587 // params - parameters of GTB algorithm. 01588 // update - is not supported now. (!) 01589 // OUTPUT 01590 // RESULT 01591 // Error state. 01592 */ 01593 virtual bool train( const CvMat* trainData, int tflag, 01594 const CvMat* responses, const CvMat* varIdx=0, 01595 const CvMat* sampleIdx=0, const CvMat* varType=0, 01596 const CvMat* missingDataMask=0, 01597 CvGBTreesParams params=CvGBTreesParams(), 01598 bool update=false ); 01599 01600 01601 /* 01602 // Gradient tree boosting model training 01603 // 01604 // API 01605 // virtual bool train( CvMLData* data, 01606 CvGBTreesParams params=CvGBTreesParams(), 01607 bool update=false ) {return false;}; 01608 01609 // INPUT 01610 // data - training set. 01611 // params - parameters of GTB algorithm. 01612 // update - is not supported now. (!) 01613 // OUTPUT 01614 // RESULT 01615 // Error state. 01616 */ 01617 virtual bool train( CvMLData* data, 01618 CvGBTreesParams params=CvGBTreesParams(), 01619 bool update=false ); 01620 01621 01622 /* 01623 // Response value prediction 01624 // 01625 // API 01626 // virtual float predict( const CvMat* sample, const CvMat* missing=0, 01627 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, 01628 int k=-1 ) const; 01629 01630 // INPUT 01631 // sample - input sample of the same type as in the training set. 01632 // missing - missing values mask. missing=0 if there are no 01633 // missing values in sample vector. 01634 // weak_responses - predictions of all of the trees. 01635 // not implemented (!) 01636 // slice - part of the ensemble used for prediction. 01637 // slice = CV_WHOLE_SEQ when all trees are used. 01638 // k - number of ensemble used. 01639 // k is in {-1,0,1,..,<count of output classes-1>}. 01640 // in the case of classification problem 01641 // <count of output classes-1> ensembles are built. 01642 // If k = -1 ordinary prediction is the result, 01643 // otherwise function gives the prediction of the 01644 // k-th ensemble only. 01645 // OUTPUT 01646 // RESULT 01647 // Predicted value. 01648 */ 01649 virtual float predict( const CvMat* sample, const CvMat* missing=0, 01650 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, 01651 int k=-1 ) const; 01652 01653 /* 01654 // Delete all temporary data. 01655 // 01656 // API 01657 // virtual void clear(); 01658 01659 // INPUT 01660 // OUTPUT 01661 // delete data, weak, orig_response, sum_response, 01662 // weak_eval, ubsample_train, subsample_test, 01663 // sample_idx, missing, lass_labels 01664 // delta = 0.0 01665 // RESULT 01666 */ 01667 CV_WRAP virtual void clear(); 01668 01669 /* 01670 // Compute error on the train/test set. 01671 // 01672 // API 01673 // virtual float calc_error( CvMLData* _data, int type, 01674 // std::vector<float> *resp = 0 ); 01675 // 01676 // INPUT 01677 // data - dataset 01678 // type - defines which error is to compute^ train (CV_TRAIN_ERROR) or 01679 // test (CV_TEST_ERROR). 01680 // OUTPUT 01681 // resp - vector of predicitons 01682 // RESULT 01683 // Error value. 01684 */ 01685 virtual float calc_error( CvMLData* _data, int type, 01686 std::vector<float> *resp = 0 ); 01687 01688 01689 /* 01690 // 01691 // Write parameters of the gtb model and data. Write learned model. 01692 // 01693 // API 01694 // virtual void write( CvFileStorage* fs, const char* name ) const; 01695 // 01696 // INPUT 01697 // fs - file storage to read parameters from. 01698 // name - model name. 01699 // OUTPUT 01700 // RESULT 01701 */ 01702 virtual void write( CvFileStorage* fs, const char* name ) const; 01703 01704 01705 /* 01706 // 01707 // Read parameters of the gtb model and data. Read learned model. 01708 // 01709 // API 01710 // virtual void read( CvFileStorage* fs, CvFileNode* node ); 01711 // 01712 // INPUT 01713 // fs - file storage to read parameters from. 01714 // node - file node. 01715 // OUTPUT 01716 // RESULT 01717 */ 01718 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01719 01720 01721 // new-style C++ interface 01722 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag, 01723 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01724 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01725 const cv::Mat& missingDataMask=cv::Mat(), 01726 CvGBTreesParams params=CvGBTreesParams() ); 01727 01728 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 01729 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 01730 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 01731 const cv::Mat& missingDataMask=cv::Mat(), 01732 CvGBTreesParams params=CvGBTreesParams(), 01733 bool update=false ); 01734 01735 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 01736 const cv::Range& slice = cv::Range::all(), 01737 int k=-1 ) const; 01738 01739 protected: 01740 01741 /* 01742 // Compute the gradient vector components. 01743 // 01744 // API 01745 // virtual void find_gradient( const int k = 0); 01746 01747 // INPUT 01748 // k - used for classification problem, determining current 01749 // tree ensemble. 01750 // OUTPUT 01751 // changes components of data->responses 01752 // which correspond to samples used for training 01753 // on the current step. 01754 // RESULT 01755 */ 01756 virtual void find_gradient( const int k = 0); 01757 01758 01759 /* 01760 // 01761 // Change values in tree leaves according to the used loss function. 01762 // 01763 // API 01764 // virtual void change_values(CvDTree* tree, const int k = 0); 01765 // 01766 // INPUT 01767 // tree - decision tree to change. 01768 // k - used for classification problem, determining current 01769 // tree ensemble. 01770 // OUTPUT 01771 // changes 'value' fields of the trees' leaves. 01772 // changes sum_response_tmp. 01773 // RESULT 01774 */ 01775 virtual void change_values(CvDTree* tree, const int k = 0); 01776 01777 01778 /* 01779 // 01780 // Find optimal constant prediction value according to the used loss 01781 // function. 01782 // The goal is to find a constant which gives the minimal summary loss 01783 // on the _Idx samples. 01784 // 01785 // API 01786 // virtual float find_optimal_value( const CvMat* _Idx ); 01787 // 01788 // INPUT 01789 // _Idx - indices of the samples from the training set. 01790 // OUTPUT 01791 // RESULT 01792 // optimal constant value. 01793 */ 01794 virtual float find_optimal_value( const CvMat* _Idx ); 01795 01796 01797 /* 01798 // 01799 // Randomly split the whole training set in two parts according 01800 // to params.portion. 01801 // 01802 // API 01803 // virtual void do_subsample(); 01804 // 01805 // INPUT 01806 // OUTPUT 01807 // subsample_train - indices of samples used for training 01808 // subsample_test - indices of samples used for test 01809 // RESULT 01810 */ 01811 virtual void do_subsample(); 01812 01813 01814 /* 01815 // 01816 // Internal recursive function giving an array of subtree tree leaves. 01817 // 01818 // API 01819 // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 01820 // 01821 // INPUT 01822 // node - current leaf. 01823 // OUTPUT 01824 // count - count of leaves in the subtree. 01825 // leaves - array of pointers to leaves. 01826 // RESULT 01827 */ 01828 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 01829 01830 01831 /* 01832 // 01833 // Get leaves of the tree. 01834 // 01835 // API 01836 // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 01837 // 01838 // INPUT 01839 // dtree - decision tree. 01840 // OUTPUT 01841 // len - count of the leaves. 01842 // RESULT 01843 // CvDTreeNode** - array of pointers to leaves. 01844 */ 01845 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 01846 01847 01848 /* 01849 // 01850 // Is it a regression or a classification. 01851 // 01852 // API 01853 // bool problem_type(); 01854 // 01855 // INPUT 01856 // OUTPUT 01857 // RESULT 01858 // false if it is a classification problem, 01859 // true - if regression. 01860 */ 01861 virtual bool problem_type() const; 01862 01863 01864 /* 01865 // 01866 // Write parameters of the gtb model. 01867 // 01868 // API 01869 // virtual void write_params( CvFileStorage* fs ) const; 01870 // 01871 // INPUT 01872 // fs - file storage to write parameters to. 01873 // OUTPUT 01874 // RESULT 01875 */ 01876 virtual void write_params( CvFileStorage* fs ) const; 01877 01878 01879 /* 01880 // 01881 // Read parameters of the gtb model and data. 01882 // 01883 // API 01884 // virtual void read_params( CvFileStorage* fs ); 01885 // 01886 // INPUT 01887 // fs - file storage to read parameters from. 01888 // OUTPUT 01889 // params - parameters of the gtb model. 01890 // data - contains information about the structure 01891 // of the data set (count of variables, 01892 // their types, etc.). 01893 // class_labels - output class labels map. 01894 // RESULT 01895 */ 01896 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode ); 01897 01898 01899 CvDTreeTrainData* data; 01900 CvGBTreesParams params; 01901 01902 CvSeq** weak; 01903 CvMat* orig_response; 01904 CvMat* sum_response; 01905 CvMat* sum_response_tmp; 01906 CvMat* weak_eval; 01907 CvMat* sample_idx; 01908 CvMat* subsample_train; 01909 CvMat* subsample_test; 01910 CvMat* missing; 01911 CvMat* class_labels; 01912 01913 cv::RNG* rng; 01914 01915 int class_count; 01916 float delta; 01917 float base_value; 01918 01919 }; 01920 01921 01922 01923 /****************************************************************************************\ 01924 * Artificial Neural Networks (ANN) * 01925 \****************************************************************************************/ 01926 01928 01929 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams 01930 { 01931 CvANN_MLP_TrainParams(); 01932 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method, 01933 double param1, double param2=0 ); 01934 ~CvANN_MLP_TrainParams(); 01935 01936 enum { BACKPROP=0, RPROP=1 }; 01937 01938 CV_PROP_RW CvTermCriteria term_crit; 01939 CV_PROP_RW int train_method; 01940 01941 // backpropagation parameters 01942 CV_PROP_RW double bp_dw_scale, bp_moment_scale; 01943 01944 // rprop parameters 01945 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max; 01946 }; 01947 01948 01949 class CV_EXPORTS_W CvANN_MLP : public CvStatModel 01950 { 01951 public: 01952 CV_WRAP CvANN_MLP(); 01953 CvANN_MLP( const CvMat* layerSizes, 01954 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01955 double fparam1=0, double fparam2=0 ); 01956 01957 virtual ~CvANN_MLP(); 01958 01959 virtual void create( const CvMat* layerSizes, 01960 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01961 double fparam1=0, double fparam2=0 ); 01962 01963 virtual int train( const CvMat* inputs, const CvMat* outputs, 01964 const CvMat* sampleWeights, const CvMat* sampleIdx=0, 01965 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 01966 int flags=0 ); 01967 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const; 01968 01969 #ifndef SWIG 01970 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes, 01971 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01972 double fparam1=0, double fparam2=0 ); 01973 01974 CV_WRAP virtual void create( const cv::Mat& layerSizes, 01975 int activateFunc=CvANN_MLP::SIGMOID_SYM, 01976 double fparam1=0, double fparam2=0 ); 01977 01978 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs, 01979 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(), 01980 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 01981 int flags=0 ); 01982 01983 CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const; 01984 #endif 01985 01986 CV_WRAP virtual void clear(); 01987 01988 // possible activation functions 01989 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 }; 01990 01991 // available training flags 01992 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 }; 01993 01994 virtual void read( CvFileStorage* fs, CvFileNode* node ); 01995 virtual void write( CvFileStorage* storage, const char* name ) const; 01996 01997 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; } 01998 const CvMat* get_layer_sizes() { return layer_sizes; } 01999 double* get_weights(int layer) 02000 { 02001 return layer_sizes && weights && 02002 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0; 02003 } 02004 02005 protected: 02006 02007 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 02008 const CvMat* _sample_weights, const CvMat* sampleIdx, 02009 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ); 02010 02011 // sequential random backpropagation 02012 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 02013 02014 // RPROP algorithm 02015 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 02016 02017 virtual void calc_activ_func( CvMat* xf, const double* bias ) const; 02018 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const; 02019 virtual void set_activ_func( int _activ_func=SIGMOID_SYM, 02020 double _f_param1=0, double _f_param2=0 ); 02021 virtual void init_weights(); 02022 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const; 02023 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const; 02024 virtual void calc_input_scale( const CvVectors* vecs, int flags ); 02025 virtual void calc_output_scale( const CvVectors* vecs, int flags ); 02026 02027 virtual void write_params( CvFileStorage* fs ) const; 02028 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 02029 02030 CvMat* layer_sizes; 02031 CvMat* wbuf; 02032 CvMat* sample_weights; 02033 double** weights; 02034 double f_param1, f_param2; 02035 double min_val, max_val, min_val1, max_val1; 02036 int activ_func; 02037 int max_count, max_buf_sz; 02038 CvANN_MLP_TrainParams params; 02039 cv::RNG* rng; 02040 }; 02041 02042 /****************************************************************************************\ 02043 * Auxilary functions declarations * 02044 \****************************************************************************************/ 02045 02046 /* Generates <sample> from multivariate normal distribution, where <mean> - is an 02047 average row vector, <cov> - symmetric covariation matrix */ 02048 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample, 02049 CvRNG* rng CV_DEFAULT(0) ); 02050 02051 /* Generates sample from gaussian mixture distribution */ 02052 CVAPI(void) cvRandGaussMixture( CvMat* means[], 02053 CvMat* covs[], 02054 float weights[], 02055 int clsnum, 02056 CvMat* sample, 02057 CvMat* sampClasses CV_DEFAULT(0) ); 02058 02059 #define CV_TS_CONCENTRIC_SPHERES 0 02060 02061 /* creates test set */ 02062 CVAPI(void) cvCreateTestSet( int type, CvMat** samples, 02063 int num_samples, 02064 int num_features, 02065 CvMat** responses, 02066 int num_classes, ... ); 02067 02068 02069 #endif 02070 02071 /****************************************************************************************\ 02072 * Data * 02073 \****************************************************************************************/ 02074 02075 #include <map> 02076 #include <string> 02077 #include <iostream> 02078 02079 #define CV_COUNT 0 02080 #define CV_PORTION 1 02081 02082 struct CV_EXPORTS CvTrainTestSplit 02083 { 02084 public: 02085 CvTrainTestSplit(); 02086 CvTrainTestSplit( int _train_sample_count, bool _mix = true); 02087 CvTrainTestSplit( float _train_sample_portion, bool _mix = true); 02088 02089 union 02090 { 02091 int count; 02092 float portion; 02093 } train_sample_part; 02094 int train_sample_part_mode; 02095 02096 union 02097 { 02098 int *count; 02099 float *portion; 02100 } *class_part; 02101 int class_part_mode; 02102 02103 bool mix; 02104 }; 02105 02106 class CV_EXPORTS CvMLData 02107 { 02108 public: 02109 CvMLData(); 02110 virtual ~CvMLData(); 02111 02112 // returns: 02113 // 0 - OK 02114 // 1 - file can not be opened or is not correct 02115 int read_csv(const char* filename); 02116 02117 const CvMat* get_values(){ return values; }; 02118 02119 const CvMat* get_responses(); 02120 02121 const CvMat* get_missing(){ return missing; }; 02122 02123 void set_response_idx( int idx ); // old response become predictors, new response_idx = idx 02124 // if idx < 0 there will be no response 02125 int get_response_idx() { return response_idx; } 02126 02127 const CvMat* get_train_sample_idx() { return train_sample_idx; }; 02128 const CvMat* get_test_sample_idx() { return test_sample_idx; }; 02129 void mix_train_and_test_idx(); 02130 void set_train_test_split( const CvTrainTestSplit * spl); 02131 02132 const CvMat* get_var_idx(); 02133 void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor 02134 02135 const CvMat* get_var_types(); 02136 int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; }; 02137 // following 2 methods enable to change vars type 02138 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable 02139 // with numerical labels; in the other cases var types are correctly determined automatically 02140 void set_var_types( const char* str ); // str examples: 02141 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]", 02142 // "cat", "ord" (all vars are categorical/ordered) 02143 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL } 02144 02145 void set_delimiter( char ch ); 02146 char get_delimiter() { return delimiter; }; 02147 02148 void set_miss_ch( char ch ); 02149 char get_miss_ch() { return miss_ch; }; 02150 02151 protected: 02152 virtual void clear(); 02153 02154 void str_to_flt_elem( const char* token, float& flt_elem, int& type); 02155 void free_train_test_idx(); 02156 02157 char delimiter; 02158 char miss_ch; 02159 //char flt_separator; 02160 02161 CvMat* values; 02162 CvMat* missing; 02163 CvMat* var_types; 02164 CvMat* var_idx_mask; 02165 02166 CvMat* response_out; // header 02167 CvMat* var_idx_out; // mat 02168 CvMat* var_types_out; // mat 02169 02170 int response_idx; 02171 02172 int train_sample_count; 02173 bool mix; 02174 02175 int total_class_count; 02176 std::map<std::string, int> *class_map; 02177 02178 CvMat* train_sample_idx; 02179 CvMat* test_sample_idx; 02180 int* sample_idx; // data of train_sample_idx and test_sample_idx 02181 02182 cv::RNG* rng; 02183 }; 02184 02185 02186 namespace cv 02187 { 02188 02189 typedef CvStatModel StatModel; 02190 typedef CvParamGrid ParamGrid; 02191 typedef CvNormalBayesClassifier NormalBayesClassifier; 02192 typedef CvKNearest KNearest; 02193 typedef CvSVMParams SVMParams; 02194 typedef CvSVMKernel SVMKernel; 02195 typedef CvSVMSolver SVMSolver; 02196 typedef CvSVM SVM; 02197 typedef CvEMParams EMParams; 02198 typedef CvEM ExpectationMaximization; 02199 typedef CvDTreeParams DTreeParams; 02200 typedef CvMLData TrainData; 02201 typedef CvDTree DecisionTree; 02202 typedef CvForestTree ForestTree; 02203 typedef CvRTParams RandomTreeParams; 02204 typedef CvRTrees RandomTrees; 02205 typedef CvERTreeTrainData ERTreeTRainData; 02206 typedef CvForestERTree ERTree; 02207 typedef CvERTrees ERTrees; 02208 typedef CvBoostParams BoostParams; 02209 typedef CvBoostTree BoostTree; 02210 typedef CvBoost Boost; 02211 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams; 02212 typedef CvANN_MLP NeuralNet_MLP; 02213 typedef CvGBTreesParams GradientBoostingTreeParams; 02214 typedef CvGBTrees GradientBoostingTrees; 02215 02216 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj(); 02217 02218 } 02219 02220 #endif 02221 /* End of file. */