SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
QDA.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 
00013 #ifdef HAVE_LAPACK
00014 
00015 #include <shogun/multiclass/QDA.h>
00016 #include <shogun/machine/NativeMulticlassMachine.h>
00017 #include <shogun/features/Features.h>
00018 #include <shogun/labels/Labels.h>
00019 #include <shogun/labels/MulticlassLabels.h>
00020 #include <shogun/mathematics/Math.h>
00021 #include <shogun/mathematics/lapack.h>
00022 
00023 using namespace shogun;
00024 
00025 CQDA::CQDA(float64_t tolerance, bool store_covs)
00026 : CNativeMulticlassMachine(), m_tolerance(tolerance), 
00027     m_store_covs(store_covs), m_num_classes(0), m_dim(0)
00028 {
00029     init();
00030 }
00031 
00032 CQDA::CQDA(CDenseFeatures<float64_t>* traindat, CLabels* trainlab, float64_t tolerance, bool store_covs)
00033 : CNativeMulticlassMachine(), m_tolerance(tolerance), m_store_covs(store_covs), m_num_classes(0), m_dim(0)
00034 {
00035     init();
00036     set_features(traindat);
00037     set_labels(trainlab);
00038 }
00039 
00040 CQDA::~CQDA()
00041 {
00042     SG_UNREF(m_features);
00043 
00044     cleanup();
00045 }
00046 
00047 void CQDA::init()
00048 {
00049     SG_ADD(&m_tolerance, "m_tolerance", "Tolerance member.", MS_AVAILABLE);
00050     SG_ADD(&m_store_covs, "m_store_covs", "Store covariances member", MS_NOT_AVAILABLE);
00051     SG_ADD((CSGObject**) &m_features, "m_features", "Feature object.", MS_NOT_AVAILABLE);
00052     SG_ADD(&m_means, "m_means", "Mean vectors list", MS_NOT_AVAILABLE);
00053     SG_ADD(&m_slog, "m_slog", "Vector used in classification", MS_NOT_AVAILABLE);
00054 
00055     //TODO include SGNDArray objects for serialization
00056 
00057     m_features  = NULL;
00058 }
00059 
00060 void CQDA::cleanup()
00061 {
00062     if ( m_store_covs )
00063         m_covs.destroy_ndarray();
00064 
00065     m_covs.free_ndarray();
00066     m_M.free_ndarray();
00067     m_means=SGMatrix<float64_t>();
00068 
00069     m_num_classes = 0;
00070 }
00071 
00072 CMulticlassLabels* CQDA::apply_multiclass(CFeatures* data)
00073 {
00074     if (data)
00075     {
00076         if (!data->has_property(FP_DOT))
00077             SG_ERROR("Specified features are not of type CDotFeatures\n");
00078 
00079         set_features((CDotFeatures*) data);
00080     }
00081 
00082     if ( !m_features )
00083         return NULL;
00084 
00085     int32_t num_vecs = m_features->get_num_vectors();
00086     ASSERT(num_vecs > 0);
00087     ASSERT( m_dim == m_features->get_dim_feature_space() );
00088 
00089     CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features;
00090 
00091     SGMatrix< float64_t > X(num_vecs, m_dim);
00092     SGMatrix< float64_t > A(num_vecs, m_dim);
00093     SGVector< float64_t > norm2(num_vecs*m_num_classes);
00094     norm2.zero();
00095 
00096     int i, j, k, vlen;
00097     bool vfree;
00098     float64_t* vec;
00099     for ( k = 0 ; k < m_num_classes ; ++k )
00100     {
00101         // X = features - means
00102         for ( i = 0 ; i < num_vecs ; ++i )
00103         {
00104             vec = rf->get_feature_vector(i, vlen, vfree);
00105             ASSERT(vec);
00106 
00107             for ( j = 0 ; j < m_dim ; ++j )
00108                 X[i + j*num_vecs] = vec[j] - m_means[k*m_dim + j];
00109 
00110             rf->free_feature_vector(vec, i, vfree);
00111 
00112         }
00113 
00114         cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, num_vecs, m_dim,
00115             m_dim, 1.0, X.matrix, num_vecs, m_M.get_matrix(k), m_dim, 0.0,
00116             A.matrix, num_vecs);
00117 
00118         for ( i = 0 ; i < num_vecs ; ++i )
00119             for ( j = 0 ; j < m_dim ; ++j )
00120                 norm2[i + k*num_vecs] += CMath::sq(A[i + j*num_vecs]);
00121 
00122 #ifdef DEBUG_QDA
00123     CMath::display_matrix(A.matrix, num_vecs, m_dim, "A");
00124 #endif
00125     }
00126 
00127     for ( i = 0 ; i < num_vecs ; ++i )
00128         for ( k = 0 ; k < m_num_classes ; ++k )
00129         {
00130             norm2[i + k*num_vecs] += m_slog[k];
00131             norm2[i + k*num_vecs] *= -0.5;
00132         }
00133 
00134     CMulticlassLabels* out = new CMulticlassLabels(num_vecs);
00135 
00136     for ( i = 0 ; i < num_vecs ; ++i )
00137         out->set_label(i, SGVector<float64_t>::arg_max(norm2.vector+i, num_vecs, m_num_classes));
00138 
00139 #ifdef DEBUG_QDA
00140     CMath::display_matrix(norm2.vector, num_vecs, m_num_classes, "norm2");
00141     CMath::display_vector(out->get_labels().vector, num_vecs, "Labels");
00142 #endif
00143 
00144     return out;
00145 }
00146 
00147 bool CQDA::train_machine(CFeatures* data)
00148 {
00149     if ( !m_labels )
00150         SG_ERROR("No labels allocated in QDA training\n");
00151 
00152     if ( data )
00153     {
00154         if ( !data->has_property(FP_DOT) )
00155             SG_ERROR("Speficied features are not of type CDotFeatures\n");
00156         set_features((CDotFeatures*) data);
00157     }
00158     if ( !m_features )
00159         SG_ERROR("No features allocated in QDA training\n");
00160     SGVector< int32_t > train_labels = ((CMulticlassLabels*) m_labels)->get_int_labels();
00161     if ( !train_labels.vector )
00162         SG_ERROR("No train_labels allocated in QDA training\n");
00163 
00164     cleanup();
00165 
00166     m_num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00167     m_dim = m_features->get_dim_feature_space();
00168     int32_t num_vec  = m_features->get_num_vectors();
00169     if ( num_vec != train_labels.vlen )
00170         SG_ERROR("Dimension mismatch between features and labels in QDA training");
00171 
00172     int32_t* class_idxs = SG_MALLOC(int32_t, num_vec*m_num_classes);
00173     // number of examples of each class
00174     int32_t* class_nums = SG_MALLOC(int32_t, m_num_classes);
00175     memset(class_nums, 0, m_num_classes*sizeof(int32_t));
00176     int32_t class_idx;
00177     int32_t i, j, k;
00178     for ( i = 0 ; i < train_labels.vlen ; ++i )
00179     {
00180         class_idx = train_labels.vector[i];
00181 
00182         if ( class_idx < 0 || class_idx >= m_num_classes )
00183         {
00184             SG_ERROR("found label out of {0, 1, 2, ..., num_classes-1}...");
00185             return false;
00186         }
00187         else
00188         {
00189             class_idxs[ class_idx*num_vec + class_nums[class_idx]++ ] = i;
00190         }
00191     }
00192 
00193     for ( i = 0 ; i < m_num_classes ; ++i )
00194     {
00195         if ( class_nums[i] <= 0 )
00196         {
00197             SG_ERROR("What? One class with no elements\n");
00198             return false;
00199         }
00200     }
00201 
00202     if ( m_store_covs )
00203     {
00204         // cov_dims will be free in m_covs.destroy_ndarray()
00205         index_t * cov_dims = SG_MALLOC(index_t, 3);
00206         cov_dims[0] = m_dim;
00207         cov_dims[1] = m_dim;
00208         cov_dims[2] = m_num_classes;
00209         m_covs = SGNDArray< float64_t >(cov_dims, 3, true);
00210     }
00211 
00212     m_means = SGMatrix< float64_t >(m_dim, m_num_classes, true);
00213     SGMatrix< float64_t > scalings  = SGMatrix< float64_t >(m_dim, m_num_classes);
00214 
00215     // rot_dims will be freed in rotations.destroy_ndarray()
00216     index_t* rot_dims = SG_MALLOC(index_t, 3);
00217     rot_dims[0] = m_dim;
00218     rot_dims[1] = m_dim;
00219     rot_dims[2] = m_num_classes;
00220     SGNDArray< float64_t > rotations = SGNDArray< float64_t >(rot_dims, 3);
00221 
00222     CDenseFeatures< float64_t >* rf = (CDenseFeatures< float64_t >*) m_features;
00223 
00224     m_means.zero();
00225 
00226     int32_t vlen;
00227     bool vfree;
00228     float64_t* vec;
00229     for ( k = 0 ; k < m_num_classes ; ++k )
00230     {
00231         SGMatrix< float64_t > buffer(class_nums[k], m_dim);
00232         for ( i = 0 ; i < class_nums[k] ; ++i )
00233         {
00234             vec = rf->get_feature_vector(class_idxs[k*num_vec + i], vlen, vfree);
00235             ASSERT(vec);
00236 
00237             for ( j = 0 ; j < vlen ; ++j )
00238             {
00239                 m_means[k*m_dim + j] += vec[j];
00240                 buffer[i + j*class_nums[k]] = vec[j];
00241             }
00242 
00243             rf->free_feature_vector(vec, class_idxs[k*num_vec + i], vfree);
00244         }
00245 
00246         for ( j = 0 ; j < m_dim ; ++j )
00247             m_means[k*m_dim + j] /= class_nums[k];
00248 
00249         for ( i = 0 ; i < class_nums[k] ; ++i )
00250             for ( j = 0 ; j < m_dim ; ++j )
00251                 buffer[i + j*class_nums[k]] -= m_means[k*m_dim + j];
00252 
00253         /* calling external lib, buffer = U * S * V^T, U is not interesting here */
00254         char jobu = 'N', jobvt = 'A';
00255         int m = class_nums[k], n = m_dim;
00256         int lda = m, ldu = m, ldvt = n;
00257         int info = -1;
00258         float64_t * col = scalings.get_column_vector(k);
00259         float64_t * rot_mat = rotations.get_matrix(k);
00260 
00261         wrap_dgesvd(jobu, jobvt, m, n, buffer.matrix, lda, col, NULL, ldu,
00262             rot_mat, ldvt, &info);
00263         ASSERT(info == 0);
00264         buffer=SGMatrix<float64_t>();
00265 
00266         SGVector<float64_t>::vector_multiply(col, col, col, m_dim);
00267         SGVector<float64_t>::scale_vector(1.0/(m-1), col, m_dim);
00268         rotations.transpose_matrix(k);
00269 
00270         if ( m_store_covs )
00271         {
00272             SGMatrix< float64_t > M(n ,n);
00273 
00274             M.matrix = SGVector<float64_t>::clone_vector(rot_mat, n*n);
00275             for ( i = 0 ; i < m_dim ; ++i )
00276                 for ( j = 0 ; j < m_dim ; ++j )
00277                     M[i + j*m_dim] *= scalings[k*m_dim + j];
00278 
00279             cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, n, n, n, 1.0,
00280                 M.matrix, n, rot_mat, n, 0.0, m_covs.get_matrix(k), n);
00281         }
00282     }
00283 
00284     /* Computation of terms required for classification */
00285 
00286     SGVector< float32_t > sinvsqrt(m_dim);
00287 
00288     // M_dims will be freed in m_M.destroy_ndarray()
00289     index_t* M_dims = SG_MALLOC(index_t, 3);
00290     M_dims[0] = m_dim;
00291     M_dims[1] = m_dim;
00292     M_dims[2] = m_num_classes;
00293     m_M = SGNDArray< float64_t >(M_dims, 3, true);
00294 
00295     m_slog = SGVector< float32_t >(m_num_classes);
00296     m_slog.zero();
00297 
00298     index_t idx = 0;
00299     for ( k = 0 ; k < m_num_classes ; ++k )
00300     {
00301         for ( j = 0 ; j < m_dim ; ++j )
00302         {
00303             sinvsqrt[j] = 1.0 / CMath::sqrt(scalings[k*m_dim + j]);
00304             m_slog[k]  += CMath::log(scalings[k*m_dim + j]);
00305         }
00306 
00307         for ( i = 0 ; i < m_dim ; ++i )
00308             for ( j = 0 ; j < m_dim ; ++j )
00309             {
00310                 idx = k*m_dim*m_dim + i + j*m_dim;
00311                 m_M[idx] = rotations[idx] * sinvsqrt[j];
00312             }
00313     }
00314 
00315 #ifdef DEBUG_QDA
00316     SG_PRINT(">>> QDA machine trained with %d classes\n", m_num_classes);
00317 
00318     SG_PRINT("\n>>> Displaying means ...\n");
00319     CMath::display_matrix(m_means.matrix, m_dim, m_num_classes);
00320 
00321     SG_PRINT("\n>>> Displaying scalings ...\n");
00322     CMath::display_matrix(scalings.matrix, m_dim, m_num_classes);
00323 
00324     SG_PRINT("\n>>> Displaying rotations ... \n");
00325     for ( k = 0 ; k < m_num_classes ; ++k )
00326         CMath::display_matrix(rotations.get_matrix(k), m_dim, m_dim);
00327 
00328     SG_PRINT("\n>>> Displaying sinvsqrt ... \n");
00329     sinvsqrt.display_vector();
00330 
00331     SG_PRINT("\n>>> Diplaying m_M matrices ... \n");
00332     for ( k = 0 ; k < m_num_classes ; ++k )
00333         CMath::display_matrix(m_M.get_matrix(k), m_dim, m_dim);
00334 
00335     SG_PRINT("\n>>> Exit DEBUG_QDA\n");
00336 #endif
00337 
00338     rotations.destroy_ndarray();
00339     SG_FREE(class_idxs);
00340     SG_FREE(class_nums);
00341     return true;
00342 }
00343 
00344 #endif /* HAVE_LAPACK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation