SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 #include <shogun/io/SGIO.h> 00013 #include <shogun/multiclass/MulticlassSVM.h> 00014 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00015 00016 using namespace shogun; 00017 00018 CMulticlassSVM::CMulticlassSVM() 00019 :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL), m_C(0) 00020 { 00021 init(); 00022 } 00023 00024 CMulticlassSVM::CMulticlassSVM(CMulticlassStrategy *strategy) 00025 :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL), m_C(0) 00026 { 00027 init(); 00028 } 00029 00030 CMulticlassSVM::CMulticlassSVM( 00031 CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab) 00032 : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab), m_C(C) 00033 { 00034 init(); 00035 } 00036 00037 CMulticlassSVM::~CMulticlassSVM() 00038 { 00039 } 00040 00041 void CMulticlassSVM::init() 00042 { 00043 SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE); 00044 } 00045 00046 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes) 00047 { 00048 if (num_classes>0) 00049 { 00050 int32_t num_svms=m_multiclass_strategy->get_num_machines(); 00051 00052 m_machines->reset_array(); 00053 for (index_t i=0; i<num_svms; ++i) 00054 m_machines->push_back(NULL); 00055 00056 return true; 00057 } 00058 return false; 00059 } 00060 00061 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm) 00062 { 00063 if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm) 00064 { 00065 m_machines->set_element(svm, num); 00066 return true; 00067 } 00068 return false; 00069 } 00070 00071 bool CMulticlassSVM::init_machines_for_apply(CFeatures* data) 00072 { 00073 if (is_data_locked()) 00074 { 00075 SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when " 00076 "data_lock was called before. Call data_unlock to allow."); 00077 } 00078 00079 if (!m_kernel) 00080 SG_ERROR("No kernel assigned!\n"); 00081 00082 CFeatures* lhs=m_kernel->get_lhs(); 00083 if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED) 00084 SG_ERROR("%s: No left hand side specified\n", get_name()); 00085 00086 if (m_kernel->get_kernel_type()!=K_COMBINED && !lhs->get_num_vectors()) 00087 { 00088 SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to" 00089 " an implementation error in %s, where it was forgotten to set " 00090 "the data (m_svs) indices\n", get_name(), 00091 data->get_name()); 00092 } 00093 00094 if (data && m_kernel->get_kernel_type()!=K_COMBINED) 00095 m_kernel->init(lhs, data); 00096 SG_UNREF(lhs); 00097 00098 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00099 { 00100 CSVM *the_svm = (CSVM *)m_machines->get_element(i); 00101 ASSERT(the_svm); 00102 the_svm->set_kernel(m_kernel); 00103 SG_UNREF(the_svm); 00104 } 00105 00106 return true; 00107 } 00108 00109 bool CMulticlassSVM::load(FILE* modelfl) 00110 { 00111 bool result=true; 00112 char char_buffer[1024]; 00113 int32_t int_buffer; 00114 float64_t double_buffer; 00115 int32_t line_number=1; 00116 int32_t svm_idx=-1; 00117 00118 SG_SET_LOCALE_C; 00119 00120 if (fscanf(modelfl,"%15s\n", char_buffer)==EOF) 00121 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00122 else 00123 { 00124 char_buffer[15]='\0'; 00125 if (strcmp("%MultiClassSVM", char_buffer)!=0) 00126 SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number); 00127 00128 line_number++; 00129 } 00130 00131 int_buffer=0; 00132 if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1) 00133 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00134 00135 if (!feof(modelfl)) 00136 line_number++; 00137 00138 if (int_buffer < 2) 00139 SG_ERROR("less than 2 classes - how is this multiclass?\n"); 00140 00141 create_multiclass_svm(int_buffer); 00142 00143 int_buffer=0; 00144 if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1) 00145 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00146 00147 if (!feof(modelfl)) 00148 line_number++; 00149 00150 if (m_machines->get_num_elements() != int_buffer) 00151 SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer); 00152 00153 if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1) 00154 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00155 00156 if (!feof(modelfl)) 00157 line_number++; 00158 00159 for (int32_t n=0; n<m_machines->get_num_elements(); n++) 00160 { 00161 svm_idx=-1; 00162 if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF) 00163 { 00164 result=false; 00165 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00166 } 00167 else 00168 { 00169 char_buffer[4]='\0'; 00170 if (strncmp("%SVM", char_buffer, 4)!=0) 00171 { 00172 result=false; 00173 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00174 } 00175 00176 if (svm_idx != n) 00177 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00178 00179 line_number++; 00180 } 00181 00182 int_buffer=0; 00183 if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2) 00184 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00185 00186 if (svm_idx != n) 00187 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00188 00189 if (!feof(modelfl)) 00190 line_number++; 00191 00192 SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx); 00193 CSVM* svm=new CSVM(int_buffer); 00194 00195 double_buffer=0; 00196 00197 if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2) 00198 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00199 00200 if (svm_idx != n) 00201 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00202 00203 if (!feof(modelfl)) 00204 line_number++; 00205 00206 svm->set_bias(double_buffer); 00207 00208 if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1) 00209 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00210 00211 if (svm_idx != n) 00212 SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx); 00213 00214 if (!feof(modelfl)) 00215 line_number++; 00216 00217 for (int32_t i=0; i<svm->get_num_support_vectors(); i++) 00218 { 00219 double_buffer=0; 00220 int_buffer=0; 00221 00222 if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2) 00223 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00224 00225 if (!feof(modelfl)) 00226 line_number++; 00227 00228 svm->set_support_vector(i, int_buffer); 00229 svm->set_alpha(i, double_buffer); 00230 } 00231 00232 if (fscanf(modelfl,"%2s", char_buffer) == EOF) 00233 { 00234 result=false; 00235 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00236 } 00237 else 00238 { 00239 char_buffer[3]='\0'; 00240 if (strcmp("];", char_buffer)!=0) 00241 { 00242 result=false; 00243 SG_ERROR( "error in svm file, line nr:%d\n", line_number); 00244 } 00245 line_number++; 00246 } 00247 00248 set_svm(n, svm); 00249 } 00250 00251 svm_proto()->svm_loaded=result; 00252 00253 SG_RESET_LOCALE; 00254 return result; 00255 } 00256 00257 bool CMulticlassSVM::save(FILE* modelfl) 00258 { 00259 SG_SET_LOCALE_C; 00260 00261 if (!m_kernel) 00262 SG_ERROR("Kernel not defined!\n"); 00263 00264 if (m_machines->get_num_elements()<1) 00265 SG_ERROR("Multiclass SVM not trained!\n"); 00266 00267 SG_INFO( "Writing model file..."); 00268 fprintf(modelfl,"%%MultiClassSVM\n"); 00269 fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes()); 00270 fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements()); 00271 fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name()); 00272 00273 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00274 { 00275 CSVM* svm=get_svm(i); 00276 ASSERT(svm); 00277 fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1); 00278 fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors()); 00279 fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias()); 00280 00281 fprintf(modelfl, "alphas%d=[\n", i); 00282 00283 for(int32_t j=0; j<svm->get_num_support_vectors(); j++) 00284 { 00285 fprintf(modelfl,"\t[%+10.16e,%d];\n", 00286 svm->get_alpha(j), svm->get_support_vector(j)); 00287 } 00288 00289 fprintf(modelfl, "];\n"); 00290 } 00291 00292 SG_RESET_LOCALE; 00293 SG_DONE(); 00294 return true ; 00295 }