SHOGUN
v2.0.0
|
00001 /* 00002 SVM with stochastic gradient 00003 Copyright (C) 2007- Leon Bottou 00004 00005 This program is free software; you can redistribute it and/or 00006 modify it under the terms of the GNU Lesser General Public 00007 License as published by the Free Software Foundation; either 00008 version 2.1 of the License, or (at your option) any later version. 00009 00010 This program is distributed in the hope that it will be useful, 00011 but WITHOUT ANY WARRANTY; without even the implied warranty of 00012 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00013 GNU General Public License for more details. 00014 00015 You should have received a copy of the GNU General Public License 00016 along with this program; if not, write to the Free Software 00017 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00018 $Id: svmsgd.cpp,v 1.13 2007/10/02 20:40:06 cvs Exp $ 00019 00020 Shogun adjustments (w) 2008-2009 Soeren Sonnenburg 00021 */ 00022 00023 #include <shogun/classifier/svm/SVMSGD.h> 00024 #include <shogun/base/Parameter.h> 00025 #include <shogun/lib/Signal.h> 00026 #include <shogun/labels/BinaryLabels.h> 00027 #include <shogun/loss/HingeLoss.h> 00028 00029 using namespace shogun; 00030 00031 CSVMSGD::CSVMSGD() 00032 : CLinearMachine() 00033 { 00034 init(); 00035 } 00036 00037 CSVMSGD::CSVMSGD(float64_t C) 00038 : CLinearMachine() 00039 { 00040 init(); 00041 00042 C1=C; 00043 C2=C; 00044 } 00045 00046 CSVMSGD::CSVMSGD(float64_t C, CDotFeatures* traindat, CLabels* trainlab) 00047 : CLinearMachine() 00048 { 00049 init(); 00050 C1=C; 00051 C2=C; 00052 00053 set_features(traindat); 00054 set_labels(trainlab); 00055 } 00056 00057 CSVMSGD::~CSVMSGD() 00058 { 00059 SG_UNREF(loss); 00060 } 00061 00062 void CSVMSGD::set_loss_function(CLossFunction* loss_func) 00063 { 00064 if (loss) 00065 SG_UNREF(loss); 00066 loss=loss_func; 00067 SG_REF(loss); 00068 } 00069 00070 bool CSVMSGD::train_machine(CFeatures* data) 00071 { 00072 // allocate memory for w and initialize everyting w and bias with 0 00073 ASSERT(m_labels); 00074 ASSERT(m_labels->get_label_type() == LT_BINARY); 00075 00076 if (data) 00077 { 00078 if (!data->has_property(FP_DOT)) 00079 SG_ERROR("Specified features are not of type CDotFeatures\n"); 00080 set_features((CDotFeatures*) data); 00081 } 00082 00083 ASSERT(features); 00084 00085 int32_t num_train_labels=m_labels->get_num_labels(); 00086 int32_t num_vec=features->get_num_vectors(); 00087 00088 ASSERT(num_vec==num_train_labels); 00089 ASSERT(num_vec>0); 00090 00091 w=SGVector<float64_t>(features->get_dim_feature_space()); 00092 w.zero(); 00093 bias=0; 00094 00095 float64_t lambda= 1.0/(C1*num_vec); 00096 00097 // Shift t in order to have a 00098 // reasonable initial learning rate. 00099 // This assumes |x| \approx 1. 00100 float64_t maxw = 1.0 / sqrt(lambda); 00101 float64_t typw = sqrt(maxw); 00102 float64_t eta0 = typw / CMath::max(1.0,-loss->first_derivative(-typw,1)); 00103 t = 1 / (eta0 * lambda); 00104 00105 SG_INFO("lambda=%f, epochs=%d, eta0=%f\n", lambda, epochs, eta0); 00106 00107 00108 //do the sgd 00109 calibrate(); 00110 00111 SG_INFO("Training on %d vectors\n", num_vec); 00112 CSignal::clear_cancel(); 00113 00114 ELossType loss_type = loss->get_loss_type(); 00115 bool is_log_loss = false; 00116 if ((loss_type == L_LOGLOSS) || (loss_type == L_LOGLOSSMARGIN)) 00117 is_log_loss = true; 00118 00119 for(int32_t e=0; e<epochs && (!CSignal::cancel_computations()); e++) 00120 { 00121 count = skip; 00122 for (int32_t i=0; i<num_vec; i++) 00123 { 00124 float64_t eta = 1.0 / (lambda * t); 00125 float64_t y = ((CBinaryLabels*) m_labels)->get_label(i); 00126 float64_t z = y * (features->dense_dot(i, w.vector, w.vlen) + bias); 00127 00128 if (z < 1 || is_log_loss) 00129 { 00130 float64_t etd = -eta * loss->first_derivative(z,1); 00131 features->add_to_dense_vec(etd * y / wscale, i, w.vector, w.vlen); 00132 00133 if (use_bias) 00134 { 00135 if (use_regularized_bias) 00136 bias *= 1 - eta * lambda * bscale; 00137 bias += etd * y * bscale; 00138 } 00139 } 00140 00141 if (--count <= 0) 00142 { 00143 float64_t r = 1 - eta * lambda * skip; 00144 if (r < 0.8) 00145 r = pow(1 - eta * lambda, skip); 00146 SGVector<float64_t>::scale_vector(r, w.vector, w.vlen); 00147 count = skip; 00148 } 00149 t++; 00150 } 00151 } 00152 00153 float64_t wnorm = SGVector<float64_t>::dot(w.vector,w.vector, w.vlen); 00154 SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias); 00155 00156 return true; 00157 } 00158 00159 void CSVMSGD::calibrate() 00160 { 00161 ASSERT(features); 00162 int32_t num_vec=features->get_num_vectors(); 00163 int32_t c_dim=features->get_dim_feature_space(); 00164 00165 ASSERT(num_vec>0); 00166 ASSERT(c_dim>0); 00167 00168 float64_t* c=SG_MALLOC(float64_t, c_dim); 00169 memset(c, 0, c_dim*sizeof(float64_t)); 00170 00171 SG_INFO("Estimating sparsity and bscale num_vec=%d num_feat=%d.\n", num_vec, c_dim); 00172 00173 // compute average gradient size 00174 int32_t n = 0; 00175 float64_t m = 0; 00176 float64_t r = 0; 00177 00178 for (int32_t j=0; j<num_vec && m<=1000; j++, n++) 00179 { 00180 r += features->get_nnz_features_for_vector(j); 00181 features->add_to_dense_vec(1, j, c, c_dim, true); 00182 00183 //waste cpu cycles for readability 00184 //(only changed dims need checking) 00185 m=SGVector<float64_t>::max(c, c_dim); 00186 } 00187 00188 // bias update scaling 00189 bscale = 0.5*m/n; 00190 00191 // compute weight decay skip 00192 skip = (int32_t) ((16 * n * c_dim) / r); 00193 SG_INFO("using %d examples. skip=%d bscale=%.6f\n", n, skip, bscale); 00194 00195 SG_FREE(c); 00196 } 00197 00198 void CSVMSGD::init() 00199 { 00200 t=1; 00201 C1=1; 00202 C2=1; 00203 wscale=1; 00204 bscale=1; 00205 epochs=5; 00206 skip=1000; 00207 count=1000; 00208 use_bias=true; 00209 00210 use_regularized_bias=false; 00211 00212 loss=new CHingeLoss(); 00213 SG_REF(loss); 00214 00215 m_parameters->add(&C1, "C1", "Cost constant 1."); 00216 m_parameters->add(&C2, "C2", "Cost constant 2."); 00217 m_parameters->add(&wscale, "wscale", "W scale"); 00218 m_parameters->add(&bscale, "bscale", "b scale"); 00219 m_parameters->add(&epochs, "epochs", "epochs"); 00220 m_parameters->add(&skip, "skip", "skip"); 00221 m_parameters->add(&count, "count", "count"); 00222 m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used."); 00223 m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized."); 00224 }