SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Written (W) 2009 Marius Kloft 00009 * Copyright (C) 2009 TU Berlin and Max-Planck-Society 00010 */ 00011 #ifdef USE_SVMLIGHT 00012 #include <shogun/classifier/svm/SVMLightOneClass.h> 00013 #endif //USE_SVMLIGHT 00014 00015 #include <shogun/kernel/Kernel.h> 00016 #include <shogun/multiclass/ScatterSVM.h> 00017 #include <shogun/kernel/normalizer/ScatterKernelNormalizer.h> 00018 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00019 #include <shogun/io/SGIO.h> 00020 00021 using namespace shogun; 00022 00023 CScatterSVM::CScatterSVM() 00024 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(NO_BIAS_LIBSVM), 00025 model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00026 { 00027 SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n"); 00028 } 00029 00030 CScatterSVM::CScatterSVM(SCATTER_TYPE type) 00031 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL), 00032 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00033 { 00034 } 00035 00036 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab) 00037 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL), 00038 norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0) 00039 { 00040 } 00041 00042 CScatterSVM::~CScatterSVM() 00043 { 00044 SG_FREE(norm_wc); 00045 SG_FREE(norm_wcw); 00046 } 00047 00048 bool CScatterSVM::train_machine(CFeatures* data) 00049 { 00050 ASSERT(m_labels && m_labels->get_num_labels()); 00051 ASSERT(m_labels->get_label_type() == LT_MULTICLASS); 00052 00053 m_num_classes = m_multiclass_strategy->get_num_classes(); 00054 int32_t num_vectors = m_labels->get_num_labels(); 00055 00056 if (data) 00057 { 00058 if (m_labels->get_num_labels() != data->get_num_vectors()) 00059 SG_ERROR("Number of training vectors does not match number of labels\n"); 00060 m_kernel->init(data, data); 00061 } 00062 00063 int32_t* numc=SG_MALLOC(int32_t, m_num_classes); 00064 SGVector<int32_t>::fill_vector(numc, m_num_classes, 0); 00065 00066 for (int32_t i=0; i<num_vectors; i++) 00067 numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++; 00068 00069 int32_t Nc=0; 00070 int32_t Nmin=num_vectors; 00071 for (int32_t i=0; i<m_num_classes; i++) 00072 { 00073 if (numc[i]>0) 00074 { 00075 Nc++; 00076 Nmin=CMath::min(Nmin, numc[i]); 00077 } 00078 00079 } 00080 SG_FREE(numc); 00081 m_num_classes=m_num_classes; 00082 00083 bool result=false; 00084 00085 if (scatter_type==NO_BIAS_LIBSVM) 00086 { 00087 result=train_no_bias_libsvm(); 00088 } 00089 #ifdef USE_SVMLIGHT 00090 else if (scatter_type==NO_BIAS_SVMLIGHT) 00091 { 00092 result=train_no_bias_svmlight(); 00093 } 00094 #endif //USE_SVMLIGHT 00095 else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2) 00096 { 00097 float64_t nu_min=((float64_t) Nc)/num_vectors; 00098 float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors; 00099 00100 SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max); 00101 00102 if (get_nu()<nu_min || get_nu()>nu_max) 00103 SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max); 00104 00105 result=train_testrule12(); 00106 } 00107 else 00108 SG_ERROR("Unknown Scatter type\n"); 00109 00110 return result; 00111 } 00112 00113 bool CScatterSVM::train_no_bias_libsvm() 00114 { 00115 struct svm_node* x_space; 00116 00117 problem.l=m_labels->get_num_labels(); 00118 SG_INFO( "%d trainlabels\n", problem.l); 00119 00120 problem.y=SG_MALLOC(float64_t, problem.l); 00121 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00122 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00123 00124 for (int32_t i=0; i<problem.l; i++) 00125 { 00126 problem.y[i]=+1; 00127 problem.x[i]=&x_space[2*i]; 00128 x_space[2*i].index=i; 00129 x_space[2*i+1].index=-1; 00130 } 00131 00132 int32_t weights_label[2]={-1,+1}; 00133 float64_t weights[2]={1.0,get_C()/get_C()}; 00134 00135 ASSERT(m_kernel && m_kernel->has_features()); 00136 ASSERT(m_kernel->get_num_vec_lhs()==problem.l); 00137 00138 param.svm_type=C_SVC; // Nu MC SVM 00139 param.kernel_type = LINEAR; 00140 param.degree = 3; 00141 param.gamma = 0; // 1/k 00142 param.coef0 = 0; 00143 param.nu = get_nu(); // Nu 00144 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer(); 00145 m_kernel->set_normalizer(new CScatterKernelNormalizer( 00146 m_num_classes-1, -1, m_labels, prev_normalizer)); 00147 param.kernel=m_kernel; 00148 param.cache_size = m_kernel->get_cache_size(); 00149 param.C = 0; 00150 param.eps = get_epsilon(); 00151 param.p = 0.1; 00152 param.shrinking = 0; 00153 param.nr_weight = 2; 00154 param.weight_label = weights_label; 00155 param.weight = weights; 00156 param.nr_class=m_num_classes; 00157 param.use_bias = svm_proto()->get_bias_enabled(); 00158 00159 const char* error_msg = svm_check_parameter(&problem,¶m); 00160 00161 if(error_msg) 00162 SG_ERROR("Error: %s\n",error_msg); 00163 00164 model = svm_train(&problem, ¶m); 00165 m_kernel->set_normalizer(prev_normalizer); 00166 SG_UNREF(prev_normalizer); 00167 00168 if (model) 00169 { 00170 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00171 00172 ASSERT(model->nr_class==m_num_classes); 00173 create_multiclass_svm(m_num_classes); 00174 00175 rho=model->rho[0]; 00176 00177 SG_FREE(norm_wcw); 00178 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00179 00180 for (int32_t i=0; i<m_num_classes; i++) 00181 { 00182 int32_t num_sv=model->nSV[i]; 00183 00184 CSVM* svm=new CSVM(num_sv); 00185 svm->set_bias(model->rho[i+1]); 00186 norm_wcw[i]=model->normwcw[i]; 00187 00188 00189 for (int32_t j=0; j<num_sv; j++) 00190 { 00191 svm->set_alpha(j, model->sv_coef[i][j]); 00192 svm->set_support_vector(j, model->SV[i][j].index); 00193 } 00194 00195 set_svm(i, svm); 00196 } 00197 00198 SG_FREE(problem.x); 00199 SG_FREE(problem.y); 00200 SG_FREE(x_space); 00201 for (int32_t i=0; i<m_num_classes; i++) 00202 { 00203 SG_FREE(model->SV[i]); 00204 model->SV[i]=NULL; 00205 } 00206 svm_destroy_model(model); 00207 00208 if (scatter_type==TEST_RULE2) 00209 compute_norm_wc(); 00210 00211 model=NULL; 00212 return true; 00213 } 00214 else 00215 return false; 00216 } 00217 00218 #ifdef USE_SVMLIGHT 00219 bool CScatterSVM::train_no_bias_svmlight() 00220 { 00221 CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer(); 00222 CScatterKernelNormalizer* n=new CScatterKernelNormalizer( 00223 m_num_classes-1, -1, m_labels, prev_normalizer); 00224 m_kernel->set_normalizer(n); 00225 m_kernel->init_normalizer(); 00226 00227 CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel); 00228 light->set_linadd_enabled(false); 00229 light->train(); 00230 00231 SG_FREE(norm_wcw); 00232 norm_wcw = SG_MALLOC(float64_t, m_num_classes); 00233 00234 int32_t num_sv=light->get_num_support_vectors(); 00235 svm_proto()->create_new_model(num_sv); 00236 00237 for (int32_t i=0; i<num_sv; i++) 00238 { 00239 svm_proto()->set_alpha(i, light->get_alpha(i)); 00240 svm_proto()->set_support_vector(i, light->get_support_vector(i)); 00241 } 00242 00243 m_kernel->set_normalizer(prev_normalizer); 00244 return true; 00245 } 00246 #endif //USE_SVMLIGHT 00247 00248 bool CScatterSVM::train_testrule12() 00249 { 00250 struct svm_node* x_space; 00251 problem.l=m_labels->get_num_labels(); 00252 SG_INFO( "%d trainlabels\n", problem.l); 00253 00254 problem.y=SG_MALLOC(float64_t, problem.l); 00255 problem.x=SG_MALLOC(struct svm_node*, problem.l); 00256 x_space=SG_MALLOC(struct svm_node, 2*problem.l); 00257 00258 for (int32_t i=0; i<problem.l; i++) 00259 { 00260 problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i); 00261 problem.x[i]=&x_space[2*i]; 00262 x_space[2*i].index=i; 00263 x_space[2*i+1].index=-1; 00264 } 00265 00266 int32_t weights_label[2]={-1,+1}; 00267 float64_t weights[2]={1.0,get_C()/get_C()}; 00268 00269 ASSERT(m_kernel && m_kernel->has_features()); 00270 ASSERT(m_kernel->get_num_vec_lhs()==problem.l); 00271 00272 param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM 00273 param.kernel_type = LINEAR; 00274 param.degree = 3; 00275 param.gamma = 0; // 1/k 00276 param.coef0 = 0; 00277 param.nu = get_nu(); // Nu 00278 param.kernel=m_kernel; 00279 param.cache_size = m_kernel->get_cache_size(); 00280 param.C = 0; 00281 param.eps = get_epsilon(); 00282 param.p = 0.1; 00283 param.shrinking = 0; 00284 param.nr_weight = 2; 00285 param.weight_label = weights_label; 00286 param.weight = weights; 00287 param.nr_class=m_num_classes; 00288 param.use_bias = svm_proto()->get_bias_enabled(); 00289 00290 const char* error_msg = svm_check_parameter(&problem,¶m); 00291 00292 if(error_msg) 00293 SG_ERROR("Error: %s\n",error_msg); 00294 00295 model = svm_train(&problem, ¶m); 00296 00297 if (model) 00298 { 00299 ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef)); 00300 00301 ASSERT(model->nr_class==m_num_classes); 00302 create_multiclass_svm(m_num_classes); 00303 00304 rho=model->rho[0]; 00305 00306 SG_FREE(norm_wcw); 00307 norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00308 00309 for (int32_t i=0; i<m_num_classes; i++) 00310 { 00311 int32_t num_sv=model->nSV[i]; 00312 00313 CSVM* svm=new CSVM(num_sv); 00314 svm->set_bias(model->rho[i+1]); 00315 norm_wcw[i]=model->normwcw[i]; 00316 00317 00318 for (int32_t j=0; j<num_sv; j++) 00319 { 00320 svm->set_alpha(j, model->sv_coef[i][j]); 00321 svm->set_support_vector(j, model->SV[i][j].index); 00322 } 00323 00324 set_svm(i, svm); 00325 } 00326 00327 SG_FREE(problem.x); 00328 SG_FREE(problem.y); 00329 SG_FREE(x_space); 00330 for (int32_t i=0; i<m_num_classes; i++) 00331 { 00332 SG_FREE(model->SV[i]); 00333 model->SV[i]=NULL; 00334 } 00335 svm_destroy_model(model); 00336 00337 if (scatter_type==TEST_RULE2) 00338 compute_norm_wc(); 00339 00340 model=NULL; 00341 return true; 00342 } 00343 else 00344 return false; 00345 } 00346 00347 void CScatterSVM::compute_norm_wc() 00348 { 00349 SG_FREE(norm_wc); 00350 norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements()); 00351 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00352 norm_wc[i]=0; 00353 00354 00355 for (int c=0; c<m_machines->get_num_elements(); c++) 00356 { 00357 CSVM* svm=get_svm(c); 00358 int32_t num_sv = svm->get_num_support_vectors(); 00359 00360 for (int32_t i=0; i<num_sv; i++) 00361 { 00362 int32_t ii=svm->get_support_vector(i); 00363 for (int32_t j=0; j<num_sv; j++) 00364 { 00365 int32_t jj=svm->get_support_vector(j); 00366 norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j); 00367 } 00368 } 00369 } 00370 00371 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00372 norm_wc[i]=CMath::sqrt(norm_wc[i]); 00373 00374 SGVector<float64_t>::display_vector(norm_wc, m_machines->get_num_elements(), "norm_wc"); 00375 } 00376 00377 CLabels* CScatterSVM::classify_one_vs_rest() 00378 { 00379 CMulticlassLabels* output=NULL; 00380 if (!m_kernel) 00381 { 00382 SG_ERROR( "SVM can not proceed without kernel!\n"); 00383 return NULL; 00384 } 00385 00386 if (!( m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs())) 00387 return NULL; 00388 00389 int32_t num_vectors=m_kernel->get_num_vec_rhs(); 00390 00391 output=new CMulticlassLabels(num_vectors); 00392 SG_REF(output); 00393 00394 if (scatter_type == TEST_RULE1) 00395 { 00396 ASSERT(m_machines->get_num_elements()>0); 00397 for (int32_t i=0; i<num_vectors; i++) 00398 output->set_label(i, apply(i)); 00399 } 00400 #ifdef USE_SVMLIGHT 00401 else if (scatter_type == NO_BIAS_SVMLIGHT) 00402 { 00403 float64_t* outputs=SG_MALLOC(float64_t, num_vectors*m_num_classes); 00404 SGVector<float64_t>::fill_vector(outputs,num_vectors*m_num_classes,0.0); 00405 00406 for (int32_t i=0; i<num_vectors; i++) 00407 { 00408 for (int32_t j=0; j<svm_proto()->get_num_support_vectors(); j++) 00409 { 00410 float64_t score=m_kernel->kernel(svm_proto()->get_support_vector(j), i)*svm_proto()->get_alpha(j); 00411 int32_t label=((CMulticlassLabels*) m_labels)->get_int_label(svm_proto()->get_support_vector(j)); 00412 for (int32_t c=0; c<m_num_classes; c++) 00413 { 00414 float64_t s= (label==c) ? (m_num_classes-1) : (-1); 00415 outputs[c+i*m_num_classes]+=s*score; 00416 } 00417 } 00418 } 00419 00420 for (int32_t i=0; i<num_vectors; i++) 00421 { 00422 int32_t winner=0; 00423 float64_t max_out=outputs[i*m_num_classes+0]; 00424 00425 for (int32_t j=1; j<m_num_classes; j++) 00426 { 00427 float64_t out=outputs[i*m_num_classes+j]; 00428 00429 if (out>max_out) 00430 { 00431 winner=j; 00432 max_out=out; 00433 } 00434 } 00435 00436 output->set_label(i, winner); 00437 } 00438 00439 SG_FREE(outputs); 00440 } 00441 #endif //USE_SVMLIGHT 00442 else 00443 { 00444 ASSERT(m_machines->get_num_elements()>0); 00445 ASSERT(num_vectors==output->get_num_labels()); 00446 CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements()); 00447 00448 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00449 { 00450 //SG_PRINT("svm %d\n", i); 00451 CSVM *svm = get_svm(i); 00452 ASSERT(svm); 00453 svm->set_kernel(m_kernel); 00454 svm->set_labels(m_labels); 00455 outputs[i]=svm->apply(); 00456 SG_UNREF(svm); 00457 } 00458 00459 for (int32_t i=0; i<num_vectors; i++) 00460 { 00461 int32_t winner=0; 00462 float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0]; 00463 00464 for (int32_t j=1; j<m_machines->get_num_elements(); j++) 00465 { 00466 float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j]; 00467 00468 if (out>max_out) 00469 { 00470 winner=j; 00471 max_out=out; 00472 } 00473 } 00474 00475 output->set_label(i, winner); 00476 } 00477 00478 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00479 SG_UNREF(outputs[i]); 00480 00481 SG_FREE(outputs); 00482 } 00483 00484 return output; 00485 } 00486 00487 float64_t CScatterSVM::apply(int32_t num) 00488 { 00489 ASSERT(m_machines->get_num_elements()>0); 00490 float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements()); 00491 int32_t winner=0; 00492 00493 if (scatter_type == TEST_RULE1) 00494 { 00495 for (int32_t c=0; c<m_machines->get_num_elements(); c++) 00496 outputs[c]=get_svm(c)->get_bias()-rho; 00497 00498 for (int32_t c=0; c<m_machines->get_num_elements(); c++) 00499 { 00500 float64_t v=0; 00501 00502 for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++) 00503 { 00504 float64_t alpha=get_svm(c)->get_alpha(i); 00505 int32_t svidx=get_svm(c)->get_support_vector(i); 00506 v += alpha*m_kernel->kernel(svidx, num); 00507 } 00508 00509 outputs[c] += v; 00510 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00511 outputs[j] -= v/m_machines->get_num_elements(); 00512 } 00513 00514 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00515 outputs[j]/=norm_wcw[j]; 00516 00517 float64_t max_out=outputs[0]; 00518 for (int32_t j=0; j<m_machines->get_num_elements(); j++) 00519 { 00520 if (outputs[j]>max_out) 00521 { 00522 max_out=outputs[j]; 00523 winner=j; 00524 } 00525 } 00526 } 00527 #ifdef USE_SVMLIGHT 00528 else if (scatter_type == NO_BIAS_SVMLIGHT) 00529 { 00530 SG_ERROR("Use classify...\n"); 00531 } 00532 #endif //USE_SVMLIGHT 00533 else 00534 { 00535 float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0]; 00536 00537 for (int32_t i=1; i<m_machines->get_num_elements(); i++) 00538 { 00539 outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i]; 00540 if (outputs[i]>max_out) 00541 { 00542 winner=i; 00543 max_out=outputs[i]; 00544 } 00545 } 00546 } 00547 00548 SG_FREE(outputs); 00549 return winner; 00550 }