SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskLogisticRegression.h> 00011 #include <shogun/lib/slep/slep_solver.h> 00012 #include <shogun/lib/slep/slep_options.h> 00013 00014 namespace shogun 00015 { 00016 00017 CMultitaskLogisticRegression::CMultitaskLogisticRegression() : 00018 CMultitaskLinearMachine() 00019 { 00020 initialize_parameters(); 00021 register_parameters(); 00022 } 00023 00024 CMultitaskLogisticRegression::CMultitaskLogisticRegression( 00025 float64_t z, CDotFeatures* train_features, 00026 CBinaryLabels* train_labels, CTaskRelation* task_relation) : 00027 CMultitaskLinearMachine(train_features,(CLabels*)train_labels,task_relation) 00028 { 00029 initialize_parameters(); 00030 register_parameters(); 00031 set_z(z); 00032 } 00033 00034 CMultitaskLogisticRegression::~CMultitaskLogisticRegression() 00035 { 00036 } 00037 00038 void CMultitaskLogisticRegression::register_parameters() 00039 { 00040 SG_ADD(&m_z, "z", "regularization coefficient", MS_AVAILABLE); 00041 SG_ADD(&m_q, "q", "q of L1/Lq", MS_AVAILABLE); 00042 SG_ADD(&m_termination, "termination", "termination", MS_NOT_AVAILABLE); 00043 SG_ADD(&m_regularization, "regularization", "regularization", MS_NOT_AVAILABLE); 00044 SG_ADD(&m_tolerance, "tolerance", "tolerance", MS_NOT_AVAILABLE); 00045 SG_ADD(&m_max_iter, "max_iter", "maximum number of iterations", MS_NOT_AVAILABLE); 00046 } 00047 00048 void CMultitaskLogisticRegression::initialize_parameters() 00049 { 00050 set_z(0.0); 00051 set_q(2.0); 00052 set_termination(0); 00053 set_regularization(0); 00054 set_tolerance(1e-3); 00055 set_max_iter(1000); 00056 } 00057 00058 bool CMultitaskLogisticRegression::train_machine(CFeatures* data) 00059 { 00060 if (data && (CDotFeatures*)data) 00061 set_features((CDotFeatures*)data); 00062 00063 ASSERT(features); 00064 ASSERT(m_labels); 00065 00066 SGVector<float64_t> y(m_labels->get_num_labels()); 00067 for (int32_t i=0; i<y.vlen; i++) 00068 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00069 00070 slep_options options = slep_options::default_options(); 00071 options.n_tasks = m_task_relation->get_num_tasks(); 00072 options.tasks_indices = m_task_relation->get_tasks_indices(); 00073 options.q = m_q; 00074 options.regularization = m_regularization; 00075 options.termination = m_termination; 00076 options.tolerance = m_tolerance; 00077 options.max_iter = m_max_iter; 00078 00079 ETaskRelationType relation_type = m_task_relation->get_relation_type(); 00080 switch (relation_type) 00081 { 00082 case TASK_GROUP: 00083 { 00084 //CTaskGroup* task_group = (CTaskGroup*)m_task_relation; 00085 options.mode = MULTITASK_GROUP; 00086 options.loss = LOGISTIC; 00087 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00088 m_tasks_w = result.w; 00089 m_tasks_c = result.c; 00090 } 00091 break; 00092 case TASK_TREE: 00093 { 00094 CTaskTree* task_tree = (CTaskTree*)m_task_relation; 00095 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t(); 00096 options.ind_t = ind_t.vector; 00097 options.n_nodes = ind_t.vlen / 3; 00098 options.mode = MULTITASK_TREE; 00099 options.loss = LOGISTIC; 00100 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00101 m_tasks_w = result.w; 00102 m_tasks_c = result.c; 00103 } 00104 break; 00105 default: 00106 SG_ERROR("Not supported task relation type\n"); 00107 } 00108 for (int32_t i=0; i<options.n_tasks; i++) 00109 options.tasks_indices[i].~SGVector<index_t>(); 00110 SG_FREE(options.tasks_indices); 00111 00112 return true; 00113 } 00114 00115 bool CMultitaskLogisticRegression::train_locked_implementation(SGVector<index_t>* tasks) 00116 { 00117 ASSERT(features); 00118 ASSERT(m_labels); 00119 00120 SGVector<float64_t> y(m_labels->get_num_labels()); 00121 for (int32_t i=0; i<y.vlen; i++) 00122 y[i] = ((CBinaryLabels*)m_labels)->get_label(i); 00123 00124 slep_options options = slep_options::default_options(); 00125 options.n_tasks = m_task_relation->get_num_tasks(); 00126 options.tasks_indices = tasks; 00127 options.q = m_q; 00128 options.regularization = m_regularization; 00129 options.termination = m_termination; 00130 options.tolerance = m_tolerance; 00131 options.max_iter = m_max_iter; 00132 00133 ETaskRelationType relation_type = m_task_relation->get_relation_type(); 00134 switch (relation_type) 00135 { 00136 case TASK_GROUP: 00137 { 00138 //CTaskGroup* task_group = (CTaskGroup*)m_task_relation; 00139 options.mode = MULTITASK_GROUP; 00140 options.loss = LOGISTIC; 00141 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00142 m_tasks_w = result.w; 00143 m_tasks_c = result.c; 00144 } 00145 break; 00146 case TASK_TREE: 00147 { 00148 CTaskTree* task_tree = (CTaskTree*)m_task_relation; 00149 SGVector<float64_t> ind_t = task_tree->get_SLEP_ind_t(); 00150 options.ind_t = ind_t.vector; 00151 options.n_nodes = ind_t.vlen / 3; 00152 options.mode = MULTITASK_TREE; 00153 options.loss = LOGISTIC; 00154 slep_result_t result = slep_solver(features, y.vector, m_z, options); 00155 m_tasks_w = result.w; 00156 m_tasks_c = result.c; 00157 } 00158 break; 00159 default: 00160 SG_ERROR("Not supported task relation type\n"); 00161 } 00162 return true; 00163 } 00164 00165 float64_t CMultitaskLogisticRegression::apply_one(int32_t i) 00166 { 00167 float64_t dot = features->dense_dot(i,m_tasks_w.get_column_vector(m_current_task),m_tasks_w.num_rows); 00168 //float64_t ep = CMath::exp(-(dot + m_tasks_c[m_current_task])); 00169 //return 2.0/(1.0+ep) - 1.0; 00170 return dot + m_tasks_c[m_current_task]; 00171 } 00172 00173 int32_t CMultitaskLogisticRegression::get_max_iter() const 00174 { 00175 return m_max_iter; 00176 } 00177 int32_t CMultitaskLogisticRegression::get_regularization() const 00178 { 00179 return m_regularization; 00180 } 00181 int32_t CMultitaskLogisticRegression::get_termination() const 00182 { 00183 return m_termination; 00184 } 00185 float64_t CMultitaskLogisticRegression::get_tolerance() const 00186 { 00187 return m_tolerance; 00188 } 00189 float64_t CMultitaskLogisticRegression::get_z() const 00190 { 00191 return m_z; 00192 } 00193 float64_t CMultitaskLogisticRegression::get_q() const 00194 { 00195 return m_q; 00196 } 00197 00198 void CMultitaskLogisticRegression::set_max_iter(int32_t max_iter) 00199 { 00200 ASSERT(max_iter>=0); 00201 m_max_iter = max_iter; 00202 } 00203 void CMultitaskLogisticRegression::set_regularization(int32_t regularization) 00204 { 00205 ASSERT(regularization==0 || regularization==1); 00206 m_regularization = regularization; 00207 } 00208 void CMultitaskLogisticRegression::set_termination(int32_t termination) 00209 { 00210 ASSERT(termination>=0 && termination<=4); 00211 m_termination = termination; 00212 } 00213 void CMultitaskLogisticRegression::set_tolerance(float64_t tolerance) 00214 { 00215 ASSERT(tolerance>0.0); 00216 m_tolerance = tolerance; 00217 } 00218 void CMultitaskLogisticRegression::set_z(float64_t z) 00219 { 00220 m_z = z; 00221 } 00222 void CMultitaskLogisticRegression::set_q(float64_t q) 00223 { 00224 m_q = q; 00225 } 00226 00227 }