SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
MulticlassLibLinear.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Sergey Lisitsyn
00008  * Copyright (C) 2012 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 #ifdef HAVE_LAPACK
00013 #include <shogun/multiclass/MulticlassLibLinear.h>
00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/lib/v_array.h>
00017 #include <shogun/lib/Signal.h>
00018 #include <shogun/labels/MulticlassLabels.h>
00019 
00020 using namespace shogun;
00021 
00022 CMulticlassLibLinear::CMulticlassLibLinear() :
00023     CLinearMulticlassMachine()
00024 {
00025     init_defaults();
00026 }
00027 
00028 CMulticlassLibLinear::CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs) :
00029     CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs)
00030 {
00031     init_defaults();
00032     set_C(C);
00033 }
00034 
00035 void CMulticlassLibLinear::init_defaults()
00036 {
00037     set_C(1.0);
00038     set_epsilon(1e-2);
00039     set_max_iter(10000);
00040     set_use_bias(false);
00041     set_save_train_state(false);
00042     m_train_state = NULL;
00043 }
00044 
00045 void CMulticlassLibLinear::register_parameters()
00046 {
00047     SG_ADD(&m_C, "m_C", "regularization constant",MS_AVAILABLE);
00048     SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE);
00049     SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE);
00050     SG_ADD(&m_use_bias, "m_use_bias", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00051     SG_ADD(&m_save_train_state, "m_save_train_state", "indicates whether bias should be used",MS_NOT_AVAILABLE);
00052 }
00053 
00054 CMulticlassLibLinear::~CMulticlassLibLinear()
00055 {
00056     reset_train_state();
00057 }
00058 
00059 SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const
00060 {
00061     if (!m_train_state)
00062         SG_ERROR("Please enable save_train_state option and train machine.\n");
00063 
00064     ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS);
00065 
00066     int32_t num_vectors = m_features->get_num_vectors();
00067     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00068 
00069     v_array<int32_t> nz_idxs;
00070     nz_idxs.reserve(num_vectors);
00071 
00072     for (int32_t i=0; i<num_vectors; i++)
00073     {
00074         for (int32_t y=0; y<num_classes; y++)
00075         {
00076             if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6)
00077             {
00078                 nz_idxs.push(i);
00079                 break;
00080             }
00081         }
00082     }
00083     int32_t num_nz = nz_idxs.index();
00084     nz_idxs.reserve(num_nz);
00085     return SGVector<int32_t>(nz_idxs.begin,num_nz);
00086 }
00087 
00088 SGMatrix<float64_t> CMulticlassLibLinear::obtain_regularizer_matrix() const
00089 {
00090     return SGMatrix<float64_t>();
00091 }
00092 
00093 bool CMulticlassLibLinear::train_machine(CFeatures* data)
00094 {
00095     if (data)
00096         set_features((CDotFeatures*)data);
00097 
00098     ASSERT(m_features);
00099     ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS);
00100     ASSERT(m_multiclass_strategy);
00101 
00102     int32_t num_vectors = m_features->get_num_vectors();
00103     int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00104     int32_t bias_n = m_use_bias ? 1 : 0;
00105 
00106     problem mc_problem;
00107     mc_problem.l = num_vectors;
00108     mc_problem.n = m_features->get_dim_feature_space() + bias_n;
00109     mc_problem.y = SG_MALLOC(float64_t, mc_problem.l);
00110     for (int32_t i=0; i<num_vectors; i++)
00111         mc_problem.y[i] = ((CMulticlassLabels*) m_labels)->get_int_label(i);
00112 
00113     mc_problem.x = m_features;
00114     mc_problem.use_bias = m_use_bias;
00115 
00116     SGMatrix<float64_t> w0 = obtain_regularizer_matrix();
00117 
00118     if (!m_train_state)
00119         m_train_state = new mcsvm_state();
00120 
00121     float64_t* C = SG_MALLOC(float64_t, num_vectors);
00122     for (int32_t i=0; i<num_vectors; i++)
00123         C[i] = m_C;
00124 
00125     CSignal::clear_cancel();
00126 
00127     Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,m_epsilon,
00128                            m_max_iter,m_max_train_time,m_train_state);
00129     solver.solve();
00130 
00131     m_machines->reset_array();
00132     for (int32_t i=0; i<num_classes; i++)
00133     {
00134         CLinearMachine* machine = new CLinearMachine();
00135         SGVector<float64_t> cw(mc_problem.n-bias_n);
00136 
00137         for (int32_t j=0; j<mc_problem.n-bias_n; j++)
00138             cw[j] = m_train_state->w[j*num_classes+i];
00139 
00140         machine->set_w(cw);
00141 
00142         if (m_use_bias)
00143             machine->set_bias(m_train_state->w[(mc_problem.n-bias_n)*num_classes+i]);
00144 
00145         m_machines->push_back(machine);
00146     }
00147 
00148     if (!m_save_train_state)
00149         reset_train_state();
00150 
00151     SG_FREE(C);
00152     SG_FREE(mc_problem.y);
00153 
00154     return true;
00155 }
00156 #endif /* HAVE_LAPACK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation