SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/LibSVM.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/labels/BinaryLabels.h> 00014 00015 using namespace shogun; 00016 00017 CLibSVM::CLibSVM(LIBSVM_SOLVER_TYPE st) 00018 : CSVM(), model(NULL), solver_type(st) 00019 { 00020 } 00021 00022 CLibSVM::CLibSVM(float64_t C, CKernel* k, CLabels* lab) 00023 : CSVM(C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00024 { 00025 problem = svm_problem(); 00026 } 00027 00028 CLibSVM::~CLibSVM() 00029 { 00030 } 00031 00032 00033 bool CLibSVM::train_machine(CFeatures* data) 00034 { 00035 struct svm_node* x_space; 00036 00037 ASSERT(m_labels && m_labels->get_num_labels()); 00038 ASSERT(m_labels->get_label_type() == LT_BINARY); 00039 00040 if (data) 00041 { 00042 if (m_labels->get_num_labels() != data->get_num_vectors()) 00043 { 00044 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does" 00045 " not match number of labels (%d)\n", get_name(), 00046 data->get_num_vectors(), m_labels->get_num_labels()); 00047 } 00048 kernel->init(data, data); 00049 } 00050 00051 problem.l=m_labels->get_num_labels(); 00052 SG_INFO( "%d trainlabels\n", problem.l); 00053 00054 // set linear term 00055 if (m_linear_term.vlen>0) 00056 { 00057 if (m_labels->get_num_labels()!=m_linear_term.vlen) 00058 SG_ERROR("Number of training vectors does not match length of linear term\n"); 00059 00060 // set with linear term from base class 00061 problem.pv = get_linear_term_array(); 00062 } 00063 else 00064 { 00065 // fill with minus ones 00066 problem.pv = SG_MALLOC(float64_t, problem.l); 00067 00068 for (int i=0; i!=problem.l; i++) 00069 problem.pv[i] = -1.0; 00070 } 00071 00072 problem.y=SG_MALLOC(float64_t, problem.l); 00073 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00074 problem.C=SG_MALLOC(float64_t, problem.l); 00075 00076 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00077 00078 for (int32_t i=0; i<problem.l; i++) 00079 { 00080 problem.y[i]=((CBinaryLabels*) m_labels)->get_label(i); 00081 problem.x[i]=&x_space[2*i]; 00082 x_space[2*i].index=i; 00083 x_space[2*i+1].index=-1; 00084 } 00085 00086 int32_t weights_label[2]={-1,+1}; 00087 float64_t weights[2]={1.0,get_C2()/get_C1()}; 00088 00089 ASSERT(kernel && kernel->has_features()); 00090 ASSERT(kernel->get_num_vec_lhs()==problem.l); 00091 00092 param.svm_type=solver_type; // C SVM or NU_SVM 00093 param.kernel_type = LINEAR; 00094 param.degree = 3; 00095 param.gamma = 0; // 1/k 00096 param.coef0 = 0; 00097 param.nu = get_nu(); 00098 param.kernel=kernel; 00099 param.cache_size = kernel->get_cache_size(); 00100 param.max_train_time = m_max_train_time; 00101 param.C = get_C1(); 00102 param.eps = epsilon; 00103 param.p = 0.1; 00104 param.shrinking = 1; 00105 param.nr_weight = 2; 00106 param.weight_label = weights_label; 00107 param.weight = weights; 00108 param.use_bias = get_bias_enabled(); 00109 00110 const char* error_msg = svm_check_parameter(&problem, ¶m); 00111 00112 if(error_msg) 00113 SG_ERROR("Error: %s\n",error_msg); 00114 00115 model = svm_train(&problem, ¶m); 00116 00117 if (model) 00118 { 00119 ASSERT(model->nr_class==2); 00120 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0])); 00121 00122 int32_t num_sv=model->l; 00123 00124 create_new_model(num_sv); 00125 CSVM::set_objective(model->objective); 00126 00127 float64_t sgn=model->label[0]; 00128 00129 set_bias(-sgn*model->rho[0]); 00130 00131 for (int32_t i=0; i<num_sv; i++) 00132 { 00133 set_support_vector(i, (model->SV[i])->index); 00134 set_alpha(i, sgn*model->sv_coef[0][i]); 00135 } 00136 00137 SG_FREE(problem.x); 00138 SG_FREE(problem.y); 00139 SG_FREE(problem.pv); 00140 SG_FREE(problem.C); 00141 00142 00143 SG_FREE(x_space); 00144 00145 svm_destroy_model(model); 00146 model=NULL; 00147 return true; 00148 } 00149 else 00150 return false; 00151 }