SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Sergey Lisitsyn 00008 * Copyright (C) 2011 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/multiclass/ConjugateIndex.h> 00012 #ifdef HAVE_LAPACK 00013 #include <shogun/machine/Machine.h> 00014 #include <shogun/features/Features.h> 00015 #include <shogun/labels/Labels.h> 00016 #include <shogun/labels/MulticlassLabels.h> 00017 #include <shogun/mathematics/lapack.h> 00018 #include <shogun/mathematics/Math.h> 00019 #include <shogun/lib/Signal.h> 00020 00021 using namespace shogun; 00022 00023 CConjugateIndex::CConjugateIndex() : CMachine() 00024 { 00025 m_classes = NULL; 00026 m_features = NULL; 00027 }; 00028 00029 CConjugateIndex::CConjugateIndex(CFeatures* train_features, CLabels* train_labels) : CMachine() 00030 { 00031 m_features = NULL; 00032 set_features(train_features); 00033 set_labels(train_labels); 00034 m_classes = NULL; 00035 }; 00036 00037 CConjugateIndex::~CConjugateIndex() 00038 { 00039 clean_classes(); 00040 SG_UNREF(m_features); 00041 }; 00042 00043 void CConjugateIndex::set_features(CFeatures* features) 00044 { 00045 ASSERT(features->get_feature_class()==C_DENSE); 00046 SG_REF(features); 00047 SG_UNREF(m_features); 00048 m_features = (CDenseFeatures<float64_t>*)features; 00049 } 00050 00051 CDenseFeatures<float64_t>* CConjugateIndex::get_features() 00052 { 00053 SG_REF(m_features); 00054 return m_features; 00055 } 00056 00057 void CConjugateIndex::clean_classes() 00058 { 00059 if (m_classes) 00060 { 00061 for (int32_t i=0; i<m_num_classes; i++) 00062 m_classes[i]=SGMatrix<float64_t>(); 00063 00064 delete[] m_classes; 00065 } 00066 } 00067 00068 bool CConjugateIndex::train_machine(CFeatures* data) 00069 { 00070 if (data) 00071 set_features(data); 00072 00073 ASSERT(m_labels); 00074 ASSERT(m_labels->get_label_type()==LT_MULTICLASS); 00075 00076 m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00077 ASSERT(m_num_classes>=2); 00078 clean_classes(); 00079 00080 int32_t num_vectors; 00081 int32_t num_features; 00082 float64_t* feature_matrix = m_features->get_feature_matrix(num_features,num_vectors); 00083 00084 m_classes = new SGMatrix<float64_t>[m_num_classes](); 00085 for (int32_t i=0; i<m_num_classes; i++) 00086 m_classes[i] = SGMatrix<float64_t>(num_features,num_features); 00087 00088 m_feature_vector = SGVector<float64_t>(num_features); 00089 00090 SG_PROGRESS(0,0,m_num_classes-1); 00091 00092 for (int32_t label=0; label<m_num_classes; label++) 00093 { 00094 int32_t count = 0; 00095 for (int32_t i=0; i<num_vectors; i++) 00096 { 00097 if (((CMulticlassLabels*) m_labels)->get_int_label(i) == label) 00098 count++; 00099 } 00100 00101 SGMatrix<float64_t> class_feature_matrix(num_features,count); 00102 SGMatrix<float64_t> matrix(count,count); 00103 SGMatrix<float64_t> helper_matrix(num_features,count); 00104 00105 count = 0; 00106 for (int32_t i=0; i<num_vectors; i++) 00107 { 00108 if (((CMulticlassLabels*) m_labels)->get_label(i) == label) 00109 { 00110 memcpy(class_feature_matrix.matrix+count*num_features, 00111 feature_matrix+i*num_features, 00112 sizeof(float64_t)*num_features); 00113 count++; 00114 } 00115 } 00116 00117 cblas_dgemm(CblasColMajor,CblasTrans,CblasNoTrans, 00118 count,count,num_features, 00119 1.0,class_feature_matrix.matrix,num_features, 00120 class_feature_matrix.matrix,num_features, 00121 0.0,matrix.matrix,count); 00122 00123 SGMatrix<float64_t>::inverse(matrix); 00124 00125 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasTrans, 00126 count,num_features,count, 00127 1.0,matrix.matrix,count, 00128 class_feature_matrix.matrix,num_features, 00129 0.0,helper_matrix.matrix,count); 00130 00131 cblas_dgemm(CblasColMajor,CblasNoTrans,CblasNoTrans, 00132 num_features,num_features,count, 00133 1.0,class_feature_matrix.matrix,num_features, 00134 helper_matrix.matrix,count, 00135 0.0,m_classes[label].matrix,num_features); 00136 00137 SG_PROGRESS(label+1,0,m_num_classes); 00138 } 00139 SG_DONE(); 00140 00141 return true; 00142 }; 00143 00144 CMulticlassLabels* CConjugateIndex::apply_multiclass(CFeatures* data) 00145 { 00146 if (data) 00147 set_features(data); 00148 00149 ASSERT(m_features); 00150 00151 ASSERT(m_classes); 00152 ASSERT(m_num_classes>1); 00153 ASSERT(m_features->get_num_features()==m_feature_vector.vlen); 00154 00155 int32_t num_vectors = m_features->get_num_vectors(); 00156 00157 CMulticlassLabels* predicted_labels = new CMulticlassLabels(num_vectors); 00158 00159 for (int32_t i=0; i<num_vectors;i++) 00160 { 00161 SG_PROGRESS(i,0,num_vectors-1); 00162 predicted_labels->set_label(i,apply_one(i)); 00163 } 00164 SG_DONE(); 00165 00166 return predicted_labels; 00167 }; 00168 00169 float64_t CConjugateIndex::conjugate_index(SGVector<float64_t> feature_vector, int32_t label) 00170 { 00171 int32_t num_features = feature_vector.vlen; 00172 float64_t norm = cblas_ddot(num_features,feature_vector.vector,1, 00173 feature_vector.vector,1); 00174 00175 cblas_dgemv(CblasColMajor,CblasNoTrans, 00176 num_features,num_features, 00177 1.0,m_classes[label].matrix,num_features, 00178 feature_vector.vector,1, 00179 0.0,m_feature_vector.vector,1); 00180 00181 float64_t product = cblas_ddot(num_features,feature_vector.vector,1, 00182 m_feature_vector.vector,1); 00183 return product/norm; 00184 }; 00185 00186 float64_t CConjugateIndex::apply_one(int32_t index) 00187 { 00188 int32_t predicted_label = 0; 00189 float64_t max_conjugate_index = 0.0; 00190 float64_t current_conjugate_index; 00191 00192 SGVector<float64_t> feature_vector = m_features->get_feature_vector(index); 00193 for (int32_t i=0; i<m_num_classes; i++) 00194 { 00195 current_conjugate_index = conjugate_index(feature_vector,i); 00196 00197 if (current_conjugate_index > max_conjugate_index) 00198 { 00199 max_conjugate_index = current_conjugate_index; 00200 predicted_label = i; 00201 } 00202 } 00203 00204 return predicted_label; 00205 }; 00206 00207 #endif /* HAVE_LAPACK */