SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassLibSVM.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/multiclass/MulticlassLibSVM.h>
00012 #include <shogun/multiclass/MulticlassOneVsOneStrategy.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/io/SGIO.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassLibSVM::CMulticlassLibSVM(LIBSVM_SOLVER_TYPE st)
00019 : CMulticlassSVM(new CMulticlassOneVsOneStrategy()), model(NULL), solver_type(st)
00020 {
00021 }
00022 
00023 CMulticlassLibSVM::CMulticlassLibSVM(float64_t C, CKernel* k, CLabels* lab)
00024 : CMulticlassSVM(new CMulticlassOneVsOneStrategy(), C, k, lab), model(NULL), solver_type(LIBSVM_C_SVC)
00025 {
00026 }
00027 
00028 CMulticlassLibSVM::~CMulticlassLibSVM()
00029 {
00030 }
00031 
00032 bool CMulticlassLibSVM::train_machine(CFeatures* data)
00033 {
00034     struct svm_node* x_space;
00035 
00036     problem = svm_problem();
00037 
00038     ASSERT(m_labels && m_labels->get_num_labels());
00039     ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00040     int32_t num_classes = m_multiclass_strategy->get_num_classes();
00041     problem.l=m_labels->get_num_labels();
00042     SG_INFO( "%d trainlabels, %d classes\n", problem.l, num_classes);
00043 
00044 
00045     if (data)
00046     {
00047         if (m_labels->get_num_labels() != data->get_num_vectors())
00048         {
00049             SG_ERROR("Number of training vectors does not match number of "
00050                     "labels\n");
00051         }
00052         m_kernel->init(data, data);
00053     }
00054 
00055     problem.y=SG_MALLOC(float64_t, problem.l);
00056     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00057     problem.pv=SG_MALLOC(float64_t, problem.l);
00058     problem.C=SG_MALLOC(float64_t, problem.l);
00059 
00060     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00061 
00062     for (int32_t i=0; i<problem.l; i++)
00063     {
00064         problem.pv[i]=-1.0;
00065         problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00066         problem.x[i]=&x_space[2*i];
00067         x_space[2*i].index=i;
00068         x_space[2*i+1].index=-1;
00069     }
00070 
00071     ASSERT(m_kernel);
00072 
00073     param.svm_type=solver_type; // C SVM or NU_SVM
00074     param.kernel_type = LINEAR;
00075     param.degree = 3;
00076     param.gamma = 0;    // 1/k
00077     param.coef0 = 0;
00078     param.nu = get_nu(); // Nu
00079     param.kernel=m_kernel;
00080     param.cache_size = m_kernel->get_cache_size();
00081     param.max_train_time = m_max_train_time;
00082     param.C = get_C();
00083     param.eps = get_epsilon();
00084     param.p = 0.1;
00085     param.shrinking = 1;
00086     param.nr_weight = 0;
00087     param.weight_label = NULL;
00088     param.weight = NULL;
00089     param.use_bias = svm_proto()->get_bias_enabled();
00090 
00091     const char* error_msg = svm_check_parameter(&problem,&param);
00092 
00093     if(error_msg)
00094         SG_ERROR("Error: %s\n",error_msg);
00095 
00096     model = svm_train(&problem, &param);
00097 
00098     if (model)
00099     {
00100         if (model->nr_class!=num_classes)
00101         {
00102             SG_ERROR("LibSVM model->nr_class=%d while num_classes=%d\n",
00103                     model->nr_class, num_classes);
00104         }
00105         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef));
00106         create_multiclass_svm(num_classes);
00107 
00108         int32_t* offsets=SG_MALLOC(int32_t, num_classes);
00109         offsets[0]=0;
00110 
00111         for (int32_t i=1; i<num_classes; i++)
00112             offsets[i] = offsets[i-1]+model->nSV[i-1];
00113 
00114         int32_t s=0;
00115         for (int32_t i=0; i<num_classes; i++)
00116         {
00117             for (int32_t j=i+1; j<num_classes; j++)
00118             {
00119                 int32_t k, l;
00120 
00121                 float64_t sgn=1;
00122                 if (model->label[i]>model->label[j])
00123                     sgn=-1;
00124 
00125                 int32_t num_sv=model->nSV[i]+model->nSV[j];
00126                 float64_t bias=-model->rho[s];
00127 
00128                 ASSERT(num_sv>0);
00129                 ASSERT(model->sv_coef[i] && model->sv_coef[j-1]);
00130 
00131                 CSVM* svm=new CSVM(num_sv);
00132 
00133                 svm->set_bias(sgn*bias);
00134 
00135                 int32_t sv_idx=0;
00136                 for (k=0; k<model->nSV[i]; k++)
00137                 {
00138                     SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00139                             model->SV[offsets[i]+k]->index);
00140                     svm->set_support_vector(sv_idx, model->SV[offsets[i]+k]->index);
00141                     svm->set_alpha(sv_idx, sgn*model->sv_coef[j-1][offsets[i]+k]);
00142                     sv_idx++;
00143                 }
00144 
00145                 for (k=0; k<model->nSV[j]; k++)
00146                 {
00147                     SG_DEBUG("setting SV[%d] to %d\n", sv_idx,
00148                             model->SV[offsets[i]+k]->index);
00149                     svm->set_support_vector(sv_idx, model->SV[offsets[j]+k]->index);
00150                     svm->set_alpha(sv_idx, sgn*model->sv_coef[i][offsets[j]+k]);
00151                     sv_idx++;
00152                 }
00153 
00154                 int32_t idx=0;
00155 
00156                 if (sgn>0)
00157                 {
00158                     for (k=0; k<model->label[i]; k++)
00159                         idx+=num_classes-k-1;
00160 
00161                     for (l=model->label[i]+1; l<model->label[j]; l++)
00162                         idx++;
00163                 }
00164                 else
00165                 {
00166                     for (k=0; k<model->label[j]; k++)
00167                         idx+=num_classes-k-1;
00168 
00169                     for (l=model->label[j]+1; l<model->label[i]; l++)
00170                         idx++;
00171                 }
00172 
00173 
00174 //              if (sgn>0)
00175 //                  idx=((num_classes-1)*model->label[i]+model->label[j])/2;
00176 //              else
00177 //                  idx=((num_classes-1)*model->label[j]+model->label[i])/2;
00178 //
00179                 SG_DEBUG("svm[%d] has %d sv (total: %d), b=%f "
00180                         "label:(%d,%d) -> svm[%d]\n",
00181                         s, num_sv, model->l, bias, model->label[i],
00182                         model->label[j], idx);
00183 
00184                 set_svm(idx, svm);
00185                 s++;
00186             }
00187         }
00188 
00189         set_objective(model->objective);
00190 
00191         SG_FREE(offsets);
00192         SG_FREE(problem.x);
00193         SG_FREE(problem.y);
00194         SG_FREE(x_space);
00195         SG_FREE(problem.pv);
00196         SG_FREE(problem.C);
00197 
00198         svm_destroy_model(model);
00199         model=NULL;
00200 
00201         return true;
00202     }
00203     else
00204         return false;
00205 }
00206 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation