SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #ifdef USE_MOSEK 00012 00013 #include <shogun/lib/DynamicObjectArray.h> 00014 #include <shogun/lib/List.h> 00015 #include <shogun/mathematics/Math.h> 00016 #include <shogun/structure/PrimalMosekSOSVM.h> 00017 00018 using namespace shogun; 00019 00020 CPrimalMosekSOSVM::CPrimalMosekSOSVM() 00021 : CLinearStructuredOutputMachine() 00022 { 00023 } 00024 00025 CPrimalMosekSOSVM::CPrimalMosekSOSVM( 00026 CStructuredModel* model, 00027 CLossFunction* loss, 00028 CStructuredLabels* labs) 00029 : CLinearStructuredOutputMachine(model, loss, labs) 00030 { 00031 } 00032 00033 void CPrimalMosekSOSVM::init() 00034 { 00035 SG_ADD(&m_slacks, "m_slacks", "Slacks vector", MS_NOT_AVAILABLE); 00036 } 00037 00038 CPrimalMosekSOSVM::~CPrimalMosekSOSVM() 00039 { 00040 } 00041 00042 bool CPrimalMosekSOSVM::train_machine(CFeatures* data) 00043 { 00044 if (data) 00045 set_features(data); 00046 00047 CFeatures* model_features = get_features(); 00048 // Check that the scenary is correct to start with training 00049 m_model->check_training_setup(); 00050 00051 // Dimensionality of the joint feature space 00052 int32_t M = m_model->get_dim(); 00053 // Number of auxiliary variables in the optimization vector 00054 int32_t num_aux = m_model->get_num_aux(); 00055 // Number of auxiliary constraints 00056 int32_t num_aux_con = m_model->get_num_aux_con(); 00057 // Number of training examples 00058 int32_t N = model_features->get_num_vectors(); 00059 00060 // Interface with MOSEK 00061 CMosek* mosek = new CMosek(0, M+num_aux+N); 00062 SG_REF(mosek); 00063 if ( mosek->get_rescode() != MSK_RES_OK ) 00064 { 00065 SG_PRINT("Mosek object could not be properly created..." 00066 "aborting training of PrimalMosekSOSVM\n"); 00067 00068 return false; 00069 } 00070 00071 // Initialize the terms of the optimization problem 00072 SGMatrix< float64_t > A, B, C; 00073 SGVector< float64_t > a, b, lb, ub; 00074 m_model->init_opt(A, a, B, b, lb, ub, C); 00075 00076 // Input terms of the problem that do not change between iterations 00077 if ( mosek->init_sosvm(M, N, num_aux, num_aux_con, C, lb, ub, A, b) != MSK_RES_OK ) 00078 { 00079 // MOSEK error took place 00080 return false; 00081 } 00082 00083 // Initialize the weight vector 00084 m_w = SGVector< float64_t >(M); 00085 m_w.zero(); 00086 00087 m_slacks = SGVector< float64_t >(N); 00088 m_slacks.zero(); 00089 00090 // Initialize the list of constraints 00091 // Each element in results is a list of CResultSet with the constraints 00092 // associated to each training example 00093 CDynamicObjectArray* results = new CDynamicObjectArray(N); 00094 SG_REF(results); 00095 for ( int32_t i = 0 ; i < N ; ++i ) 00096 { 00097 CList* list = new CList(true); 00098 results->push_back(list); 00099 } 00100 00101 // Initialize variables used in the loop 00102 int32_t num_con = num_aux_con; // number of constraints 00103 int32_t old_num_con = num_con; 00104 float64_t slack = 0.0; 00105 float64_t max_slack = 0.0; 00106 CResultSet* result = NULL; 00107 CResultSet* cur_res = NULL; 00108 CList* cur_list = NULL; 00109 bool exception = false; 00110 00111 SGVector< float64_t > sol(M+num_aux+N); 00112 sol.zero(); 00113 00114 SGVector< float64_t > aux(num_aux); 00115 00116 do 00117 { 00118 old_num_con = num_con; 00119 00120 for ( int32_t i = 0 ; i < N ; ++i ) 00121 { 00122 // Predict the result of the ith training example 00123 result = m_model->argmax(m_w, i); 00124 00125 // Compute the loss associated with the prediction 00126 slack = m_loss->loss( compute_loss_arg(result) ); 00127 cur_list = (CList*) results->get_element(i); 00128 00129 // Update the list of constraints 00130 if ( cur_list->get_num_elements() > 0 ) 00131 { 00132 // Find the maximum loss within the elements of 00133 // the list of constraints 00134 cur_res = (CResultSet*) cur_list->get_first_element(); 00135 max_slack = -CMath::INFTY; 00136 00137 while ( cur_res != NULL ) 00138 { 00139 max_slack = CMath::max(max_slack, 00140 m_loss->loss( compute_loss_arg(cur_res) )); 00141 00142 SG_UNREF(cur_res); 00143 cur_res = (CResultSet*) cur_list->get_next_element(); 00144 } 00145 00146 if ( slack > max_slack ) 00147 { 00148 // The current training example is a 00149 // violated constraint 00150 if ( ! insert_result(cur_list, result) ) 00151 { 00152 exception = true; 00153 break; 00154 } 00155 00156 add_constraint(mosek, result, num_con, i); 00157 ++num_con; 00158 } 00159 } 00160 else 00161 { 00162 // First iteration of do ... while, add constraint 00163 if ( ! insert_result(cur_list, result) ) 00164 { 00165 exception = true; 00166 break; 00167 } 00168 00169 add_constraint(mosek, result, num_con, i); 00170 ++num_con; 00171 } 00172 00173 SG_UNREF(cur_list); 00174 SG_UNREF(result); 00175 } 00176 00177 // Solve the QP 00178 mosek->optimize(sol); 00179 for ( int32_t i = 0 ; i < M+num_aux+N ; ++i ) 00180 { 00181 if ( i < M ) 00182 m_w[i] = sol[i]; 00183 else if ( i < M+num_aux ) 00184 aux[i-M] = sol[i]; 00185 else 00186 m_slacks[i-M-num_aux] = sol[i]; 00187 } 00188 00189 } while ( old_num_con != num_con && ! exception ); 00190 00191 // Free resources 00192 SG_UNREF(results); 00193 SG_UNREF(mosek); 00194 SG_UNREF(model_features); 00195 return true; 00196 } 00197 00198 float64_t CPrimalMosekSOSVM::compute_loss_arg(CResultSet* result) const 00199 { 00200 // Dimensionality of the joint feature space 00201 int32_t M = m_w.vlen; 00202 00203 return SGVector< float64_t >::dot(m_w.vector, result->psi_pred.vector, M) + 00204 result->delta - 00205 SGVector< float64_t >::dot(m_w.vector, result->psi_truth.vector, M); 00206 } 00207 00208 bool CPrimalMosekSOSVM::insert_result(CList* result_list, CResultSet* result) const 00209 { 00210 bool succeed = result_list->insert_element(result); 00211 00212 if ( ! succeed ) 00213 { 00214 SG_PRINT("ResultSet could not be inserted in the list..." 00215 "aborting training of PrimalMosekSOSVM\n"); 00216 } 00217 00218 return succeed; 00219 } 00220 00221 bool CPrimalMosekSOSVM::add_constraint( 00222 CMosek* mosek, 00223 CResultSet* result, 00224 index_t con_idx, 00225 index_t train_idx) const 00226 { 00227 int32_t M = m_model->get_dim(); 00228 SGVector< float64_t > dPsi(M); 00229 00230 for ( int i = 0 ; i < M ; ++i ) 00231 dPsi[i] = result->psi_pred[i] - result->psi_truth[i]; 00232 00233 return ( mosek->add_constraint_sosvm(dPsi, con_idx, train_idx, 00234 m_model->get_num_aux(), -result->delta) == MSK_RES_OK ); 00235 } 00236 00237 #endif /* USE_MOSEK */