opencv 2.2.0
/usr/src/RPM/BUILD/libopencv2.2-2.2.0/modules/ml/include/opencv2/ml/ml.hpp
Go to the documentation of this file.
00001 /*M///////////////////////////////////////////////////////////////////////////////////////
00002 //
00003 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
00004 //
00005 //  By downloading, copying, installing or using the software you agree to this license.
00006 //  If you do not agree to this license, do not download, install,
00007 //  copy or use the software.
00008 //
00009 //
00010 //                        Intel License Agreement
00011 //
00012 // Copyright (C) 2000, Intel Corporation, all rights reserved.
00013 // Third party copyrights are property of their respective owners.
00014 //
00015 // Redistribution and use in source and binary forms, with or without modification,
00016 // are permitted provided that the following conditions are met:
00017 //
00018 //   * Redistribution's of source code must retain the above copyright notice,
00019 //     this list of conditions and the following disclaimer.
00020 //
00021 //   * Redistribution's in binary form must reproduce the above copyright notice,
00022 //     this list of conditions and the following disclaimer in the documentation
00023 //     and/or other materials provided with the distribution.
00024 //
00025 //   * The name of Intel Corporation may not be used to endorse or promote products
00026 //     derived from this software without specific prior written permission.
00027 //
00028 // This software is provided by the copyright holders and contributors "as is" and
00029 // any express or implied warranties, including, but not limited to, the implied
00030 // warranties of merchantability and fitness for a particular purpose are disclaimed.
00031 // In no event shall the Intel Corporation or contributors be liable for any direct,
00032 // indirect, incidental, special, exemplary, or consequential damages
00033 // (including, but not limited to, procurement of substitute goods or services;
00034 // loss of use, data, or profits; or business interruption) however caused
00035 // and on any theory of liability, whether in contract, strict liability,
00036 // or tort (including negligence or otherwise) arising in any way out of
00037 // the use of this software, even if advised of the possibility of such damage.
00038 //
00039 //M*/
00040 
00041 #ifndef __OPENCV_ML_HPP__
00042 #define __OPENCV_ML_HPP__
00043 
00044 // disable deprecation warning which appears in VisualStudio 8.0
00045 #if _MSC_VER >= 1400
00046 #pragma warning( disable : 4996 )
00047 #endif
00048 
00049 #ifndef SKIP_INCLUDES
00050 
00051   #include "opencv2/core/core.hpp"
00052   #include <limits.h>
00053 
00054   #if defined WIN32 || defined _WIN32
00055     #include <windows.h>
00056   #endif
00057 
00058 #else // SKIP_INCLUDES
00059 
00060   #if defined WIN32 || defined _WIN32
00061     #define CV_CDECL __cdecl
00062     #define CV_STDCALL __stdcall
00063   #else
00064     #define CV_CDECL
00065     #define CV_STDCALL
00066   #endif
00067 
00068   #ifndef CV_EXTERN_C
00069     #ifdef __cplusplus
00070       #define CV_EXTERN_C extern "C"
00071       #define CV_DEFAULT(val) = val
00072     #else
00073       #define CV_EXTERN_C
00074       #define CV_DEFAULT(val)
00075     #endif
00076   #endif
00077 
00078   #ifndef CV_EXTERN_C_FUNCPTR
00079     #ifdef __cplusplus
00080       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
00081     #else
00082       #define CV_EXTERN_C_FUNCPTR(x) typedef x
00083     #endif
00084   #endif
00085 
00086   #ifndef CV_INLINE
00087     #if defined __cplusplus
00088       #define CV_INLINE inline
00089     #elif (defined WIN32 || defined _WIN32) && !defined __GNUC__
00090       #define CV_INLINE __inline
00091     #else
00092       #define CV_INLINE static
00093     #endif
00094   #endif /* CV_INLINE */
00095 
00096   #if (defined WIN32 || defined _WIN32) && defined CVAPI_EXPORTS
00097     #define CV_EXPORTS __declspec(dllexport)
00098   #else
00099     #define CV_EXPORTS
00100   #endif
00101 
00102   #ifndef CVAPI
00103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
00104   #endif
00105 
00106 #endif // SKIP_INCLUDES
00107 
00108 
00109 #ifdef __cplusplus
00110 
00111 // Apple defines a check() macro somewhere in the debug headers
00112 // that interferes with a method definiton in this header
00113 #undef check
00114 
00115 /****************************************************************************************\
00116 *                               Main struct definitions                                  *
00117 \****************************************************************************************/
00118 
00119 /* log(2*PI) */
00120 #define CV_LOG2PI (1.8378770664093454835606594728112)
00121 
00122 /* columns of <trainData> matrix are training samples */
00123 #define CV_COL_SAMPLE 0
00124 
00125 /* rows of <trainData> matrix are training samples */
00126 #define CV_ROW_SAMPLE 1
00127 
00128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
00129 
00130 struct CvVectors
00131 {
00132     int type;
00133     int dims, count;
00134     CvVectors* next;
00135     union
00136     {
00137         uchar** ptr;
00138         float** fl;
00139         double** db;
00140     } data;
00141 };
00142 
00143 #if 0
00144 /* A structure, representing the lattice range of statmodel parameters.
00145    It is used for optimizing statmodel parameters by cross-validation method.
00146    The lattice is logarithmic, so <step> must be greater then 1. */
00147 typedef struct CvParamLattice
00148 {
00149     double min_val;
00150     double max_val;
00151     double step;
00152 }
00153 CvParamLattice;
00154 
00155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
00156                                          double log_step )
00157 {
00158     CvParamLattice pl;
00159     pl.min_val = MIN( min_val, max_val );
00160     pl.max_val = MAX( min_val, max_val );
00161     pl.step = MAX( log_step, 1. );
00162     return pl;
00163 }
00164 
00165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
00166 {
00167     CvParamLattice pl = {0,0,0};
00168     return pl;
00169 }
00170 #endif
00171 
00172 /* Variable type */
00173 #define CV_VAR_NUMERICAL    0
00174 #define CV_VAR_ORDERED      0
00175 #define CV_VAR_CATEGORICAL  1
00176 
00177 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
00178 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
00179 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
00180 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
00181 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
00182 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
00183 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
00184 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
00185 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
00186 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
00187 
00188 #define CV_TRAIN_ERROR  0
00189 #define CV_TEST_ERROR   1
00190 
00191 class CV_EXPORTS_W CvStatModel
00192 {
00193 public:
00194     CvStatModel();
00195     virtual ~CvStatModel();
00196 
00197     virtual void clear();
00198 
00199     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
00200     CV_WRAP virtual void load( const char* filename, const char* name=0 );
00201 
00202     virtual void write( CvFileStorage* storage, const char* name ) const;
00203     virtual void read( CvFileStorage* storage, CvFileNode* node );
00204 
00205 protected:
00206     const char* default_model_name;
00207 };
00208 
00209 /****************************************************************************************\
00210 *                                 Normal Bayes Classifier                                *
00211 \****************************************************************************************/
00212 
00213 /* The structure, representing the grid range of statmodel parameters.
00214    It is used for optimizing statmodel accuracy by varying model parameters,
00215    the accuracy estimate being computed by cross-validation.
00216    The grid is logarithmic, so <step> must be greater then 1. */
00217 
00218 class CvMLData;
00219 
00220 struct CV_EXPORTS_W_MAP CvParamGrid
00221 {
00222     // SVM params type
00223     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
00224 
00225     CvParamGrid()
00226     {
00227         min_val = max_val = step = 0;
00228     }
00229 
00230     CvParamGrid( double _min_val, double _max_val, double log_step )
00231     {
00232         min_val = _min_val;
00233         max_val = _max_val;
00234         step = log_step;
00235     }
00236     //CvParamGrid( int param_id );
00237     bool check() const;
00238 
00239     CV_PROP_RW double min_val;
00240     CV_PROP_RW double max_val;
00241     CV_PROP_RW double step;
00242 };
00243 
00244 class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
00245 {
00246 public:
00247     CV_WRAP CvNormalBayesClassifier();
00248     virtual ~CvNormalBayesClassifier();
00249 
00250     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
00251         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
00252     
00253     virtual bool train( const CvMat* trainData, const CvMat* responses,
00254         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
00255    
00256     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
00257     CV_WRAP virtual void clear();
00258 
00259 #ifndef SWIG
00260     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
00261                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
00262     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00263                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00264                        bool update=false );
00265     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;
00266 #endif
00267     
00268     virtual void write( CvFileStorage* storage, const char* name ) const;
00269     virtual void read( CvFileStorage* storage, CvFileNode* node );
00270 
00271 protected:
00272     int     var_count, var_all;
00273     CvMat*  var_idx;
00274     CvMat*  cls_labels;
00275     CvMat** count;
00276     CvMat** sum;
00277     CvMat** productsum;
00278     CvMat** avg;
00279     CvMat** inv_eigen_values;
00280     CvMat** cov_rotate_mats;
00281     CvMat*  c;
00282 };
00283 
00284 
00285 /****************************************************************************************\
00286 *                          K-Nearest Neighbour Classifier                                *
00287 \****************************************************************************************/
00288 
00289 // k Nearest Neighbors
00290 class CV_EXPORTS_W CvKNearest : public CvStatModel
00291 {
00292 public:
00293 
00294     CV_WRAP CvKNearest();
00295     virtual ~CvKNearest();
00296 
00297     CvKNearest( const CvMat* trainData, const CvMat* responses,
00298                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
00299 
00300     virtual bool train( const CvMat* trainData, const CvMat* responses,
00301                         const CvMat* sampleIdx=0, bool is_regression=false,
00302                         int maxK=32, bool updateBase=false );
00303     
00304     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
00305         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
00306     
00307 #ifndef SWIG
00308     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
00309                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
00310     
00311     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00312                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
00313                        int maxK=32, bool updateBase=false );    
00314     
00315     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
00316                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
00317                                 cv::Mat* dist=0 ) const;
00318     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
00319                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
00320 #endif
00321     
00322     virtual void clear();
00323     int get_max_k() const;
00324     int get_var_count() const;
00325     int get_sample_count() const;
00326     bool is_regression() const;
00327 
00328 protected:
00329 
00330     virtual float write_results( int k, int k1, int start, int end,
00331         const float* neighbor_responses, const float* dist, CvMat* _results,
00332         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
00333 
00334     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
00335         float* neighbor_responses, const float** neighbors, float* dist ) const;
00336 
00337 
00338     int max_k, var_count;
00339     int total;
00340     bool regression;
00341     CvVectors* samples;
00342 };
00343 
00344 /****************************************************************************************\
00345 *                                   Support Vector Machines                              *
00346 \****************************************************************************************/
00347 
00348 // SVM training parameters
00349 struct CV_EXPORTS_W_MAP CvSVMParams
00350 {
00351     CvSVMParams();
00352     CvSVMParams( int _svm_type, int _kernel_type,
00353                  double _degree, double _gamma, double _coef0,
00354                  double Cvalue, double _nu, double _p,
00355                  CvMat* _class_weights, CvTermCriteria _term_crit );
00356 
00357     CV_PROP_RW int         svm_type;
00358     CV_PROP_RW int         kernel_type;
00359     CV_PROP_RW double      degree; // for poly
00360     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid
00361     CV_PROP_RW double      coef0;  // for poly/sigmoid
00362 
00363     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
00364     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
00365     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
00366     CvMat*      class_weights; // for CV_SVM_C_SVC
00367     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
00368 };
00369 
00370 
00371 struct CV_EXPORTS CvSVMKernel
00372 {
00373     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
00374                                        const float* another, float* results );
00375     CvSVMKernel();
00376     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
00377     virtual bool create( const CvSVMParams* params, Calc _calc_func );
00378     virtual ~CvSVMKernel();
00379 
00380     virtual void clear();
00381     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
00382 
00383     const CvSVMParams* params;
00384     Calc calc_func;
00385 
00386     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
00387                                     const float* another, float* results,
00388                                     double alpha, double beta );
00389 
00390     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
00391                               const float* another, float* results );
00392     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
00393                            const float* another, float* results );
00394     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
00395                             const float* another, float* results );
00396     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
00397                                const float* another, float* results );
00398 };
00399 
00400 
00401 struct CvSVMKernelRow
00402 {
00403     CvSVMKernelRow* prev;
00404     CvSVMKernelRow* next;
00405     float* data;
00406 };
00407 
00408 
00409 struct CvSVMSolutionInfo
00410 {
00411     double obj;
00412     double rho;
00413     double upper_bound_p;
00414     double upper_bound_n;
00415     double r;   // for Solver_NU
00416 };
00417 
00418 class CV_EXPORTS CvSVMSolver
00419 {
00420 public:
00421     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
00422     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
00423     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
00424 
00425     CvSVMSolver();
00426 
00427     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
00428                  int alpha_count, double* alpha, double Cp, double Cn,
00429                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00430                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00431     virtual bool create( int count, int var_count, const float** samples, schar* y,
00432                  int alpha_count, double* alpha, double Cp, double Cn,
00433                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
00434                  SelectWorkingSet select_working_set, CalcRho calc_rho );
00435     virtual ~CvSVMSolver();
00436 
00437     virtual void clear();
00438     virtual bool solve_generic( CvSVMSolutionInfo& si );
00439 
00440     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
00441                               double Cp, double Cn, CvMemStorage* storage,
00442                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
00443     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
00444                                CvMemStorage* storage, CvSVMKernel* kernel,
00445                                double* alpha, CvSVMSolutionInfo& si );
00446     virtual bool solve_one_class( int count, int var_count, const float** samples,
00447                                   CvMemStorage* storage, CvSVMKernel* kernel,
00448                                   double* alpha, CvSVMSolutionInfo& si );
00449 
00450     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
00451                                 CvMemStorage* storage, CvSVMKernel* kernel,
00452                                 double* alpha, CvSVMSolutionInfo& si );
00453 
00454     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
00455                                CvMemStorage* storage, CvSVMKernel* kernel,
00456                                double* alpha, CvSVMSolutionInfo& si );
00457 
00458     virtual float* get_row_base( int i, bool* _existed );
00459     virtual float* get_row( int i, float* dst );
00460 
00461     int sample_count;
00462     int var_count;
00463     int cache_size;
00464     int cache_line_size;
00465     const float** samples;
00466     const CvSVMParams* params;
00467     CvMemStorage* storage;
00468     CvSVMKernelRow lru_list;
00469     CvSVMKernelRow* rows;
00470 
00471     int alpha_count;
00472 
00473     double* G;
00474     double* alpha;
00475 
00476     // -1 - lower bound, 0 - free, 1 - upper bound
00477     schar* alpha_status;
00478 
00479     schar* y;
00480     double* b;
00481     float* buf[2];
00482     double eps;
00483     int max_iter;
00484     double C[2];  // C[0] == Cn, C[1] == Cp
00485     CvSVMKernel* kernel;
00486 
00487     SelectWorkingSet select_working_set_func;
00488     CalcRho calc_rho_func;
00489     GetRow get_row_func;
00490 
00491     virtual bool select_working_set( int& i, int& j );
00492     virtual bool select_working_set_nu_svm( int& i, int& j );
00493     virtual void calc_rho( double& rho, double& r );
00494     virtual void calc_rho_nu_svm( double& rho, double& r );
00495 
00496     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
00497     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
00498     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
00499 };
00500 
00501 
00502 struct CvSVMDecisionFunc
00503 {
00504     double rho;
00505     int sv_count;
00506     double* alpha;
00507     int* sv_index;
00508 };
00509 
00510 
00511 // SVM model
00512 class CV_EXPORTS_W CvSVM : public CvStatModel
00513 {
00514 public:
00515     // SVM type
00516     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
00517 
00518     // SVM kernel type
00519     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
00520 
00521     // SVM params type
00522     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
00523 
00524     CV_WRAP CvSVM();
00525     virtual ~CvSVM();
00526 
00527     CvSVM( const CvMat* trainData, const CvMat* responses,
00528            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00529            CvSVMParams params=CvSVMParams() );
00530 
00531     virtual bool train( const CvMat* trainData, const CvMat* responses,
00532                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
00533                         CvSVMParams params=CvSVMParams() );
00534     
00535     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
00536         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
00537         int kfold = 10,
00538         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
00539         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
00540         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
00541         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
00542         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
00543         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
00544         bool balanced=false );
00545 
00546     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
00547 
00548 #ifndef SWIG
00549     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
00550           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00551           CvSVMParams params=CvSVMParams() );
00552     
00553     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
00554                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
00555                        CvSVMParams params=CvSVMParams() );
00556     
00557     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
00558                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
00559                             int k_fold = 10,
00560                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
00561                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
00562                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
00563                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
00564                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
00565                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
00566                             bool balanced=false);
00567     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;    
00568 #endif
00569     
00570     CV_WRAP virtual int get_support_vector_count() const;
00571     virtual const float* get_support_vector(int i) const;
00572     virtual CvSVMParams get_params() const { return params; };
00573     CV_WRAP virtual void clear();
00574 
00575     static CvParamGrid get_default_grid( int param_id );
00576 
00577     virtual void write( CvFileStorage* storage, const char* name ) const;
00578     virtual void read( CvFileStorage* storage, CvFileNode* node );
00579     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
00580 
00581 protected:
00582 
00583     virtual bool set_params( const CvSVMParams& params );
00584     virtual bool train1( int sample_count, int var_count, const float** samples,
00585                     const void* responses, double Cp, double Cn,
00586                     CvMemStorage* _storage, double* alpha, double& rho );
00587     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
00588                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
00589     virtual void create_kernel();
00590     virtual void create_solver();
00591 
00592     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
00593 
00594     virtual void write_params( CvFileStorage* fs ) const;
00595     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00596 
00597     CvSVMParams params;
00598     CvMat* class_labels;
00599     int var_all;
00600     float** sv;
00601     int sv_total;
00602     CvMat* var_idx;
00603     CvMat* class_weights;
00604     CvSVMDecisionFunc* decision_func;
00605     CvMemStorage* storage;
00606 
00607     CvSVMSolver* solver;
00608     CvSVMKernel* kernel;
00609 };
00610 
00611 /****************************************************************************************\
00612 *                              Expectation - Maximization                                *
00613 \****************************************************************************************/
00614 
00615 struct CV_EXPORTS_W_MAP CvEMParams
00616 {
00617     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
00618         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
00619     {
00620         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
00621     }
00622 
00623     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
00624                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
00625                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
00626                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
00627                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
00628                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
00629     {}
00630 
00631     CV_PROP_RW int nclusters;
00632     CV_PROP_RW int cov_mat_type;
00633     CV_PROP_RW int start_step;
00634     const CvMat* probs;
00635     const CvMat* weights;
00636     const CvMat* means;
00637     const CvMat** covs;
00638     CV_PROP_RW CvTermCriteria term_crit;
00639 };
00640 
00641 
00642 class CV_EXPORTS_W CvEM : public CvStatModel
00643 {
00644 public:
00645     // Type of covariation matrices
00646     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
00647 
00648     // The initial step
00649     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
00650 
00651     CV_WRAP CvEM();
00652     CvEM( const CvMat* samples, const CvMat* sampleIdx=0,
00653           CvEMParams params=CvEMParams(), CvMat* labels=0 );
00654     //CvEM (CvEMParams params, CvMat * means, CvMat ** covs, CvMat * weights,
00655     // CvMat * probs, CvMat * log_weight_div_det, CvMat * inv_eigen_values, CvMat** cov_rotate_mats);
00656 
00657     virtual ~CvEM();
00658 
00659     virtual bool train( const CvMat* samples, const CvMat* sampleIdx=0,
00660                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
00661 
00662     virtual float predict( const CvMat* sample, CV_OUT CvMat* probs ) const;
00663 
00664 #ifndef SWIG
00665     CV_WRAP CvEM( const cv::Mat& samples, const cv::Mat& sampleIdx=cv::Mat(),
00666                   CvEMParams params=CvEMParams() );
00667     
00668     CV_WRAP virtual bool train( const cv::Mat& samples,
00669                                 const cv::Mat& sampleIdx=cv::Mat(),
00670                                 CvEMParams params=CvEMParams(),
00671                                 CV_OUT cv::Mat* labels=0 );
00672     
00673     CV_WRAP virtual float predict( const cv::Mat& sample, CV_OUT cv::Mat* probs=0 ) const;
00674     
00675     CV_WRAP int  getNClusters() const;
00676     CV_WRAP cv::Mat  getMeans()     const;
00677     CV_WRAP void getCovs(CV_OUT std::vector<cv::Mat>& covs)      const;
00678     CV_WRAP cv::Mat  getWeights()   const;
00679     CV_WRAP cv::Mat  getProbs()     const;
00680     
00681     CV_WRAP inline double getLikelihood() const { return log_likelihood;     };
00682 #endif
00683     
00684     CV_WRAP virtual void clear();
00685 
00686     int           get_nclusters() const;
00687     const CvMat*  get_means()     const;
00688     const CvMat** get_covs()      const;
00689     const CvMat*  get_weights()   const;
00690     const CvMat*  get_probs()     const;
00691 
00692     inline double         get_log_likelihood     () const { return log_likelihood;     };
00693     
00694 //    inline const CvMat *  get_log_weight_div_det () const { return log_weight_div_det; };
00695 //    inline const CvMat *  get_inv_eigen_values   () const { return inv_eigen_values;   };
00696 //    inline const CvMat ** get_cov_rotate_mats    () const { return cov_rotate_mats;    };
00697 
00698 protected:
00699 
00700     virtual void set_params( const CvEMParams& params,
00701                              const CvVectors& train_data );
00702     virtual void init_em( const CvVectors& train_data );
00703     virtual double run_em( const CvVectors& train_data );
00704     virtual void init_auto( const CvVectors& samples );
00705     virtual void kmeans( const CvVectors& train_data, int nclusters,
00706                          CvMat* labels, CvTermCriteria criteria,
00707                          const CvMat* means );
00708     CvEMParams params;
00709     double log_likelihood;
00710 
00711     CvMat* means;
00712     CvMat** covs;
00713     CvMat* weights;
00714     CvMat* probs;
00715 
00716     CvMat* log_weight_div_det;
00717     CvMat* inv_eigen_values;
00718     CvMat** cov_rotate_mats;
00719 };
00720 
00721 /****************************************************************************************\
00722 *                                      Decision Tree                                     *
00723 \****************************************************************************************/\
00724 struct CvPair16u32s
00725 {
00726     unsigned short* u;
00727     int* i;
00728 };
00729 
00730 
00731 #define CV_DTREE_CAT_DIR(idx,subset) \
00732     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
00733 
00734 struct CvDTreeSplit
00735 {
00736     int var_idx;
00737     int condensed_idx;
00738     int inversed;
00739     float quality;
00740     CvDTreeSplit* next;
00741     union
00742     {
00743         int subset[2];
00744         struct
00745         {
00746             float c;
00747             int split_point;
00748         }
00749         ord;
00750     };
00751 };
00752 
00753 struct CvDTreeNode
00754 {
00755     int class_idx;
00756     int Tn;
00757     double value;
00758 
00759     CvDTreeNode* parent;
00760     CvDTreeNode* left;
00761     CvDTreeNode* right;
00762 
00763     CvDTreeSplit* split;
00764 
00765     int sample_count;
00766     int depth;
00767     int* num_valid;
00768     int offset;
00769     int buf_idx;
00770     double maxlr;
00771 
00772     // global pruning data
00773     int complexity;
00774     double alpha;
00775     double node_risk, tree_risk, tree_error;
00776 
00777     // cross-validation pruning data
00778     int* cv_Tn;
00779     double* cv_node_risk;
00780     double* cv_node_error;
00781 
00782     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
00783     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
00784 };
00785 
00786 
00787 struct CV_EXPORTS_W_MAP CvDTreeParams
00788 {
00789     CV_PROP_RW int   max_categories;
00790     CV_PROP_RW int   max_depth;
00791     CV_PROP_RW int   min_sample_count;
00792     CV_PROP_RW int   cv_folds;
00793     CV_PROP_RW bool  use_surrogates;
00794     CV_PROP_RW bool  use_1se_rule;
00795     CV_PROP_RW bool  truncate_pruned_tree;
00796     CV_PROP_RW float regression_accuracy;
00797     const float* priors;
00798 
00799     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
00800         cv_folds(10), use_surrogates(true), use_1se_rule(true),
00801         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
00802     {}
00803 
00804     CvDTreeParams( int _max_depth, int _min_sample_count,
00805                    float _regression_accuracy, bool _use_surrogates,
00806                    int _max_categories, int _cv_folds,
00807                    bool _use_1se_rule, bool _truncate_pruned_tree,
00808                    const float* _priors ) :
00809         max_categories(_max_categories), max_depth(_max_depth),
00810         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
00811         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
00812         truncate_pruned_tree(_truncate_pruned_tree),
00813         regression_accuracy(_regression_accuracy),
00814         priors(_priors)
00815     {}
00816 };
00817 
00818 
00819 struct CV_EXPORTS CvDTreeTrainData
00820 {
00821     CvDTreeTrainData();
00822     CvDTreeTrainData( const CvMat* trainData, int tflag,
00823                       const CvMat* responses, const CvMat* varIdx=0,
00824                       const CvMat* sampleIdx=0, const CvMat* varType=0,
00825                       const CvMat* missingDataMask=0,
00826                       const CvDTreeParams& params=CvDTreeParams(),
00827                       bool _shared=false, bool _add_labels=false );
00828     virtual ~CvDTreeTrainData();
00829 
00830     virtual void set_data( const CvMat* trainData, int tflag,
00831                           const CvMat* responses, const CvMat* varIdx=0,
00832                           const CvMat* sampleIdx=0, const CvMat* varType=0,
00833                           const CvMat* missingDataMask=0,
00834                           const CvDTreeParams& params=CvDTreeParams(),
00835                           bool _shared=false, bool _add_labels=false,
00836                           bool _update_data=false );
00837     virtual void do_responses_copy();
00838 
00839     virtual void get_vectors( const CvMat* _subsample_idx,
00840          float* values, uchar* missing, float* responses, bool get_class_idx=false );
00841 
00842     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
00843 
00844     virtual void write_params( CvFileStorage* fs ) const;
00845     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
00846 
00847     // release all the data
00848     virtual void clear();
00849 
00850     int get_num_classes() const;
00851     int get_var_type(int vi) const;
00852     int get_work_var_count() const {return work_var_count;}
00853 
00854     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
00855     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
00856     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
00857     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
00858     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
00859     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
00860                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
00861     virtual int get_child_buf_idx( CvDTreeNode* n );
00862 
00864 
00865     virtual bool set_params( const CvDTreeParams& params );
00866     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
00867                                    int storage_idx, int offset );
00868 
00869     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
00870                 int split_point, int inversed, float quality );
00871     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
00872     virtual void free_node_data( CvDTreeNode* node );
00873     virtual void free_train_data();
00874     virtual void free_node( CvDTreeNode* node );
00875 
00876     int sample_count, var_all, var_count, max_c_count;
00877     int ord_var_count, cat_var_count, work_var_count;
00878     bool have_labels, have_priors;
00879     bool is_classifier;
00880     int tflag;
00881 
00882     const CvMat* train_data;
00883     const CvMat* responses;
00884     CvMat* responses_copy; // used in Boosting
00885 
00886     int buf_count, buf_size;
00887     bool shared;
00888     int is_buf_16u;
00889     
00890     CvMat* cat_count;
00891     CvMat* cat_ofs;
00892     CvMat* cat_map;
00893 
00894     CvMat* counts;
00895     CvMat* buf;
00896     CvMat* direction;
00897     CvMat* split_buf;
00898 
00899     CvMat* var_idx;
00900     CvMat* var_type; // i-th element =
00901                      //   k<0  - ordered
00902                      //   k>=0 - categorical, see k-th element of cat_* arrays
00903     CvMat* priors;
00904     CvMat* priors_mult;
00905 
00906     CvDTreeParams params;
00907 
00908     CvMemStorage* tree_storage;
00909     CvMemStorage* temp_storage;
00910 
00911     CvDTreeNode* data_root;
00912 
00913     CvSet* node_heap;
00914     CvSet* split_heap;
00915     CvSet* cv_heap;
00916     CvSet* nv_heap;
00917 
00918     cv::RNG* rng;
00919 };
00920 
00921 class CvDTree;
00922 class CvForestTree;
00923 
00924 namespace cv
00925 {
00926     struct DTreeBestSplitFinder;
00927     struct ForestTreeBestSplitFinder;
00928 }
00929 
00930 class CV_EXPORTS_W CvDTree : public CvStatModel
00931 {
00932 public:
00933     CV_WRAP CvDTree();
00934     virtual ~CvDTree();
00935 
00936     virtual bool train( const CvMat* trainData, int tflag,
00937                         const CvMat* responses, const CvMat* varIdx=0,
00938                         const CvMat* sampleIdx=0, const CvMat* varType=0,
00939                         const CvMat* missingDataMask=0,
00940                         CvDTreeParams params=CvDTreeParams() );
00941 
00942     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
00943 
00944     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
00945     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
00946 
00947     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
00948 
00949     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
00950                                   bool preprocessedInput=false ) const;
00951 
00952 #ifndef SWIG
00953     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
00954                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
00955                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
00956                        const cv::Mat& missingDataMask=cv::Mat(),
00957                        CvDTreeParams params=CvDTreeParams() );
00958     
00959     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
00960                                   bool preprocessedInput=false ) const;
00961     CV_WRAP virtual cv::Mat getVarImportance();
00962 #endif
00963     
00964     virtual const CvMat* get_var_importance();
00965     CV_WRAP virtual void clear();
00966 
00967     virtual void read( CvFileStorage* fs, CvFileNode* node );
00968     virtual void write( CvFileStorage* fs, const char* name ) const;
00969 
00970     // special read & write methods for trees in the tree ensembles
00971     virtual void read( CvFileStorage* fs, CvFileNode* node,
00972                        CvDTreeTrainData* data );
00973     virtual void write( CvFileStorage* fs ) const;
00974 
00975     const CvDTreeNode* get_root() const;
00976     int get_pruned_tree_idx() const;
00977     CvDTreeTrainData* get_data();
00978 
00979 protected:
00980     friend struct cv::DTreeBestSplitFinder;
00981 
00982     virtual bool do_train( const CvMat* _subsample_idx );
00983 
00984     virtual void try_split_node( CvDTreeNode* n );
00985     virtual void split_node_data( CvDTreeNode* n );
00986     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
00987     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
00988                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00989     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
00990                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00991     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
00992                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00993     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
00994                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
00995     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00996     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
00997     virtual double calc_node_dir( CvDTreeNode* node );
00998     virtual void complete_node_dir( CvDTreeNode* node );
00999     virtual void cluster_categories( const int* vectors, int vector_count,
01000         int var_count, int* sums, int k, int* cluster_labels );
01001 
01002     virtual void calc_node_value( CvDTreeNode* node );
01003 
01004     virtual void prune_cv();
01005     virtual double update_tree_rnc( int T, int fold );
01006     virtual int cut_tree( int T, int fold, double min_alpha );
01007     virtual void free_prune_data(bool cut_tree);
01008     virtual void free_tree();
01009 
01010     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
01011     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
01012     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
01013     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
01014     virtual void write_tree_nodes( CvFileStorage* fs ) const;
01015     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
01016 
01017     CvDTreeNode* root;
01018     CvMat* var_importance;
01019     CvDTreeTrainData* data;
01020 
01021 public:
01022     int pruned_tree_idx;
01023 };
01024 
01025 
01026 /****************************************************************************************\
01027 *                                   Random Trees Classifier                              *
01028 \****************************************************************************************/
01029 
01030 class CvRTrees;
01031 
01032 class CV_EXPORTS CvForestTree: public CvDTree
01033 {
01034 public:
01035     CvForestTree();
01036     virtual ~CvForestTree();
01037 
01038     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
01039 
01040     virtual int get_var_count() const {return data ? data->var_count : 0;}
01041     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
01042 
01043     /* dummy methods to avoid warnings: BEGIN */
01044     virtual bool train( const CvMat* trainData, int tflag,
01045                         const CvMat* responses, const CvMat* varIdx=0,
01046                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01047                         const CvMat* missingDataMask=0,
01048                         CvDTreeParams params=CvDTreeParams() );
01049 
01050     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01051     virtual void read( CvFileStorage* fs, CvFileNode* node );
01052     virtual void read( CvFileStorage* fs, CvFileNode* node,
01053                        CvDTreeTrainData* data );
01054     /* dummy methods to avoid warnings: END */
01055 
01056 protected:
01057     friend struct cv::ForestTreeBestSplitFinder;
01058 
01059     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
01060     CvRTrees* forest;
01061 };
01062 
01063 
01064 struct CV_EXPORTS_W_MAP CvRTParams : public CvDTreeParams
01065 {
01066     //Parameters for the forest
01067     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
01068     CV_PROP_RW int nactive_vars;
01069     CV_PROP_RW CvTermCriteria term_crit;
01070 
01071     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
01072         calc_var_importance(false), nactive_vars(0)
01073     {
01074         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
01075     }
01076 
01077     CvRTParams( int _max_depth, int _min_sample_count,
01078                 float _regression_accuracy, bool _use_surrogates,
01079                 int _max_categories, const float* _priors, bool _calc_var_importance,
01080                 int _nactive_vars, int max_num_of_trees_in_the_forest,
01081                 float forest_accuracy, int termcrit_type ) :
01082         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
01083                        _use_surrogates, _max_categories, 0,
01084                        false, false, _priors ),
01085         calc_var_importance(_calc_var_importance),
01086         nactive_vars(_nactive_vars)
01087     {
01088         term_crit = cvTermCriteria(termcrit_type,
01089             max_num_of_trees_in_the_forest, forest_accuracy);
01090     }
01091 };
01092 
01093 
01094 class CV_EXPORTS_W CvRTrees : public CvStatModel
01095 {
01096 public:
01097     CV_WRAP CvRTrees();
01098     virtual ~CvRTrees();
01099     virtual bool train( const CvMat* trainData, int tflag,
01100                         const CvMat* responses, const CvMat* varIdx=0,
01101                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01102                         const CvMat* missingDataMask=0,
01103                         CvRTParams params=CvRTParams() );
01104     
01105     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01106     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
01107     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
01108 
01109 #ifndef SWIG
01110     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01111                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01112                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01113                        const cv::Mat& missingDataMask=cv::Mat(),
01114                        CvRTParams params=CvRTParams() );
01115     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01116     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
01117     CV_WRAP virtual cv::Mat getVarImportance();
01118 #endif
01119     
01120     CV_WRAP virtual void clear();
01121 
01122     virtual const CvMat* get_var_importance();
01123     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
01124         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
01125     
01126     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01127 
01128     virtual float get_train_error();    
01129 
01130     virtual void read( CvFileStorage* fs, CvFileNode* node );
01131     virtual void write( CvFileStorage* fs, const char* name ) const;
01132 
01133     CvMat* get_active_var_mask();
01134     CvRNG* get_rng();
01135 
01136     int get_tree_count() const;
01137     CvForestTree* get_tree(int i) const;
01138 
01139 protected:
01140 
01141     virtual bool grow_forest( const CvTermCriteria term_crit );
01142 
01143     // array of the trees of the forest
01144     CvForestTree** trees;
01145     CvDTreeTrainData* data;
01146     int ntrees;
01147     int nclasses;
01148     double oob_error;
01149     CvMat* var_importance;
01150     int nsamples;
01151 
01152     cv::RNG* rng;
01153     CvMat* active_var_mask;
01154 };
01155 
01156 /****************************************************************************************\
01157 *                           Extremely randomized trees Classifier                        *
01158 \****************************************************************************************/
01159 struct CV_EXPORTS CvERTreeTrainData : public CvDTreeTrainData
01160 {
01161     virtual void set_data( const CvMat* trainData, int tflag,
01162                           const CvMat* responses, const CvMat* varIdx=0,
01163                           const CvMat* sampleIdx=0, const CvMat* varType=0,
01164                           const CvMat* missingDataMask=0,
01165                           const CvDTreeParams& params=CvDTreeParams(),
01166                           bool _shared=false, bool _add_labels=false,
01167                           bool _update_data=false );
01168     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
01169                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
01170     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
01171     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
01172     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
01173     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
01174                               float* responses, bool get_class_idx=false );
01175     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
01176     const CvMat* missing_mask;
01177 };
01178 
01179 class CV_EXPORTS CvForestERTree : public CvForestTree
01180 {
01181 protected:
01182     virtual double calc_node_dir( CvDTreeNode* node );
01183     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01184         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01185     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01186         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01187     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01188         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01189     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01190         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01191     virtual void split_node_data( CvDTreeNode* n );
01192 };
01193 
01194 class CV_EXPORTS_W CvERTrees : public CvRTrees
01195 {
01196 public:
01197     CV_WRAP CvERTrees();
01198     virtual ~CvERTrees();
01199     virtual bool train( const CvMat* trainData, int tflag,
01200                         const CvMat* responses, const CvMat* varIdx=0,
01201                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01202                         const CvMat* missingDataMask=0,
01203                         CvRTParams params=CvRTParams());
01204 #ifndef SWIG
01205     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01206                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01207                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01208                        const cv::Mat& missingDataMask=cv::Mat(),
01209                        CvRTParams params=CvRTParams());
01210 #endif
01211     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
01212 protected:
01213     virtual bool grow_forest( const CvTermCriteria term_crit );
01214 };
01215 
01216 
01217 /****************************************************************************************\
01218 *                                   Boosted tree classifier                              *
01219 \****************************************************************************************/
01220 
01221 struct CV_EXPORTS_W_MAP CvBoostParams : public CvDTreeParams
01222 {
01223     CV_PROP_RW int boost_type;
01224     CV_PROP_RW int weak_count;
01225     CV_PROP_RW int split_criteria;
01226     CV_PROP_RW double weight_trim_rate;
01227 
01228     CvBoostParams();
01229     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
01230                    int max_depth, bool use_surrogates, const float* priors );
01231 };
01232 
01233 
01234 class CvBoost;
01235 
01236 class CV_EXPORTS CvBoostTree: public CvDTree
01237 {
01238 public:
01239     CvBoostTree();
01240     virtual ~CvBoostTree();
01241 
01242     virtual bool train( CvDTreeTrainData* trainData,
01243                         const CvMat* subsample_idx, CvBoost* ensemble );
01244 
01245     virtual void scale( double s );
01246     virtual void read( CvFileStorage* fs, CvFileNode* node,
01247                        CvBoost* ensemble, CvDTreeTrainData* _data );
01248     virtual void clear();
01249 
01250     /* dummy methods to avoid warnings: BEGIN */
01251     virtual bool train( const CvMat* trainData, int tflag,
01252                         const CvMat* responses, const CvMat* varIdx=0,
01253                         const CvMat* sampleIdx=0, const CvMat* varType=0,
01254                         const CvMat* missingDataMask=0,
01255                         CvDTreeParams params=CvDTreeParams() );
01256     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
01257 
01258     virtual void read( CvFileStorage* fs, CvFileNode* node );
01259     virtual void read( CvFileStorage* fs, CvFileNode* node,
01260                        CvDTreeTrainData* data );
01261     /* dummy methods to avoid warnings: END */
01262 
01263 protected:
01264 
01265     virtual void try_split_node( CvDTreeNode* n );
01266     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01267     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
01268     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 
01269         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01270     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
01271         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01272     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 
01273         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01274     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 
01275         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
01276     virtual void calc_node_value( CvDTreeNode* n );
01277     virtual double calc_node_dir( CvDTreeNode* n );
01278 
01279     CvBoost* ensemble;
01280 };
01281 
01282 
01283 class CV_EXPORTS_W CvBoost : public CvStatModel
01284 {
01285 public:
01286     // Boosting type
01287     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
01288 
01289     // Splitting criteria
01290     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
01291 
01292     CV_WRAP CvBoost();
01293     virtual ~CvBoost();
01294 
01295     CvBoost( const CvMat* trainData, int tflag,
01296              const CvMat* responses, const CvMat* varIdx=0,
01297              const CvMat* sampleIdx=0, const CvMat* varType=0,
01298              const CvMat* missingDataMask=0,
01299              CvBoostParams params=CvBoostParams() );
01300     
01301     virtual bool train( const CvMat* trainData, int tflag,
01302              const CvMat* responses, const CvMat* varIdx=0,
01303              const CvMat* sampleIdx=0, const CvMat* varType=0,
01304              const CvMat* missingDataMask=0,
01305              CvBoostParams params=CvBoostParams(),
01306              bool update=false );
01307     
01308     virtual bool train( CvMLData* data,
01309              CvBoostParams params=CvBoostParams(),
01310              bool update=false );
01311 
01312     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01313                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
01314                            bool raw_mode=false, bool return_sum=false ) const;
01315 
01316 #ifndef SWIG
01317     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
01318             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01319             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01320             const cv::Mat& missingDataMask=cv::Mat(),
01321             CvBoostParams params=CvBoostParams() );
01322     
01323     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01324                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01325                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01326                        const cv::Mat& missingDataMask=cv::Mat(),
01327                        CvBoostParams params=CvBoostParams(),
01328                        bool update=false );
01329     
01330     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01331                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
01332                                    bool returnSum=false ) const;
01333 #endif
01334     
01335     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
01336 
01337     CV_WRAP virtual void prune( CvSlice slice );
01338 
01339     CV_WRAP virtual void clear();
01340 
01341     virtual void write( CvFileStorage* storage, const char* name ) const;
01342     virtual void read( CvFileStorage* storage, CvFileNode* node );
01343     virtual const CvMat* get_active_vars(bool absolute_idx=true);
01344 
01345     CvSeq* get_weak_predictors();
01346 
01347     CvMat* get_weights();
01348     CvMat* get_subtree_weights();
01349     CvMat* get_weak_response();
01350     const CvBoostParams& get_params() const;
01351     const CvDTreeTrainData* get_data() const;
01352 
01353 protected:
01354 
01355     virtual bool set_params( const CvBoostParams& params );
01356     virtual void update_weights( CvBoostTree* tree );
01357     virtual void trim_weights();
01358     virtual void write_params( CvFileStorage* fs ) const;
01359     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
01360 
01361     CvDTreeTrainData* data;
01362     CvBoostParams params;
01363     CvSeq* weak;
01364 
01365     CvMat* active_vars;
01366     CvMat* active_vars_abs;
01367     bool have_active_cat_vars;
01368 
01369     CvMat* orig_response;
01370     CvMat* sum_response;
01371     CvMat* weak_eval;
01372     CvMat* subsample_mask;
01373     CvMat* weights;
01374     CvMat* subtree_weights;
01375     bool have_subsample;
01376 };
01377 
01378 
01379 /****************************************************************************************\
01380 *                                   Gradient Boosted Trees                               *
01381 \****************************************************************************************/
01382 
01383 // DataType: STRUCT CvGBTreesParams
01384 // Parameters of GBT (Gradient Boosted trees model), including single
01385 // tree settings and ensemble parameters.
01386 //
01387 // weak_count          - count of trees in the ensemble
01388 // loss_function_type  - loss function used for ensemble training
01389 // subsample_portion   - portion of whole training set used for
01390 //                       every single tree training.
01391 //                       subsample_portion value is in (0.0, 1.0].
01392 //                       subsample_portion == 1.0 when whole dataset is
01393 //                       used on each step. Count of sample used on each
01394 //                       step is computed as
01395 //                       int(total_samples_count * subsample_portion).
01396 // shrinkage           - regularization parameter.
01397 //                       Each tree prediction is multiplied on shrinkage value.
01398 
01399 
01400 struct CV_EXPORTS_W_MAP CvGBTreesParams : public CvDTreeParams
01401 {
01402     CV_PROP_RW int weak_count;
01403     CV_PROP_RW int loss_function_type;
01404     CV_PROP_RW float subsample_portion;
01405     CV_PROP_RW float shrinkage;
01406 
01407     CvGBTreesParams();
01408     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
01409         float subsample_portion, int max_depth, bool use_surrogates );
01410 };
01411 
01412 // DataType: CLASS CvGBTrees
01413 // Gradient Boosting Trees (GBT) algorithm implementation.
01414 // 
01415 // data             - training dataset
01416 // params           - parameters of the CvGBTrees
01417 // weak             - array[0..(class_count-1)] of CvSeq
01418 //                    for storing tree ensembles
01419 // orig_response    - original responses of the training set samples
01420 // sum_response     - predicitons of the current model on the training dataset.
01421 //                    this matrix is updated on every iteration.
01422 // sum_response_tmp - predicitons of the model on the training set on the next
01423 //                    step. On every iteration values of sum_responses_tmp are
01424 //                    computed via sum_responses values. When the current
01425 //                    step is complete sum_response values become equal to
01426 //                    sum_responses_tmp.
01427 // sampleIdx       - indices of samples used for training the ensemble.
01428 //                    CvGBTrees training procedure takes a set of samples
01429 //                    (train_data) and a set of responses (responses).
01430 //                    Only pairs (train_data[i], responses[i]), where i is
01431 //                    in sample_idx are used for training the ensemble.
01432 // subsample_train  - indices of samples used for training a single decision
01433 //                    tree on the current step. This indices are countered
01434 //                    relatively to the sample_idx, so that pairs
01435 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
01436 //                    are used for training a decision tree.
01437 //                    Training set is randomly splited
01438 //                    in two parts (subsample_train and subsample_test)
01439 //                    on every iteration accordingly to the portion parameter.
01440 // subsample_test   - relative indices of samples from the training set,
01441 //                    which are not used for training a tree on the current
01442 //                    step.
01443 // missing          - mask of the missing values in the training set. This
01444 //                    matrix has the same size as train_data. 1 - missing
01445 //                    value, 0 - not a missing value.
01446 // class_labels     - output class labels map. 
01447 // rng              - random number generator. Used for spliting the
01448 //                    training set.
01449 // class_count      - count of output classes.
01450 //                    class_count == 1 in the case of regression,
01451 //                    and > 1 in the case of classification.
01452 // delta            - Huber loss function parameter.
01453 // base_value       - start point of the gradient descent procedure.
01454 //                    model prediction is
01455 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
01456 //                    f_0 is the base value.
01457 
01458 
01459 
01460 class CV_EXPORTS_W CvGBTrees : public CvStatModel
01461 {
01462 public:
01463 
01464     /*
01465     // DataType: ENUM
01466     // Loss functions implemented in CvGBTrees.
01467     //  
01468     // SQUARED_LOSS
01469     // problem: regression
01470     // loss = (x - x')^2
01471     // 
01472     // ABSOLUTE_LOSS
01473     // problem: regression
01474     // loss = abs(x - x')
01475     // 
01476     // HUBER_LOSS
01477     // problem: regression
01478     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
01479     //           1/2*(x - x')^2, if abs(x - x') <= delta,
01480     //           where delta is the alpha-quantile of pseudo responses from
01481     //           the training set.
01482     //
01483     // DEVIANCE_LOSS
01484     // problem: classification
01485     // 
01486     */ 
01487     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
01488  
01489     
01490     /*
01491     // Default constructor. Creates a model only (without training).
01492     // Should be followed by one form of the train(...) function.
01493     //
01494     // API
01495     // CvGBTrees();
01496     
01497     // INPUT
01498     // OUTPUT
01499     // RESULT
01500     */
01501     CV_WRAP CvGBTrees();
01502 
01503 
01504     /*
01505     // Full form constructor. Creates a gradient boosting model and does the
01506     // train.
01507     //
01508     // API
01509     // CvGBTrees( const CvMat* trainData, int tflag,
01510              const CvMat* responses, const CvMat* varIdx=0,
01511              const CvMat* sampleIdx=0, const CvMat* varType=0,
01512              const CvMat* missingDataMask=0,
01513              CvGBTreesParams params=CvGBTreesParams() );
01514     
01515     // INPUT
01516     // trainData    - a set of input feature vectors.
01517     //                  size of matrix is
01518     //                  <count of samples> x <variables count>
01519     //                  or <variables count> x <count of samples>
01520     //                  depending on the tflag parameter.
01521     //                  matrix values are float.
01522     // tflag         - a flag showing how do samples stored in the
01523     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01524     //                  or column by column (tflag=CV_COL_SAMPLE).
01525     // responses     - a vector of responses corresponding to the samples
01526     //                  in trainData.
01527     // varIdx       - indices of used variables. zero value means that all
01528     //                  variables are active.
01529     // sampleIdx    - indices of used samples. zero value means that all
01530     //                  samples from trainData are in the training set.
01531     // varType      - vector of <variables count> length. gives every
01532     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01533     //                  varType = 0 means all variables are numerical.
01534     // missingDataMask  - a mask of misiing values in trainData.
01535     //                  missingDataMask = 0 means that there are no missing
01536     //                  values.
01537     // params         - parameters of GTB algorithm.
01538     // OUTPUT
01539     // RESULT
01540     */
01541     CvGBTrees( const CvMat* trainData, int tflag,
01542              const CvMat* responses, const CvMat* varIdx=0,
01543              const CvMat* sampleIdx=0, const CvMat* varType=0,
01544              const CvMat* missingDataMask=0,
01545              CvGBTreesParams params=CvGBTreesParams() );
01546 
01547              
01548     /*
01549     // Destructor.
01550     */
01551     virtual ~CvGBTrees();
01552     
01553     
01554     /*
01555     // Gradient tree boosting model training
01556     //
01557     // API
01558     // virtual bool train( const CvMat* trainData, int tflag,
01559              const CvMat* responses, const CvMat* varIdx=0,
01560              const CvMat* sampleIdx=0, const CvMat* varType=0,
01561              const CvMat* missingDataMask=0,
01562              CvGBTreesParams params=CvGBTreesParams(),
01563              bool update=false );
01564     
01565     // INPUT
01566     // trainData    - a set of input feature vectors.
01567     //                  size of matrix is
01568     //                  <count of samples> x <variables count>
01569     //                  or <variables count> x <count of samples>
01570     //                  depending on the tflag parameter.
01571     //                  matrix values are float.
01572     // tflag         - a flag showing how do samples stored in the
01573     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
01574     //                  or column by column (tflag=CV_COL_SAMPLE).
01575     // responses     - a vector of responses corresponding to the samples
01576     //                  in trainData.
01577     // varIdx       - indices of used variables. zero value means that all
01578     //                  variables are active.
01579     // sampleIdx    - indices of used samples. zero value means that all
01580     //                  samples from trainData are in the training set.
01581     // varType      - vector of <variables count> length. gives every
01582     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
01583     //                  varType = 0 means all variables are numerical.
01584     // missingDataMask  - a mask of misiing values in trainData.
01585     //                  missingDataMask = 0 means that there are no missing
01586     //                  values.
01587     // params         - parameters of GTB algorithm.
01588     // update         - is not supported now. (!)
01589     // OUTPUT
01590     // RESULT
01591     // Error state.
01592     */
01593     virtual bool train( const CvMat* trainData, int tflag,
01594              const CvMat* responses, const CvMat* varIdx=0,
01595              const CvMat* sampleIdx=0, const CvMat* varType=0,
01596              const CvMat* missingDataMask=0,
01597              CvGBTreesParams params=CvGBTreesParams(),
01598              bool update=false );
01599     
01600     
01601     /*
01602     // Gradient tree boosting model training
01603     //
01604     // API
01605     // virtual bool train( CvMLData* data,
01606              CvGBTreesParams params=CvGBTreesParams(),
01607              bool update=false ) {return false;};
01608     
01609     // INPUT
01610     // data          - training set.
01611     // params        - parameters of GTB algorithm.
01612     // update        - is not supported now. (!)
01613     // OUTPUT
01614     // RESULT
01615     // Error state.
01616     */
01617     virtual bool train( CvMLData* data,
01618              CvGBTreesParams params=CvGBTreesParams(),
01619              bool update=false );
01620 
01621     
01622     /*
01623     // Response value prediction
01624     //
01625     // API
01626     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
01627              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
01628              int k=-1 ) const;
01629     
01630     // INPUT
01631     // sample         - input sample of the same type as in the training set.
01632     // missing        - missing values mask. missing=0 if there are no
01633     //                   missing values in sample vector.
01634     // weak_responses  - predictions of all of the trees.
01635     //                   not implemented (!)
01636     // slice           - part of the ensemble used for prediction.
01637     //                   slice = CV_WHOLE_SEQ when all trees are used.
01638     // k               - number of ensemble used.
01639     //                   k is in {-1,0,1,..,<count of output classes-1>}.
01640     //                   in the case of classification problem 
01641     //                   <count of output classes-1> ensembles are built.
01642     //                   If k = -1 ordinary prediction is the result,
01643     //                   otherwise function gives the prediction of the
01644     //                   k-th ensemble only.
01645     // OUTPUT
01646     // RESULT
01647     // Predicted value.
01648     */
01649     virtual float predict( const CvMat* sample, const CvMat* missing=0,
01650             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
01651             int k=-1 ) const;
01652 
01653     /*
01654     // Delete all temporary data.
01655     //
01656     // API
01657     // virtual void clear();
01658     
01659     // INPUT
01660     // OUTPUT
01661     // delete data, weak, orig_response, sum_response,
01662     //        weak_eval, ubsample_train, subsample_test,
01663     //        sample_idx, missing, lass_labels
01664     // delta = 0.0
01665     // RESULT
01666     */
01667     CV_WRAP virtual void clear();
01668 
01669     /*
01670     // Compute error on the train/test set.
01671     //
01672     // API
01673     // virtual float calc_error( CvMLData* _data, int type,
01674     //        std::vector<float> *resp = 0 );
01675     //
01676     // INPUT
01677     // data  - dataset
01678     // type  - defines which error is to compute^ train (CV_TRAIN_ERROR) or
01679     //         test (CV_TEST_ERROR).
01680     // OUTPUT
01681     // resp  - vector of predicitons
01682     // RESULT
01683     // Error value.
01684     */
01685     virtual float calc_error( CvMLData* _data, int type,
01686             std::vector<float> *resp = 0 );
01687 
01688 
01689     /*
01690     // 
01691     // Write parameters of the gtb model and data. Write learned model.
01692     //
01693     // API
01694     // virtual void write( CvFileStorage* fs, const char* name ) const;
01695     //
01696     // INPUT
01697     // fs     - file storage to read parameters from.
01698     // name   - model name.
01699     // OUTPUT
01700     // RESULT
01701     */
01702     virtual void write( CvFileStorage* fs, const char* name ) const;
01703 
01704 
01705     /*
01706     // 
01707     // Read parameters of the gtb model and data. Read learned model.
01708     //
01709     // API
01710     // virtual void read( CvFileStorage* fs, CvFileNode* node );
01711     //
01712     // INPUT
01713     // fs     - file storage to read parameters from.
01714     // node   - file node.
01715     // OUTPUT
01716     // RESULT
01717     */
01718     virtual void read( CvFileStorage* fs, CvFileNode* node );
01719 
01720     
01721     // new-style C++ interface
01722     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
01723               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01724               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01725               const cv::Mat& missingDataMask=cv::Mat(),
01726               CvGBTreesParams params=CvGBTreesParams() );
01727     
01728     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
01729                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
01730                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
01731                        const cv::Mat& missingDataMask=cv::Mat(),
01732                        CvGBTreesParams params=CvGBTreesParams(),
01733                        bool update=false );
01734 
01735     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
01736                            const cv::Range& slice = cv::Range::all(),
01737                            int k=-1 ) const;
01738     
01739 protected:
01740 
01741     /*
01742     // Compute the gradient vector components.
01743     //
01744     // API
01745     // virtual void find_gradient( const int k = 0);
01746     
01747     // INPUT
01748     // k        - used for classification problem, determining current
01749     //            tree ensemble.
01750     // OUTPUT
01751     // changes components of data->responses
01752     // which correspond to samples used for training
01753     // on the current step.
01754     // RESULT
01755     */
01756     virtual void find_gradient( const int k = 0);
01757 
01758     
01759     /*
01760     // 
01761     // Change values in tree leaves according to the used loss function.
01762     //
01763     // API
01764     // virtual void change_values(CvDTree* tree, const int k = 0);
01765     //
01766     // INPUT
01767     // tree      - decision tree to change.
01768     // k         - used for classification problem, determining current
01769     //             tree ensemble.
01770     // OUTPUT
01771     // changes 'value' fields of the trees' leaves.
01772     // changes sum_response_tmp.
01773     // RESULT
01774     */
01775     virtual void change_values(CvDTree* tree, const int k = 0);
01776 
01777 
01778     /*
01779     // 
01780     // Find optimal constant prediction value according to the used loss
01781     // function.
01782     // The goal is to find a constant which gives the minimal summary loss
01783     // on the _Idx samples.
01784     //
01785     // API
01786     // virtual float find_optimal_value( const CvMat* _Idx );
01787     //
01788     // INPUT
01789     // _Idx        - indices of the samples from the training set.
01790     // OUTPUT
01791     // RESULT
01792     // optimal constant value.
01793     */
01794     virtual float find_optimal_value( const CvMat* _Idx );
01795 
01796     
01797     /*
01798     // 
01799     // Randomly split the whole training set in two parts according
01800     // to params.portion.
01801     //
01802     // API
01803     // virtual void do_subsample();
01804     //
01805     // INPUT
01806     // OUTPUT
01807     // subsample_train - indices of samples used for training
01808     // subsample_test  - indices of samples used for test
01809     // RESULT
01810     */
01811     virtual void do_subsample();
01812 
01813 
01814     /*
01815     // 
01816     // Internal recursive function giving an array of subtree tree leaves.
01817     //
01818     // API
01819     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01820     //
01821     // INPUT
01822     // node         - current leaf.
01823     // OUTPUT
01824     // count        - count of leaves in the subtree.
01825     // leaves       - array of pointers to leaves.
01826     // RESULT
01827     */
01828     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
01829     
01830     
01831     /*
01832     // 
01833     // Get leaves of the tree.
01834     //
01835     // API
01836     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01837     //
01838     // INPUT
01839     // dtree            - decision tree.
01840     // OUTPUT
01841     // len              - count of the leaves.
01842     // RESULT
01843     // CvDTreeNode**    - array of pointers to leaves.
01844     */
01845     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
01846 
01847     
01848     /*
01849     // 
01850     // Is it a regression or a classification.
01851     //
01852     // API
01853     // bool problem_type();
01854     //
01855     // INPUT
01856     // OUTPUT
01857     // RESULT
01858     // false if it is a classification problem,
01859     // true - if regression.
01860     */
01861     virtual bool problem_type() const;
01862 
01863 
01864     /*
01865     // 
01866     // Write parameters of the gtb model.
01867     //
01868     // API
01869     // virtual void write_params( CvFileStorage* fs ) const;
01870     //
01871     // INPUT
01872     // fs           - file storage to write parameters to.
01873     // OUTPUT
01874     // RESULT
01875     */
01876     virtual void write_params( CvFileStorage* fs ) const;
01877 
01878 
01879     /*
01880     // 
01881     // Read parameters of the gtb model and data.
01882     //
01883     // API
01884     // virtual void read_params( CvFileStorage* fs );
01885     //
01886     // INPUT
01887     // fs           - file storage to read parameters from.
01888     // OUTPUT
01889     // params       - parameters of the gtb model.
01890     // data         - contains information about the structure
01891     //                of the data set (count of variables,
01892     //                their types, etc.).
01893     // class_labels - output class labels map.
01894     // RESULT
01895     */
01896     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
01897 
01898     
01899     CvDTreeTrainData* data;
01900     CvGBTreesParams params;
01901 
01902     CvSeq** weak;
01903     CvMat* orig_response;
01904     CvMat* sum_response;
01905     CvMat* sum_response_tmp;
01906     CvMat* weak_eval;
01907     CvMat* sample_idx;
01908     CvMat* subsample_train;
01909     CvMat* subsample_test;
01910     CvMat* missing;
01911     CvMat* class_labels;
01912 
01913     cv::RNG* rng;
01914 
01915     int class_count;
01916     float delta;
01917     float base_value;
01918 
01919 };
01920 
01921 
01922 
01923 /****************************************************************************************\
01924 *                              Artificial Neural Networks (ANN)                          *
01925 \****************************************************************************************/
01926 
01928 
01929 struct CV_EXPORTS_W_MAP CvANN_MLP_TrainParams
01930 {
01931     CvANN_MLP_TrainParams();
01932     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
01933                            double param1, double param2=0 );
01934     ~CvANN_MLP_TrainParams();
01935 
01936     enum { BACKPROP=0, RPROP=1 };
01937 
01938     CV_PROP_RW CvTermCriteria term_crit;
01939     CV_PROP_RW int train_method;
01940 
01941     // backpropagation parameters
01942     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
01943 
01944     // rprop parameters
01945     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
01946 };
01947 
01948 
01949 class CV_EXPORTS_W CvANN_MLP : public CvStatModel
01950 {
01951 public:
01952     CV_WRAP CvANN_MLP();
01953     CvANN_MLP( const CvMat* layerSizes,
01954                int activateFunc=CvANN_MLP::SIGMOID_SYM,
01955                double fparam1=0, double fparam2=0 );
01956 
01957     virtual ~CvANN_MLP();
01958 
01959     virtual void create( const CvMat* layerSizes,
01960                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
01961                          double fparam1=0, double fparam2=0 );
01962     
01963     virtual int train( const CvMat* inputs, const CvMat* outputs,
01964                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
01965                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01966                        int flags=0 );
01967     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
01968     
01969 #ifndef SWIG
01970     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
01971               int activateFunc=CvANN_MLP::SIGMOID_SYM,
01972               double fparam1=0, double fparam2=0 );
01973     
01974     CV_WRAP virtual void create( const cv::Mat& layerSizes,
01975                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
01976                         double fparam1=0, double fparam2=0 );    
01977     
01978     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
01979                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
01980                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
01981                       int flags=0 );    
01982     
01983     CV_WRAP virtual float predict( const cv::Mat& inputs, cv::Mat& outputs ) const;
01984 #endif
01985     
01986     CV_WRAP virtual void clear();
01987 
01988     // possible activation functions
01989     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
01990 
01991     // available training flags
01992     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
01993 
01994     virtual void read( CvFileStorage* fs, CvFileNode* node );
01995     virtual void write( CvFileStorage* storage, const char* name ) const;
01996 
01997     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
01998     const CvMat* get_layer_sizes() { return layer_sizes; }
01999     double* get_weights(int layer)
02000     {
02001         return layer_sizes && weights &&
02002             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
02003     }
02004 
02005 protected:
02006 
02007     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
02008             const CvMat* _sample_weights, const CvMat* sampleIdx,
02009             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
02010 
02011     // sequential random backpropagation
02012     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
02013 
02014     // RPROP algorithm
02015     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
02016 
02017     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
02018     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
02019     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
02020                                  double _f_param1=0, double _f_param2=0 );
02021     virtual void init_weights();
02022     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
02023     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
02024     virtual void calc_input_scale( const CvVectors* vecs, int flags );
02025     virtual void calc_output_scale( const CvVectors* vecs, int flags );
02026 
02027     virtual void write_params( CvFileStorage* fs ) const;
02028     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
02029 
02030     CvMat* layer_sizes;
02031     CvMat* wbuf;
02032     CvMat* sample_weights;
02033     double** weights;
02034     double f_param1, f_param2;
02035     double min_val, max_val, min_val1, max_val1;
02036     int activ_func;
02037     int max_count, max_buf_sz;
02038     CvANN_MLP_TrainParams params;
02039     cv::RNG* rng;
02040 };
02041 
02042 /****************************************************************************************\
02043 *                           Auxilary functions declarations                              *
02044 \****************************************************************************************/
02045 
02046 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
02047    average row vector, <cov> - symmetric covariation matrix */
02048 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
02049                            CvRNG* rng CV_DEFAULT(0) );
02050 
02051 /* Generates sample from gaussian mixture distribution */
02052 CVAPI(void) cvRandGaussMixture( CvMat* means[],
02053                                CvMat* covs[],
02054                                float weights[],
02055                                int clsnum,
02056                                CvMat* sample,
02057                                CvMat* sampClasses CV_DEFAULT(0) );
02058 
02059 #define CV_TS_CONCENTRIC_SPHERES 0
02060 
02061 /* creates test set */
02062 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
02063                  int num_samples,
02064                  int num_features,
02065                  CvMat** responses,
02066                  int num_classes, ... );
02067 
02068 
02069 #endif
02070 
02071 /****************************************************************************************\
02072 *                                      Data                                             *
02073 \****************************************************************************************/
02074 
02075 #include <map>
02076 #include <string>
02077 #include <iostream>
02078 
02079 #define CV_COUNT     0
02080 #define CV_PORTION   1
02081 
02082 struct CV_EXPORTS CvTrainTestSplit
02083 {
02084 public:
02085     CvTrainTestSplit();
02086     CvTrainTestSplit( int _train_sample_count, bool _mix = true);
02087     CvTrainTestSplit( float _train_sample_portion, bool _mix = true);
02088 
02089     union
02090     {
02091         int count;
02092         float portion;
02093     } train_sample_part;
02094     int train_sample_part_mode;
02095 
02096     union
02097     {
02098         int *count;
02099         float *portion;
02100     } *class_part;
02101     int class_part_mode;
02102 
02103     bool mix;    
02104 };
02105 
02106 class CV_EXPORTS CvMLData
02107 {
02108 public:
02109     CvMLData();
02110     virtual ~CvMLData();
02111 
02112     // returns:
02113     // 0 - OK  
02114     // 1 - file can not be opened or is not correct
02115     int read_csv(const char* filename);
02116 
02117     const CvMat* get_values(){ return values; };
02118 
02119     const CvMat* get_responses();
02120 
02121     const CvMat* get_missing(){ return missing; };
02122 
02123     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
02124                                       // if idx < 0 there will be no response
02125     int get_response_idx() { return response_idx; }
02126 
02127     const CvMat* get_train_sample_idx() { return train_sample_idx; };
02128     const CvMat* get_test_sample_idx() { return test_sample_idx; };
02129     void mix_train_and_test_idx();
02130     void set_train_test_split( const CvTrainTestSplit * spl);
02131     
02132     const CvMat* get_var_idx();
02133     void chahge_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
02134 
02135     const CvMat* get_var_types();
02136     int get_var_type( int var_idx ) { return var_types->data.ptr[var_idx]; };
02137     // following 2 methods enable to change vars type
02138     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
02139     // with numerical labels; in the other cases var types are correctly determined automatically
02140     void set_var_types( const char* str );  // str examples:
02141                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
02142                                             // "cat", "ord" (all vars are categorical/ordered)
02143     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }    
02144  
02145     void set_delimiter( char ch );
02146     char get_delimiter() { return delimiter; };
02147 
02148     void set_miss_ch( char ch );
02149     char get_miss_ch() { return miss_ch; };
02150     
02151 protected:
02152     virtual void clear();
02153 
02154     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
02155     void free_train_test_idx();
02156     
02157     char delimiter;
02158     char miss_ch;
02159     //char flt_separator;
02160 
02161     CvMat* values;
02162     CvMat* missing;
02163     CvMat* var_types;
02164     CvMat* var_idx_mask;
02165 
02166     CvMat* response_out; // header
02167     CvMat* var_idx_out; // mat
02168     CvMat* var_types_out; // mat
02169 
02170     int response_idx;
02171 
02172     int train_sample_count;
02173     bool mix;
02174    
02175     int total_class_count;
02176     std::map<std::string, int> *class_map;
02177 
02178     CvMat* train_sample_idx;
02179     CvMat* test_sample_idx;
02180     int* sample_idx; // data of train_sample_idx and test_sample_idx
02181 
02182     cv::RNG* rng;
02183 };
02184 
02185 
02186 namespace cv
02187 {
02188     
02189 typedef CvStatModel StatModel;
02190 typedef CvParamGrid ParamGrid;
02191 typedef CvNormalBayesClassifier NormalBayesClassifier;
02192 typedef CvKNearest KNearest;
02193 typedef CvSVMParams SVMParams;
02194 typedef CvSVMKernel SVMKernel;
02195 typedef CvSVMSolver SVMSolver;
02196 typedef CvSVM SVM;
02197 typedef CvEMParams EMParams;
02198 typedef CvEM ExpectationMaximization;
02199 typedef CvDTreeParams DTreeParams;
02200 typedef CvMLData TrainData;
02201 typedef CvDTree DecisionTree;
02202 typedef CvForestTree ForestTree;
02203 typedef CvRTParams RandomTreeParams;
02204 typedef CvRTrees RandomTrees;
02205 typedef CvERTreeTrainData ERTreeTRainData;
02206 typedef CvForestERTree ERTree;
02207 typedef CvERTrees ERTrees;
02208 typedef CvBoostParams BoostParams;
02209 typedef CvBoostTree BoostTree;
02210 typedef CvBoost Boost;
02211 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
02212 typedef CvANN_MLP NeuralNet_MLP;
02213 typedef CvGBTreesParams GradientBoostingTreeParams;
02214 typedef CvGBTrees GradientBoostingTrees;
02215 
02216 template<> CV_EXPORTS void Ptr<CvDTreeSplit>::delete_obj();
02217     
02218 }
02219 
02220 #endif
02221 /* End of file. */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines