SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 2011-2012 Heiko Strathmann 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _MACHINE_H__ 00013 #define _MACHINE_H__ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/base/SGObject.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/BinaryLabels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <shogun/labels/MulticlassLabels.h> 00021 #include <shogun/labels/StructuredLabels.h> 00022 #include <shogun/labels/LatentLabels.h> 00023 #include <shogun/features/Features.h> 00024 00025 namespace shogun 00026 { 00027 00028 class CFeatures; 00029 class CLabels; 00030 class CMath; 00031 00033 enum EMachineType 00034 { 00035 CT_NONE = 0, 00036 CT_LIGHT = 10, 00037 CT_LIGHTONECLASS = 11, 00038 CT_LIBSVM = 20, 00039 CT_LIBSVMONECLASS=30, 00040 CT_LIBSVMMULTICLASS=40, 00041 CT_MPD = 50, 00042 CT_GPBT = 60, 00043 CT_CPLEXSVM = 70, 00044 CT_PERCEPTRON = 80, 00045 CT_KERNELPERCEPTRON = 90, 00046 CT_LDA = 100, 00047 CT_LPM = 110, 00048 CT_LPBOOST = 120, 00049 CT_KNN = 130, 00050 CT_SVMLIN=140, 00051 CT_KERNELRIDGEREGRESSION = 150, 00052 CT_GNPPSVM = 160, 00053 CT_GMNPSVM = 170, 00054 CT_SUBGRADIENTSVM = 180, 00055 CT_SUBGRADIENTLPM = 190, 00056 CT_SVMPERF = 200, 00057 CT_LIBSVR = 210, 00058 CT_SVRLIGHT = 220, 00059 CT_LIBLINEAR = 230, 00060 CT_KMEANS = 240, 00061 CT_HIERARCHICAL = 250, 00062 CT_SVMOCAS = 260, 00063 CT_WDSVMOCAS = 270, 00064 CT_SVMSGD = 280, 00065 CT_MKLMULTICLASS = 290, 00066 CT_MKLCLASSIFICATION = 300, 00067 CT_MKLONECLASS = 310, 00068 CT_MKLREGRESSION = 320, 00069 CT_SCATTERSVM = 330, 00070 CT_DASVM = 340, 00071 CT_LARANK = 350, 00072 CT_DASVMLINEAR = 360, 00073 CT_GAUSSIANNAIVEBAYES = 370, 00074 CT_AVERAGEDPERCEPTRON = 380, 00075 CT_SGDQN = 390, 00076 CT_CONJUGATEINDEX = 400, 00077 CT_LINEARRIDGEREGRESSION = 410, 00078 CT_LEASTSQUARESREGRESSION = 420, 00079 CT_QDA = 430, 00080 CT_NEWTONSVM = 440, 00081 CT_GAUSSIANPROCESSREGRESSION = 450, 00082 CT_LARS = 460, 00083 CT_MULTICLASS = 470, 00084 CT_DIRECTORLINEAR = 480, 00085 CT_DIRECTORKERNEL = 490 00086 }; 00087 00089 enum ESolverType 00090 { 00091 ST_AUTO=0, 00092 ST_CPLEX=1, 00093 ST_GLPK=2, 00094 ST_NEWTON=3, 00095 ST_DIRECT=4, 00096 ST_ELASTICNET=5, 00097 ST_BLOCK_NORM=6 00098 }; 00099 00101 enum EProblemType 00102 { 00103 PT_BINARY = 0, 00104 PT_REGRESSION = 1, 00105 PT_MULTICLASS = 2, 00106 PT_STRUCTURED = 3, 00107 PT_LATENT = 4 00108 }; 00109 00110 #define MACHINE_PROBLEM_TYPE(PT) \ 00111 \ 00114 virtual EProblemType get_machine_problem_type() const { return PT; } 00115 00133 class CMachine : public CSGObject 00134 { 00135 public: 00137 CMachine(); 00138 00140 virtual ~CMachine(); 00141 00151 virtual bool train(CFeatures* data=NULL); 00152 00159 virtual CLabels* apply(CFeatures* data=NULL); 00160 00162 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL); 00164 virtual CRegressionLabels* apply_regression(CFeatures* data=NULL); 00166 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00168 virtual CStructuredLabels* apply_structured(CFeatures* data=NULL); 00170 virtual CLatentLabels* apply_latent(CFeatures* data=NULL); 00171 00176 virtual void set_labels(CLabels* lab); 00177 00182 virtual CLabels* get_labels(); 00183 00188 void set_max_train_time(float64_t t); 00189 00194 float64_t get_max_train_time(); 00195 00200 virtual EMachineType get_classifier_type(); 00201 00206 void set_solver_type(ESolverType st); 00207 00212 ESolverType get_solver_type(); 00213 00219 virtual void set_store_model_features(bool store_model); 00220 00229 virtual bool train_locked(SGVector<index_t> indices) 00230 { 00231 SG_ERROR("train_locked(SGVector<index_t>) is not yet implemented " 00232 "for %s\n", get_name()); 00233 return false; 00234 } 00235 00237 virtual float64_t apply_one(int32_t i) 00238 { 00239 SG_NOTIMPLEMENTED; 00240 return 0.0; 00241 } 00242 00248 virtual CLabels* apply_locked(SGVector<index_t> indices); 00249 00251 virtual CBinaryLabels* apply_locked_binary( 00252 SGVector<index_t> indices); 00254 virtual CRegressionLabels* apply_locked_regression( 00255 SGVector<index_t> indices); 00257 virtual CMulticlassLabels* apply_locked_multiclass( 00258 SGVector<index_t> indices); 00260 virtual CStructuredLabels* apply_locked_structured( 00261 SGVector<index_t> indices); 00263 virtual CLatentLabels* apply_locked_latent( 00264 SGVector<index_t> indices); 00265 00274 virtual void data_lock(CLabels* labs, CFeatures* features); 00275 00277 virtual void post_lock(CLabels* labs, CFeatures* features) { }; 00278 00280 virtual void data_unlock(); 00281 00283 virtual bool supports_locking() const { return false; } 00284 00286 bool is_data_locked() const { return m_data_locked; } 00287 00289 virtual EProblemType get_machine_problem_type() const 00290 { 00291 SG_NOTIMPLEMENTED; 00292 return PT_BINARY; 00293 } 00294 00296 virtual CMachine* clone() 00297 { 00298 SG_NOTIMPLEMENTED; 00299 return NULL; 00300 } 00301 00302 virtual const char* get_name() const { return "Machine"; } 00303 00304 protected: 00315 virtual bool train_machine(CFeatures* data=NULL) 00316 { 00317 SG_ERROR("train_machine is not yet implemented for %s!\n", 00318 get_name()); 00319 return false; 00320 } 00321 00332 virtual void store_model_features() 00333 { 00334 SG_ERROR("Model storage and therefore unlocked Cross-Validation and" 00335 " Model-Selection is not supported for %s. Locked may" 00336 " work though.\n", get_name()); 00337 } 00338 00345 virtual bool is_label_valid(CLabels *lab) const 00346 { 00347 return true; 00348 } 00349 00351 virtual bool train_require_labels() const { return true; } 00352 00353 protected: 00355 float64_t m_max_train_time; 00356 00358 CLabels* m_labels; 00359 00361 ESolverType m_solver_type; 00362 00364 bool m_store_model_features; 00365 00367 bool m_data_locked; 00368 }; 00369 } 00370 #endif // _MACHINE_H__