SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #include <shogun/lib/common.h> 00012 00013 #ifdef HAVE_LAPACK 00014 00015 #include <shogun/multiclass/QDA.h> 00016 #include <shogun/machine/NativeMulticlassMachine.h> 00017 #include <shogun/features/Features.h> 00018 #include <shogun/labels/Labels.h> 00019 #include <shogun/labels/MulticlassLabels.h> 00020 #include <shogun/mathematics/Math.h> 00021 #include <shogun/mathematics/lapack.h> 00022 00023 using namespace shogun; 00024 00025 CQDA::CQDA(float64_t tolerance, bool store_covs) 00026 : CNativeMulticlassMachine(), m_tolerance(tolerance), 00027 m_store_covs(store_covs), m_num_classes(0), m_dim(0) 00028 { 00029 init(); 00030 } 00031 00032 CQDA::CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance, bool store_covs) 00033 : CNativeMulticlassMachine(), m_tolerance(tolerance), m_store_covs(store_covs), m_num_classes(0), m_dim(0) 00034 { 00035 init(); 00036 set_features(traindat); 00037 set_labels(trainlab); 00038 } 00039 00040 CQDA::~CQDA() 00041 { 00042 SG_UNREF(m_features); 00043 00044 cleanup(); 00045 } 00046 00047 void CQDA::init() 00048 { 00049 SG_ADD(&m_tolerance, "m_tolerance", "Tolerance member.", MS_AVAILABLE); 00050 SG_ADD(&m_store_covs, "m_store_covs", "Store covariances member", MS_NOT_AVAILABLE); 00051 SG_ADD((CSGObject**) &m_features, "m_features", "Feature object.", MS_NOT_AVAILABLE); 00052 SG_ADD(&m_means, "m_means", "Mean vectors list", MS_NOT_AVAILABLE); 00053 SG_ADD(&m_slog, "m_slog", "Vector used in classification", MS_NOT_AVAILABLE); 00054 00055 //TODO include SGNDArray objects for serialization 00056 00057 m_features = NULL; 00058 } 00059 00060 void CQDA::cleanup() 00061 { 00062 if ( m_store_covs ) 00063 m_covs.destroy_ndarray(); 00064 00065 m_covs.free_ndarray(); 00066 m_M.free_ndarray(); 00067 m_means=SGMatrix<float64_t>(); 00068 00069 m_num_classes = 0; 00070 } 00071 00072 CMulticlassLabels* CQDA::apply_multiclass(CFeatures* data) 00073 { 00074 if (data) 00075 { 00076 if (!data->has_property(FP_DOT)) 00077 SG_ERROR("Specified features are not of type CDotFeatures\n"); 00078 00079 set_features((CDotFeatures*) data); 00080 } 00081 00082 if ( !m_features ) 00083 return NULL; 00084 00085 int32_t num_vecs = m_features->get_num_vectors(); 00086 ASSERT(num_vecs > 0); 00087 ASSERT( m_dim == m_features->get_dim_feature_space() ); 00088 00089 CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features; 00090 00091 SGMatrix< float64_t > X(num_vecs, m_dim); 00092 SGMatrix< float64_t > A(num_vecs, m_dim); 00093 SGVector< float64_t > norm2(num_vecs*m_num_classes); 00094 norm2.zero(); 00095 00096 int i, j, k, vlen; 00097 bool vfree; 00098 float64_t* vec; 00099 for ( k = 0 ; k < m_num_classes ; ++k ) 00100 { 00101 // X = features - means 00102 for ( i = 0 ; i < num_vecs ; ++i ) 00103 { 00104 vec = rf->get_feature_vector(i, vlen, vfree); 00105 ASSERT(vec); 00106 00107 for ( j = 0 ; j < m_dim ; ++j ) 00108 X[i + j*num_vecs] = vec[j] - m_means[k*m_dim + j]; 00109 00110 rf->free_feature_vector(vec, i, vfree); 00111 00112 } 00113 00114 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, num_vecs, m_dim, 00115 m_dim, 1.0, X.matrix, num_vecs, m_M.get_matrix(k), m_dim, 0.0, 00116 A.matrix, num_vecs); 00117 00118 for ( i = 0 ; i < num_vecs ; ++i ) 00119 for ( j = 0 ; j < m_dim ; ++j ) 00120 norm2[i + k*num_vecs] += CMath::sq(A[i + j*num_vecs]); 00121 00122 #ifdef DEBUG_QDA 00123 CMath::display_matrix(A.matrix, num_vecs, m_dim, "A"); 00124 #endif 00125 } 00126 00127 for ( i = 0 ; i < num_vecs ; ++i ) 00128 for ( k = 0 ; k < m_num_classes ; ++k ) 00129 { 00130 norm2[i + k*num_vecs] += m_slog[k]; 00131 norm2[i + k*num_vecs] *= -0.5; 00132 } 00133 00134 CMulticlassLabels* out = new CMulticlassLabels(num_vecs); 00135 00136 for ( i = 0 ; i < num_vecs ; ++i ) 00137 out->set_label(i, SGVector<float64_t>::arg_max(norm2.vector+i, num_vecs, m_num_classes)); 00138 00139 #ifdef DEBUG_QDA 00140 CMath::display_matrix(norm2.vector, num_vecs, m_num_classes, "norm2"); 00141 CMath::display_vector(out->get_labels().vector, num_vecs, "Labels"); 00142 #endif 00143 00144 return out; 00145 } 00146 00147 bool CQDA::train_machine(CFeatures* data) 00148 { 00149 if ( !m_labels ) 00150 SG_ERROR("No labels allocated in QDA training\n"); 00151 00152 if ( data ) 00153 { 00154 if ( !data->has_property(FP_DOT) ) 00155 SG_ERROR("Speficied features are not of type CDotFeatures\n"); 00156 set_features((CDotFeatures*) data); 00157 } 00158 if ( !m_features ) 00159 SG_ERROR("No features allocated in QDA training\n"); 00160 SGVector< int32_t > train_labels = ((CMulticlassLabels*) m_labels)->get_int_labels(); 00161 if ( !train_labels.vector ) 00162 SG_ERROR("No train_labels allocated in QDA training\n"); 00163 00164 cleanup(); 00165 00166 m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00167 m_dim = m_features->get_dim_feature_space(); 00168 int32_t num_vec = m_features->get_num_vectors(); 00169 if ( num_vec != train_labels.vlen ) 00170 SG_ERROR("Dimension mismatch between features and labels in QDA training"); 00171 00172 int32_t* class_idxs = SG_MALLOC(int32_t, num_vec*m_num_classes); 00173 // number of examples of each class 00174 int32_t* class_nums = SG_MALLOC(int32_t, m_num_classes); 00175 memset(class_nums, 0, m_num_classes*sizeof(int32_t)); 00176 int32_t class_idx; 00177 int32_t i, j, k; 00178 for ( i = 0 ; i < train_labels.vlen ; ++i ) 00179 { 00180 class_idx = train_labels.vector[i]; 00181 00182 if ( class_idx < 0 || class_idx >= m_num_classes ) 00183 { 00184 SG_ERROR("found label out of {0, 1, 2, ..., num_classes-1}..."); 00185 return false; 00186 } 00187 else 00188 { 00189 class_idxs[ class_idx*num_vec + class_nums[class_idx]++ ] = i; 00190 } 00191 } 00192 00193 for ( i = 0 ; i < m_num_classes ; ++i ) 00194 { 00195 if ( class_nums[i] <= 0 ) 00196 { 00197 SG_ERROR("What? One class with no elements\n"); 00198 return false; 00199 } 00200 } 00201 00202 if ( m_store_covs ) 00203 { 00204 // cov_dims will be free in m_covs.destroy_ndarray() 00205 index_t * cov_dims = SG_MALLOC(index_t, 3); 00206 cov_dims[0] = m_dim; 00207 cov_dims[1] = m_dim; 00208 cov_dims[2] = m_num_classes; 00209 m_covs = SGNDArray< float64_t >(cov_dims, 3, true); 00210 } 00211 00212 m_means = SGMatrix< float64_t >(m_dim, m_num_classes, true); 00213 SGMatrix< float64_t > scalings = SGMatrix< float64_t >(m_dim, m_num_classes); 00214 00215 // rot_dims will be freed in rotations.destroy_ndarray() 00216 index_t* rot_dims = SG_MALLOC(index_t, 3); 00217 rot_dims[0] = m_dim; 00218 rot_dims[1] = m_dim; 00219 rot_dims[2] = m_num_classes; 00220 SGNDArray< float64_t > rotations = SGNDArray< float64_t >(rot_dims, 3); 00221 00222 CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features; 00223 00224 m_means.zero(); 00225 00226 int32_t vlen; 00227 bool vfree; 00228 float64_t* vec; 00229 for ( k = 0 ; k < m_num_classes ; ++k ) 00230 { 00231 SGMatrix< float64_t > buffer(class_nums[k], m_dim); 00232 for ( i = 0 ; i < class_nums[k] ; ++i ) 00233 { 00234 vec = rf->get_feature_vector(class_idxs[k*num_vec + i], vlen, vfree); 00235 ASSERT(vec); 00236 00237 for ( j = 0 ; j < vlen ; ++j ) 00238 { 00239 m_means[k*m_dim + j] += vec[j]; 00240 buffer[i + j*class_nums[k]] = vec[j]; 00241 } 00242 00243 rf->free_feature_vector(vec, class_idxs[k*num_vec + i], vfree); 00244 } 00245 00246 for ( j = 0 ; j < m_dim ; ++j ) 00247 m_means[k*m_dim + j] /= class_nums[k]; 00248 00249 for ( i = 0 ; i < class_nums[k] ; ++i ) 00250 for ( j = 0 ; j < m_dim ; ++j ) 00251 buffer[i + j*class_nums[k]] -= m_means[k*m_dim + j]; 00252 00253 /* calling external lib, buffer = U * S * V^T, U is not interesting here */ 00254 char jobu = 'N', jobvt = 'A'; 00255 int m = class_nums[k], n = m_dim; 00256 int lda = m, ldu = m, ldvt = n; 00257 int info = -1; 00258 float64_t * col = scalings.get_column_vector(k); 00259 float64_t * rot_mat = rotations.get_matrix(k); 00260 00261 wrap_dgesvd(jobu, jobvt, m, n, buffer.matrix, lda, col, NULL, ldu, 00262 rot_mat, ldvt, &info); 00263 ASSERT(info == 0); 00264 buffer=SGMatrix<float64_t>(); 00265 00266 SGVector<float64_t>::vector_multiply(col, col, col, m_dim); 00267 SGVector<float64_t>::scale_vector(1.0/(m-1), col, m_dim); 00268 rotations.transpose_matrix(k); 00269 00270 if ( m_store_covs ) 00271 { 00272 SGMatrix< float64_t > M(n ,n); 00273 00274 M.matrix = SGVector<float64_t>::clone_vector(rot_mat, n*n); 00275 for ( i = 0 ; i < m_dim ; ++i ) 00276 for ( j = 0 ; j < m_dim ; ++j ) 00277 M[i + j*m_dim] *= scalings[k*m_dim + j]; 00278 00279 cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, n, n, n, 1.0, 00280 M.matrix, n, rot_mat, n, 0.0, m_covs.get_matrix(k), n); 00281 } 00282 } 00283 00284 /* Computation of terms required for classification */ 00285 00286 SGVector< float32_t > sinvsqrt(m_dim); 00287 00288 // M_dims will be freed in m_M.destroy_ndarray() 00289 index_t* M_dims = SG_MALLOC(index_t, 3); 00290 M_dims[0] = m_dim; 00291 M_dims[1] = m_dim; 00292 M_dims[2] = m_num_classes; 00293 m_M = SGNDArray< float64_t >(M_dims, 3, true); 00294 00295 m_slog = SGVector< float32_t >(m_num_classes); 00296 m_slog.zero(); 00297 00298 index_t idx = 0; 00299 for ( k = 0 ; k < m_num_classes ; ++k ) 00300 { 00301 for ( j = 0 ; j < m_dim ; ++j ) 00302 { 00303 sinvsqrt[j] = 1.0 / CMath::sqrt(scalings[k*m_dim + j]); 00304 m_slog[k] += CMath::log(scalings[k*m_dim + j]); 00305 } 00306 00307 for ( i = 0 ; i < m_dim ; ++i ) 00308 for ( j = 0 ; j < m_dim ; ++j ) 00309 { 00310 idx = k*m_dim*m_dim + i + j*m_dim; 00311 m_M[idx] = rotations[idx] * sinvsqrt[j]; 00312 } 00313 } 00314 00315 #ifdef DEBUG_QDA 00316 SG_PRINT(">>> QDA machine trained with %d classes\n", m_num_classes); 00317 00318 SG_PRINT("\n>>> Displaying means ...\n"); 00319 CMath::display_matrix(m_means.matrix, m_dim, m_num_classes); 00320 00321 SG_PRINT("\n>>> Displaying scalings ...\n"); 00322 CMath::display_matrix(scalings.matrix, m_dim, m_num_classes); 00323 00324 SG_PRINT("\n>>> Displaying rotations ... \n"); 00325 for ( k = 0 ; k < m_num_classes ; ++k ) 00326 CMath::display_matrix(rotations.get_matrix(k), m_dim, m_dim); 00327 00328 SG_PRINT("\n>>> Displaying sinvsqrt ... \n"); 00329 sinvsqrt.display_vector(); 00330 00331 SG_PRINT("\n>>> Diplaying m_M matrices ... \n"); 00332 for ( k = 0 ; k < m_num_classes ; ++k ) 00333 CMath::display_matrix(m_M.get_matrix(k), m_dim, m_dim); 00334 00335 SG_PRINT("\n>>> Exit DEBUG_QDA\n"); 00336 #endif 00337 00338 rotations.destroy_ndarray(); 00339 SG_FREE(class_idxs); 00340 SG_FREE(class_nums); 00341 return true; 00342 } 00343 00344 #endif /* HAVE_LAPACK */