SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Vojtech Franc, xfrancv@cmp.felk.cvut.cz 00008 * Copyright (C) 1999-2008 Center for Machine Perception, CTU FEL Prague 00009 */ 00010 00011 #include <shogun/io/SGIO.h> 00012 #include <shogun/labels/MulticlassLabels.h> 00013 #include <shogun/multiclass/GMNPSVM.h> 00014 #include <shogun/multiclass/GMNPLib.h> 00015 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00016 00017 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW)) 00018 #define MINUS_INF INT_MIN 00019 #define PLUS_INF INT_MAX 00020 #define KDELTA(A,B) (A==B) 00021 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4)) 00022 00023 using namespace shogun; 00024 00025 CGMNPSVM::CGMNPSVM() 00026 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()) 00027 { 00028 init(); 00029 } 00030 00031 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab) 00032 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab) 00033 { 00034 init(); 00035 } 00036 00037 CGMNPSVM::~CGMNPSVM() 00038 { 00039 if (m_basealphas != NULL) SG_FREE(m_basealphas); 00040 } 00041 00042 void 00043 CGMNPSVM::init() 00044 { 00045 m_parameters->add_matrix(&m_basealphas, 00046 &m_basealphas_y, &m_basealphas_x, 00047 "m_basealphas", 00048 "Is the basic untransformed alpha."); 00049 00050 m_basealphas = NULL, m_basealphas_y = 0, m_basealphas_x = 0; 00051 } 00052 00053 bool CGMNPSVM::train_machine(CFeatures* data) 00054 { 00055 ASSERT(m_kernel); 00056 ASSERT(m_labels && m_labels->get_num_labels()); 00057 ASSERT(m_labels->get_label_type() == LT_MULTICLASS); 00058 00059 if (data) 00060 { 00061 if (m_labels->get_num_labels() != data->get_num_vectors()) 00062 { 00063 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does" 00064 " not match number of labels (%d)\n", get_name(), 00065 data->get_num_vectors(), m_labels->get_num_labels()); 00066 } 00067 m_kernel->init(data, data); 00068 } 00069 00070 int32_t num_data = m_labels->get_num_labels(); 00071 int32_t num_classes = m_multiclass_strategy->get_num_classes(); 00072 int32_t num_virtual_data= num_data*(num_classes-1); 00073 00074 SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes); 00075 00076 float64_t* vector_y = SG_MALLOC(float64_t, num_data); 00077 for (int32_t i=0; i<num_data; i++) 00078 { 00079 vector_y[i] = ((CMulticlassLabels*) m_labels)->get_label(i)+1; 00080 00081 } 00082 00083 float64_t C = get_C(); 00084 int32_t tmax = 1000000000; 00085 float64_t tolabs = 0; 00086 float64_t tolrel = get_epsilon(); 00087 00088 float64_t reg_const=0; 00089 if( C!=0 ) 00090 reg_const = 1/(2*C); 00091 00092 00093 float64_t* alpha = SG_MALLOC(float64_t, num_virtual_data); 00094 float64_t* vector_c = SG_MALLOC(float64_t, num_virtual_data); 00095 memset(vector_c, 0, num_virtual_data*sizeof(float64_t)); 00096 00097 float64_t thlb = 10000000000.0; 00098 int32_t t = 0; 00099 float64_t* History = NULL; 00100 int32_t verb = 0; 00101 00102 CGMNPLib mnp(vector_y,m_kernel,num_data, num_virtual_data, num_classes, reg_const); 00103 00104 mnp.gmnp_imdm(vector_c, num_virtual_data, tmax, 00105 tolabs, tolrel, thlb, alpha, &t, &History, verb); 00106 00107 /* matrix alpha [num_classes x num_data] */ 00108 float64_t* all_alphas= SG_MALLOC(float64_t, num_classes*num_data); 00109 memset(all_alphas,0,num_classes*num_data*sizeof(float64_t)); 00110 00111 /* bias vector b [num_classes x 1] */ 00112 float64_t* all_bs=SG_MALLOC(float64_t, num_classes); 00113 memset(all_bs,0,num_classes*sizeof(float64_t)); 00114 00115 /* compute alpha/b from virt_data */ 00116 for(int32_t i=0; i < num_classes; i++ ) 00117 { 00118 for(int32_t j=0; j < num_virtual_data; j++ ) 00119 { 00120 int32_t inx1=0; 00121 int32_t inx2=0; 00122 00123 mnp.get_indices2( &inx1, &inx2, j ); 00124 00125 all_alphas[(inx1*num_classes)+i] += 00126 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2)); 00127 all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2)); 00128 } 00129 } 00130 00131 create_multiclass_svm(num_classes); 00132 00133 for (int32_t i=0; i<num_classes; i++) 00134 { 00135 int32_t num_sv=0; 00136 for (int32_t j=0; j<num_data; j++) 00137 { 00138 if (all_alphas[j*num_classes+i] != 0) 00139 num_sv++; 00140 } 00141 ASSERT(num_sv>0); 00142 SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]); 00143 00144 CSVM* svm=new CSVM(num_sv); 00145 00146 int32_t k=0; 00147 for (int32_t j=0; j<num_data; j++) 00148 { 00149 if (all_alphas[j*num_classes+i] != 0) 00150 { 00151 svm->set_alpha(k, all_alphas[j*num_classes+i]); 00152 svm->set_support_vector(k, j); 00153 k++; 00154 } 00155 } 00156 00157 svm->set_bias(all_bs[i]); 00158 set_svm(i, svm); 00159 } 00160 00161 if (m_basealphas != NULL) SG_FREE(m_basealphas); 00162 m_basealphas_y = num_classes, m_basealphas_x = num_data; 00163 m_basealphas = SG_MALLOC(float64_t, m_basealphas_y*m_basealphas_x); 00164 for (index_t i=0; i<m_basealphas_y*m_basealphas_x; i++) 00165 m_basealphas[i] = 0.0; 00166 00167 for(index_t j=0; j<num_virtual_data; j++) 00168 { 00169 index_t inx1=0, inx2=0; 00170 00171 mnp.get_indices2(&inx1, &inx2, j); 00172 m_basealphas[inx1*m_basealphas_y + (inx2-1)] = alpha[j]; 00173 } 00174 00175 SG_FREE(vector_c); 00176 SG_FREE(alpha); 00177 SG_FREE(all_alphas); 00178 SG_FREE(all_bs); 00179 SG_FREE(vector_y); 00180 SG_FREE(History); 00181 00182 return true; 00183 } 00184 00185 float64_t* 00186 CGMNPSVM::get_basealphas_ptr(index_t* y, index_t* x) 00187 { 00188 if (y == NULL || x == NULL) return NULL; 00189 00190 *y = m_basealphas_y, *x = m_basealphas_x; 00191 return m_basealphas; 00192 }