SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/classifier/svm/CPLEXSVM.h> 00012 #include <shogun/lib/common.h> 00013 00014 #ifdef USE_CPLEX 00015 #include <shogun/io/SGIO.h> 00016 #include <shogun/mathematics/Math.h> 00017 #include <shogun/mathematics/Cplex.h> 00018 #include <shogun/labels/Labels.h> 00019 00020 using namespace shogun; 00021 00022 CCPLEXSVM::CCPLEXSVM() 00023 : CSVM() 00024 { 00025 } 00026 00027 CCPLEXSVM::~CCPLEXSVM() 00028 { 00029 } 00030 00031 bool CCPLEXSVM::train_machine(CFeatures* data) 00032 { 00033 ASSERT(m_labels); 00034 ASSERT(m_labels->get_label_type() == LT_BINARY); 00035 00036 bool result = false; 00037 CCplex cplex; 00038 00039 if (data) 00040 { 00041 if (m_labels->get_num_labels() != data->get_num_vectors()) 00042 { 00043 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does" 00044 " not match number of labels (%d)\n", get_name(), 00045 data->get_num_vectors(), m_labels->get_num_labels()); 00046 } 00047 kernel->init(data, data); 00048 } 00049 00050 if (cplex.init(E_QP)) 00051 { 00052 int32_t n,m; 00053 int32_t num_label=0; 00054 SGVector<float64_t> y=((CBinaryLabels*)m_labels)->get_labels(); 00055 SGMatrix<float64_t> H=kernel->get_kernel_matrix(); 00056 m=H.num_rows; 00057 n=H.num_cols; 00058 ASSERT(n>0 && n==m && n==num_label); 00059 float64_t* alphas=SG_MALLOC(float64_t, n); 00060 float64_t* lb=SG_MALLOC(float64_t, n); 00061 float64_t* ub=SG_MALLOC(float64_t, n); 00062 00063 //hessian y'y.*K 00064 for (int32_t i=0; i<n; i++) 00065 { 00066 lb[i]=0; 00067 ub[i]=get_C1(); 00068 00069 for (int32_t j=0; j<n; j++) 00070 H[i*n+j]*=y[j]*y[i]; 00071 } 00072 00073 //feed qp to cplex 00074 00075 00076 int32_t j=0; 00077 for (int32_t i=0; i<n; i++) 00078 { 00079 if (alphas[i]>0) 00080 { 00081 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]); 00082 set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_int_label(i)); 00083 set_support_vector(j, i); 00084 j++; 00085 } 00086 } 00087 //compute_objective(); 00088 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias()); 00089 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors()); 00090 00091 SG_FREE(alphas); 00092 SG_FREE(lb); 00093 SG_FREE(ub); 00094 00095 result = true; 00096 } 00097 00098 if (!result) 00099 SG_ERROR( "cplex svm failed"); 00100 00101 return result; 00102 } 00103 #endif