SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 #ifdef HAVE_LAPACK 00013 #include <shogun/multiclass/MulticlassLibLinear.h> 00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00015 #include <shogun/mathematics/Math.h> 00016 #include <shogun/lib/v_array.h> 00017 #include <shogun/lib/Signal.h> 00018 #include <shogun/labels/MulticlassLabels.h> 00019 00020 using namespace shogun; 00021 00022 CMulticlassLibLinear::CMulticlassLibLinear() : 00023 CLinearMulticlassMachine() 00024 { 00025 init_defaults(); 00026 } 00027 00028 CMulticlassLibLinear::CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs) : 00029 CLinearMulticlassMachine(new CMulticlassOneVsRestStrategy(),features,NULL,labs) 00030 { 00031 init_defaults(); 00032 set_C(C); 00033 } 00034 00035 void CMulticlassLibLinear::init_defaults() 00036 { 00037 set_C(1.0); 00038 set_epsilon(1e-2); 00039 set_max_iter(10000); 00040 set_use_bias(false); 00041 set_save_train_state(false); 00042 m_train_state = NULL; 00043 } 00044 00045 void CMulticlassLibLinear::register_parameters() 00046 { 00047 SG_ADD(&m_C, "m_C", "regularization constant",MS_AVAILABLE); 00048 SG_ADD(&m_epsilon, "m_epsilon", "tolerance epsilon",MS_NOT_AVAILABLE); 00049 SG_ADD(&m_max_iter, "m_max_iter", "max number of iterations",MS_NOT_AVAILABLE); 00050 SG_ADD(&m_use_bias, "m_use_bias", "indicates whether bias should be used",MS_NOT_AVAILABLE); 00051 SG_ADD(&m_save_train_state, "m_save_train_state", "indicates whether bias should be used",MS_NOT_AVAILABLE); 00052 } 00053 00054 CMulticlassLibLinear::~CMulticlassLibLinear() 00055 { 00056 reset_train_state(); 00057 } 00058 00059 SGVector<int32_t> CMulticlassLibLinear::get_support_vectors() const 00060 { 00061 if (!m_train_state) 00062 SG_ERROR("Please enable save_train_state option and train machine.\n"); 00063 00064 ASSERT(m_labels && m_labels->get_label_type() == LT_MULTICLASS); 00065 00066 int32_t num_vectors = m_features->get_num_vectors(); 00067 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00068 00069 v_array<int32_t> nz_idxs; 00070 nz_idxs.reserve(num_vectors); 00071 00072 for (int32_t i=0; i<num_vectors; i++) 00073 { 00074 for (int32_t y=0; y<num_classes; y++) 00075 { 00076 if (CMath::abs(m_train_state->alpha[i*num_classes+y])>1e-6) 00077 { 00078 nz_idxs.push(i); 00079 break; 00080 } 00081 } 00082 } 00083 int32_t num_nz = nz_idxs.index(); 00084 nz_idxs.reserve(num_nz); 00085 return SGVector<int32_t>(nz_idxs.begin,num_nz); 00086 } 00087 00088 SGMatrix<float64_t> CMulticlassLibLinear::obtain_regularizer_matrix() const 00089 { 00090 return SGMatrix<float64_t>(); 00091 } 00092 00093 bool CMulticlassLibLinear::train_machine(CFeatures* data) 00094 { 00095 if (data) 00096 set_features((CDotFeatures*)data); 00097 00098 ASSERT(m_features); 00099 ASSERT(m_labels && m_labels->get_label_type()==LT_MULTICLASS); 00100 ASSERT(m_multiclass_strategy); 00101 00102 int32_t num_vectors = m_features->get_num_vectors(); 00103 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00104 int32_t bias_n = m_use_bias ? 1 : 0; 00105 00106 problem mc_problem; 00107 mc_problem.l = num_vectors; 00108 mc_problem.n = m_features->get_dim_feature_space() + bias_n; 00109 mc_problem.y = SG_MALLOC(float64_t, mc_problem.l); 00110 for (int32_t i=0; i<num_vectors; i++) 00111 mc_problem.y[i] = ((CMulticlassLabels*) m_labels)->get_int_label(i); 00112 00113 mc_problem.x = m_features; 00114 mc_problem.use_bias = m_use_bias; 00115 00116 SGMatrix<float64_t> w0 = obtain_regularizer_matrix(); 00117 00118 if (!m_train_state) 00119 m_train_state = new mcsvm_state(); 00120 00121 float64_t* C = SG_MALLOC(float64_t, num_vectors); 00122 for (int32_t i=0; i<num_vectors; i++) 00123 C[i] = m_C; 00124 00125 CSignal::clear_cancel(); 00126 00127 Solver_MCSVM_CS solver(&mc_problem,num_classes,C,w0.matrix,m_epsilon, 00128 m_max_iter,m_max_train_time,m_train_state); 00129 solver.solve(); 00130 00131 m_machines->reset_array(); 00132 for (int32_t i=0; i<num_classes; i++) 00133 { 00134 CLinearMachine* machine = new CLinearMachine(); 00135 SGVector<float64_t> cw(mc_problem.n-bias_n); 00136 00137 for (int32_t j=0; j<mc_problem.n-bias_n; j++) 00138 cw[j] = m_train_state->w[j*num_classes+i]; 00139 00140 machine->set_w(cw); 00141 00142 if (m_use_bias) 00143 machine->set_bias(m_train_state->w[(mc_problem.n-bias_n)*num_classes+i]); 00144 00145 m_machines->push_back(machine); 00146 } 00147 00148 if (!m_save_train_state) 00149 reset_train_state(); 00150 00151 SG_FREE(C); 00152 SG_FREE(mc_problem.y); 00153 00154 return true; 00155 } 00156 #endif /* HAVE_LAPACK */