SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/multiclass/MulticlassLibSVM.h> 00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 #include <shogun/io/SGIO.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassLibSVM::CMulticlassLibSVM(LIBSVM_SOLVER_TYPE st) 00019 : CMulticlassSVM(new CMulticlassOneVsOneStrategy()), model(NULL), solver_type(st) 00020 { 00021 } 00022 00023 CMulticlassLibSVM::CMulticlassLibSVM(float64_t C, CKernel* k, CLabels* lab) 00024 : CMulticlassSVM(new CMulticlassOneVsOneStrategy(), C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC) 00025 { 00026 } 00027 00028 CMulticlassLibSVM::~CMulticlassLibSVM() 00029 { 00030 } 00031 00032 bool CMulticlassLibSVM::train_machine(CFeatures* data) 00033 { 00034 struct svm_node* x_space; 00035 00036 problem = svm_problem(); 00037 00038 ASSERT(m_labels && m_labels->get_num_labels()); 00039 ASSERT(m_labels->get_label_type() == LT_MULTICLASS); 00040 int32_t num_classes = m_multiclass_strategy->get_num_classes(); 00041 problem.l=m_labels->get_num_labels(); 00042 SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes); 00043 00044 00045 if (data) 00046 { 00047 if (m_labels->get_num_labels() != data->get_num_vectors()) 00048 { 00049 SG_ERROR("Number of training vectors does not match number of " 00050 "labels\n"); 00051 } 00052 m_kernel->init(data, data); 00053 } 00054 00055 problem.y=SG_MALLOC(float64_t, problem.l); 00056 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00057 problem.pv=SG_MALLOC(float64_t, problem.l); 00058 problem.C=SG_MALLOC(float64_t, problem.l); 00059 00060 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00061 00062 for (int32_t i=0; i<problem.l; i++) 00063 { 00064 problem.pv[i]=-1.0; 00065 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i); 00066 problem.x[i]=&x_space[2*i]; 00067 x_space[2*i].index=i; 00068 x_space[2*i+1].index=-1; 00069 } 00070 00071 ASSERT(m_kernel); 00072 00073 param.svm_type=solver_type; // C SVM or NU_SVM 00074 param.kernel_type = LINEAR; 00075 param.degree = 3; 00076 param.gamma = 0; // 1/k 00077 param.coef0 = 0; 00078 param.nu = get_nu(); // Nu 00079 param.kernel=m_kernel; 00080 param.cache_size = m_kernel->get_cache_size(); 00081 param.max_train_time = m_max_train_time; 00082 param.C = get_C(); 00083 param.eps = get_epsilon(); 00084 param.p = 0.1; 00085 param.shrinking = 1; 00086 param.nr_weight = 0; 00087 param.weight_label = NULL; 00088 param.weight = NULL; 00089 param.use_bias = svm_proto()->get_bias_enabled(); 00090 00091 const char* error_msg = svm_check_parameter(&problem,¶m); 00092 00093 if(error_msg) 00094 SG_ERROR("Error: %s\n",error_msg); 00095 00096 model = svm_train(&problem, ¶m); 00097 00098 if (model) 00099 { 00100 if (model->nr_class!=num_classes) 00101 { 00102 SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n", 00103 model->nr_class, num_classes); 00104 } 00105 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef)); 00106 create_multiclass_svm(num_classes); 00107 00108 int32_t* offsets=SG_MALLOC(int32_t, num_classes); 00109 offsets[0]=0; 00110 00111 for (int32_t i=1; i<num_classes; i++) 00112 offsets[i] = offsets[i-1]+model->nSV[i-1]; 00113 00114 int32_t s=0; 00115 for (int32_t i=0; i<num_classes; i++) 00116 { 00117 for (int32_t j=i+1; j<num_classes; j++) 00118 { 00119 int32_t k, l; 00120 00121 float64_t sgn=1; 00122 if (model->label[i]>model->label[j]) 00123 sgn=-1; 00124 00125 int32_t num_sv=model->nSV[i]+model->nSV[j]; 00126 float64_t bias=-model->rho[s]; 00127 00128 ASSERT(num_sv>0); 00129 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]); 00130 00131 CSVM* svm=new CSVM(num_sv); 00132 00133 svm->set_bias(sgn*bias); 00134 00135 int32_t sv_idx=0; 00136 for (k=0; k<model->nSV[i]; k++) 00137 { 00138 SG_DEBUG("setting SV[%d] to %d\n", sv_idx, 00139 model->SV[offsets[i]+k]->index); 00140 svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index); 00141 svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]); 00142 sv_idx++; 00143 } 00144 00145 for (k=0; k<model->nSV[j]; k++) 00146 { 00147 SG_DEBUG("setting SV[%d] to %d\n", sv_idx, 00148 model->SV[offsets[i]+k]->index); 00149 svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index); 00150 svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]); 00151 sv_idx++; 00152 } 00153 00154 int32_t idx=0; 00155 00156 if (sgn>0) 00157 { 00158 for (k=0; k<model->label[i]; k++) 00159 idx+=num_classes-k-1; 00160 00161 for (l=model->label[i]+1; l<model->label[j]; l++) 00162 idx++; 00163 } 00164 else 00165 { 00166 for (k=0; k<model->label[j]; k++) 00167 idx+=num_classes-k-1; 00168 00169 for (l=model->label[j]+1; l<model->label[i]; l++) 00170 idx++; 00171 } 00172 00173 00174 // if (sgn>0) 00175 // idx=((num_classes-1)*model->label[i]+model->label[j])/2; 00176 // else 00177 // idx=((num_classes-1)*model->label[j]+model->label[i])/2; 00178 // 00179 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f " 00180 "label:(%d,%d) -> svm[%d]\n", 00181 s, num_sv, model->l, bias, model->label[i], 00182 model->label[j], idx); 00183 00184 set_svm(idx, svm); 00185 s++; 00186 } 00187 } 00188 00189 set_objective(model->objective); 00190 00191 SG_FREE(offsets); 00192 SG_FREE(problem.x); 00193 SG_FREE(problem.y); 00194 SG_FREE(x_space); 00195 SG_FREE(problem.pv); 00196 SG_FREE(problem.C); 00197 00198 svm_destroy_model(model); 00199 model=NULL; 00200 00201 return true; 00202 } 00203 else 00204 return false; 00205 } 00206