SHOGUN
v2.0.0
|
00001 /***********************************************************************/ 00002 /* */ 00003 /* SVMLight.h */ 00004 /* */ 00005 /* Author: Thorsten Joachims */ 00006 /* Date: 19.07.99 */ 00007 /* */ 00008 /* Copyright (c) 1999 Universitaet Dortmund - All rights reserved */ 00009 /* */ 00010 /* This software is available for non-commercial use only. It must */ 00011 /* not be modified and distributed without prior permission of the */ 00012 /* author. The author is not responsible for implications from the */ 00013 /* use of this software. */ 00014 /* */ 00015 /* THIS INCLUDES THE FOLLOWING ADDITIONS */ 00016 /* Generic Kernel Interfacing: Soeren Sonnenburg */ 00017 /* Parallizations: Soeren Sonnenburg */ 00018 /* Multiple Kernel Learning: Gunnar Raetsch, Soeren Sonnenburg */ 00019 /* Linadd Speedup: Gunnar Raetsch, Soeren Sonnenburg */ 00020 /* */ 00021 /***********************************************************************/ 00022 #ifndef _SVMLight_H___ 00023 #define _SVMLight_H___ 00024 00025 #include <shogun/lib/config.h> 00026 00027 #ifdef USE_SVMLIGHT 00028 #include <shogun/classifier/svm/SVM.h> 00029 #include <shogun/kernel/Kernel.h> 00030 #include <shogun/mathematics/Math.h> 00031 #include <shogun/lib/common.h> 00032 00033 #include <stdio.h> 00034 #include <ctype.h> 00035 #include <string.h> 00036 #include <stdlib.h> 00037 #include <time.h> 00038 00039 namespace shogun 00040 { 00041 //# define VERSION "V3.50 -- correct??" 00042 //# define VERSION_DATE "01.11.00 -- correct??" 00043 00044 # define DEF_PRECISION 1E-14 00045 # define MAXSHRINK 50000 00046 00047 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00048 00049 struct MODEL { 00051 int32_t sv_num; 00053 int32_t at_upper_bound; 00055 float64_t b; 00057 int32_t* supvec; 00059 float64_t *alpha; 00061 int32_t *index; 00063 int32_t totdoc; 00065 CKernel* kernel; 00066 00067 /* the following values are not written to file */ 00069 float64_t loo_error; 00071 float64_t loo_recall; 00073 float64_t loo_precision; 00074 00076 float64_t xa_error; 00078 float64_t xa_recall; 00080 float64_t xa_precision; 00081 }; 00082 00084 typedef struct quadratic_program { 00086 int32_t opt_n; 00088 int32_t opt_m; 00090 float64_t *opt_ce; 00092 float64_t *opt_ce0; 00094 float64_t *opt_g; 00096 float64_t *opt_g0; 00098 float64_t *opt_xinit; 00100 float64_t *opt_low; 00102 float64_t *opt_up; 00103 } QP; 00104 00106 typedef int32_t FNUM; 00107 00109 typedef float64_t FVAL; 00110 00112 struct LEARN_PARM { 00114 int32_t type; 00116 float64_t svm_c; 00118 float64_t* eps; 00120 float64_t svm_costratio; 00122 float64_t transduction_posratio; 00123 /* classified as positives */ 00125 int32_t biased_hyperplane; 00130 int32_t sharedslack; 00132 int32_t svm_maxqpsize; 00134 int32_t svm_newvarsinqp; 00136 int32_t kernel_cache_size; 00138 float64_t epsilon_crit; 00140 float64_t epsilon_shrink; 00142 int32_t svm_iter_to_shrink; 00146 int32_t maxiter; 00148 int32_t remove_inconsistent; 00152 int32_t skip_final_opt_check; 00154 int32_t compute_loo; 00158 float64_t rho; 00162 int32_t xa_depth; 00164 char predfile[200]; 00168 char alphafile[200]; 00169 00170 /* you probably do not want to touch the following */ 00172 float64_t epsilon_const; 00174 float64_t epsilon_a; 00176 float64_t opt_precision; 00177 00178 /* the following are only for internal use */ 00180 int32_t svm_c_steps; 00182 float64_t svm_c_factor; 00184 float64_t svm_costratio_unlab; 00186 float64_t svm_unlabbound; 00188 float64_t *svm_cost; 00189 }; 00190 00192 struct TIMING { 00194 int32_t time_kernel; 00196 int32_t time_opti; 00198 int32_t time_shrink; 00200 int32_t time_update; 00202 int32_t time_model; 00204 int32_t time_check; 00206 int32_t time_select; 00207 }; 00208 00209 00211 struct SHRINK_STATE 00212 { 00214 int32_t *active; 00216 int32_t *inactive_since; 00218 int32_t deactnum; 00220 float64_t **a_history; 00222 int32_t maxhistory; 00224 float64_t *last_a; 00226 float64_t *last_lin; 00227 }; 00228 #endif // DOXYGEN_SHOULD_SKIP_THIS 00229 00231 class CSVMLight : public CSVM 00232 { 00233 public: 00235 CSVMLight(); 00236 00243 CSVMLight(float64_t C, CKernel* k, CLabels* lab); 00244 virtual ~CSVMLight(); 00245 00247 void init(); 00248 00253 virtual inline EMachineType get_classifier_type() { return CT_LIGHT; } 00254 00259 int32_t get_runtime(); 00260 00261 00263 void svm_learn(); 00264 00281 int32_t optimize_to_convergence( 00282 int32_t* docs, int32_t* label, int32_t totdoc, SHRINK_STATE *shrink_state, 00283 int32_t *inconsistent, float64_t *a, float64_t *lin, float64_t *c, 00284 TIMING *timing_profile, float64_t *maxdiff, int32_t heldout, 00285 int32_t retrain); 00286 00297 virtual float64_t compute_objective_function( 00298 float64_t *a, float64_t *lin, float64_t *c, float64_t* eps, int32_t *label, 00299 int32_t totdoc); 00300 00305 void clear_index(int32_t *index); 00306 00312 void add_to_index(int32_t *index, int32_t elem); 00313 00321 int32_t compute_index(int32_t *binfeature, int32_t range, int32_t *index); 00322 00341 void optimize_svm( 00342 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const, 00343 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t totdoc, 00344 int32_t *working2dnum, int32_t varnum, float64_t *a, float64_t *lin, 00345 float64_t *c, float64_t *aicache, QP *qp, float64_t *epsilon_crit_target); 00346 00364 void compute_matrices_for_optimization( 00365 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const, 00366 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t *key, 00367 float64_t *a, float64_t *lin, float64_t *c, int32_t varnum, int32_t totdoc, 00368 float64_t *aicache, QP *qp); 00369 00387 void compute_matrices_for_optimization_parallel( 00388 int32_t* docs, int32_t* label, int32_t *exclude_from_eq_const, 00389 float64_t eq_target, int32_t *chosen, int32_t *active2dnum, int32_t *key, 00390 float64_t *a, float64_t *lin, float64_t *c, int32_t varnum, int32_t totdoc, 00391 float64_t *aicache, QP *qp); 00392 00405 int32_t calculate_svm_model( 00406 int32_t* docs, int32_t *label,float64_t *lin, float64_t *a, 00407 float64_t* a_old, float64_t *c, int32_t *working2dnum, int32_t *active2dnum); 00408 00425 int32_t check_optimality( 00426 int32_t *label, float64_t *a, float64_t* lin, float64_t *c, int32_t totdoc, 00427 float64_t *maxdiff, float64_t epsilon_crit_org, int32_t *misclassified, 00428 int32_t *inconsistent,int32_t* active2dnum, int32_t *last_suboptimal_at, 00429 int32_t iteration); 00430 00444 virtual void update_linear_component( 00445 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a, 00446 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin, 00447 float64_t *aicache, float64_t* c); 00448 00453 static void* update_linear_component_mkl_linadd_helper(void* p); 00454 00467 void update_linear_component_mkl( 00468 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a, 00469 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin, 00470 float64_t *aicache); 00471 00484 void update_linear_component_mkl_linadd( 00485 int32_t* docs, int32_t *label, int32_t *active2dnum, float64_t *a, 00486 float64_t* a_old, int32_t *working2dnum, int32_t totdoc, float64_t *lin, 00487 float64_t *aicache); 00488 00489 void call_mkl_callback(float64_t* a, int32_t* label, float64_t* lin); 00490 00509 int32_t select_next_qp_subproblem_grad( 00510 int32_t *label, float64_t *a, float64_t* lin, float64_t* c, int32_t totdoc, 00511 int32_t qp_size, int32_t *inconsistent, int32_t* active2dnum, 00512 int32_t* working2dnum, float64_t *selcrit, int32_t *select, 00513 int32_t cache_only, int32_t *key, int32_t *chosen); 00514 00533 int32_t select_next_qp_subproblem_rand( 00534 int32_t* label, float64_t *a, float64_t *lin, float64_t *c, 00535 int32_t totdoc, int32_t qp_size, int32_t *inconsistent, 00536 int32_t *active2dnum, int32_t *working2dnum, float64_t *selcrit, 00537 int32_t *select, int32_t *key, int32_t *chosen, int32_t iteration); 00538 00546 void select_top_n( 00547 float64_t *selcrit, int32_t range, int32_t *select, int32_t n); 00548 00555 void init_shrink_state( 00556 SHRINK_STATE *shrink_state, int32_t totdoc, int32_t maxhistory); 00557 00562 void shrink_state_cleanup(SHRINK_STATE *shrink_state); 00563 00579 int32_t shrink_problem( 00580 SHRINK_STATE *shrink_state, int32_t *active2dnum, 00581 int32_t *last_suboptimal_at, int32_t iteration, int32_t totdoc, 00582 int32_t minshrink, float64_t *a, int32_t *inconsistent, float64_t* c, 00583 float64_t* lin, int* label); 00584 00599 virtual void reactivate_inactive_examples( 00600 int32_t *label,float64_t *a,SHRINK_STATE *shrink_state, float64_t *lin, 00601 float64_t *c, int32_t totdoc,int32_t iteration, int32_t *inconsistent, 00602 int32_t *docs,float64_t *aicache, float64_t* maxdiff); 00603 00604 protected: 00611 inline virtual float64_t compute_kernel(int32_t i, int32_t j) 00612 { 00613 return kernel->kernel(i, j); 00614 } 00615 00620 static void* compute_kernel_helper(void* p); 00621 00626 static void* update_linear_component_linadd_helper(void* p); 00627 00632 static void* reactivate_inactive_examples_vanilla_helper(void* p); 00633 00638 static void* reactivate_inactive_examples_linadd_helper(void* p); 00639 00641 inline virtual const char* get_name() const { return "SVMLight"; } 00642 00643 /* interface to QP-solver */ 00644 float64_t *optimize_qp( QP *qp,float64_t *epsilon_crit, int32_t nx, 00645 float64_t *threshold, int32_t& svm_maxqpsize); 00646 00655 virtual bool train_machine(CFeatures* data=NULL); 00656 00657 protected: 00659 MODEL* model; 00661 LEARN_PARM* learn_parm; 00663 int32_t verbosity; 00664 00666 float64_t init_margin; 00668 int32_t init_iter; 00670 int32_t precision_violations; 00672 float64_t model_b; 00674 float64_t opt_precision; 00676 float64_t* primal; 00678 float64_t* dual; 00679 00680 // MKL stuff 00681 00685 float64_t* W; 00687 int32_t count; 00689 float64_t mymaxdiff; 00691 bool use_kernel_cache; 00693 bool mkl_converged; 00694 }; 00695 } 00696 #endif //USE_SVMLIGHT 00697 #endif //_SVMLight_H___