SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskLinearMachine.h> 00011 #include <shogun/lib/slep/slep_solver.h> 00012 #include <shogun/lib/slep/slep_options.h> 00013 00014 #include <map> 00015 #include <vector> 00016 00017 using namespace std; 00018 00019 namespace shogun 00020 { 00021 00022 CMultitaskLinearMachine::CMultitaskLinearMachine() : 00023 CLinearMachine(), m_current_task(0), 00024 m_task_relation(NULL) 00025 { 00026 register_parameters(); 00027 } 00028 00029 CMultitaskLinearMachine::CMultitaskLinearMachine( 00030 CDotFeatures* train_features, 00031 CLabels* train_labels, CTaskRelation* task_relation) : 00032 CLinearMachine(), m_current_task(0), m_task_relation(NULL) 00033 { 00034 set_features(train_features); 00035 set_labels(train_labels); 00036 set_task_relation(task_relation); 00037 register_parameters(); 00038 } 00039 00040 CMultitaskLinearMachine::~CMultitaskLinearMachine() 00041 { 00042 SG_UNREF(m_task_relation); 00043 } 00044 00045 void CMultitaskLinearMachine::register_parameters() 00046 { 00047 SG_ADD((CSGObject**)&m_task_relation, "task_relation", "task relation", MS_NOT_AVAILABLE); 00048 } 00049 00050 int32_t CMultitaskLinearMachine::get_current_task() const 00051 { 00052 return m_current_task; 00053 } 00054 00055 void CMultitaskLinearMachine::set_current_task(int32_t task) 00056 { 00057 ASSERT(task>=0); 00058 ASSERT(task<m_tasks_w.num_cols); 00059 m_current_task = task; 00060 } 00061 00062 CTaskRelation* CMultitaskLinearMachine::get_task_relation() const 00063 { 00064 SG_REF(m_task_relation); 00065 return m_task_relation; 00066 } 00067 00068 void CMultitaskLinearMachine::set_task_relation(CTaskRelation* task_relation) 00069 { 00070 SG_UNREF(m_task_relation); 00071 SG_REF(task_relation); 00072 m_task_relation = task_relation; 00073 } 00074 00075 bool CMultitaskLinearMachine::train_machine(CFeatures* data) 00076 { 00077 SG_NOTIMPLEMENTED; 00078 return false; 00079 } 00080 00081 void CMultitaskLinearMachine::post_lock(CLabels* labels, CFeatures* features_) 00082 { 00083 set_features((CDotFeatures*)features_); 00084 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00085 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00086 00087 m_tasks_indices.clear(); 00088 for (int32_t i=0; i<n_tasks; i++) 00089 { 00090 set<index_t> indices_set; 00091 SGVector<index_t> task_indices = tasks_indices[i]; 00092 for (int32_t j=0; j<task_indices.vlen; j++) 00093 indices_set.insert(task_indices[j]); 00094 00095 m_tasks_indices.push_back(indices_set); 00096 } 00097 00098 for (int32_t i=0; i<n_tasks; i++) 00099 tasks_indices[i].~SGVector<index_t>(); 00100 SG_FREE(tasks_indices); 00101 } 00102 00103 bool CMultitaskLinearMachine::train_locked(SGVector<index_t> indices) 00104 { 00105 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00106 ASSERT((int)m_tasks_indices.size()==n_tasks); 00107 vector< vector<index_t> > cutted_task_indices; 00108 for (int32_t i=0; i<n_tasks; i++) 00109 cutted_task_indices.push_back(vector<index_t>()); 00110 for (int32_t i=0; i<indices.vlen; i++) 00111 { 00112 for (int32_t j=0; j<n_tasks; j++) 00113 { 00114 if (m_tasks_indices[j].count(indices[i])) 00115 { 00116 cutted_task_indices[j].push_back(indices[i]); 00117 break; 00118 } 00119 } 00120 } 00121 SGVector<index_t>* tasks = SG_MALLOC(SGVector<index_t>, n_tasks); 00122 for (int32_t i=0; i<n_tasks; i++) 00123 { 00124 new (&tasks[i]) SGVector<index_t>(cutted_task_indices[i].size()); 00125 for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++) 00126 tasks[i][j] = cutted_task_indices[i][j]; 00127 //tasks[i].display_vector(); 00128 } 00129 bool res = train_locked_implementation(tasks); 00130 for (int32_t i=0; i<n_tasks; i++) 00131 tasks[i].~SGVector<index_t>(); 00132 SG_FREE(tasks); 00133 return res; 00134 } 00135 00136 bool CMultitaskLinearMachine::train_locked_implementation(SGVector<index_t>* tasks) 00137 { 00138 SG_NOTIMPLEMENTED; 00139 return false; 00140 } 00141 00142 CBinaryLabels* CMultitaskLinearMachine::apply_locked_binary(SGVector<index_t> indices) 00143 { 00144 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00145 SGVector<float64_t> result(indices.vlen); 00146 result.zero(); 00147 for (int32_t i=0; i<indices.vlen; i++) 00148 { 00149 for (int32_t j=0; j<n_tasks; j++) 00150 { 00151 if (m_tasks_indices[j].count(indices[i])) 00152 { 00153 set_current_task(j); 00154 result[i] = apply_one(indices[i]); 00155 break; 00156 } 00157 } 00158 } 00159 return new CBinaryLabels(result); 00160 } 00161 00162 float64_t CMultitaskLinearMachine::apply_one(int32_t i) 00163 { 00164 SG_NOTIMPLEMENTED; 00165 return 0.0; 00166 } 00167 00168 SGVector<float64_t> CMultitaskLinearMachine::apply_get_outputs(CFeatures* data) 00169 { 00170 if (data) 00171 { 00172 if (!data->has_property(FP_DOT)) 00173 SG_ERROR("Specified features are not of type CDotFeatures\n"); 00174 00175 set_features((CDotFeatures*) data); 00176 } 00177 00178 if (!features) 00179 return SGVector<float64_t>(); 00180 00181 int32_t num=features->get_num_vectors(); 00182 ASSERT(num>0); 00183 float64_t* out=SG_MALLOC(float64_t, num); 00184 for (int32_t i=0; i<num; i++) 00185 out[i] = apply_one(i); 00186 00187 return SGVector<float64_t>(out,num); 00188 } 00189 00190 SGVector<float64_t> CMultitaskLinearMachine::get_w() const 00191 { 00192 SGVector<float64_t> w_(m_tasks_w.num_rows); 00193 for (int32_t i=0; i<w_.vlen; i++) 00194 w_[i] = m_tasks_w(i,m_current_task); 00195 return w_; 00196 } 00197 00198 void CMultitaskLinearMachine::set_w(const SGVector<float64_t> src_w) 00199 { 00200 for (int32_t i=0; i<m_tasks_w.num_rows; i++) 00201 m_tasks_w(i,m_current_task) = src_w[i]; 00202 } 00203 00204 void CMultitaskLinearMachine::set_bias(float64_t b) 00205 { 00206 m_tasks_c[m_current_task] = b; 00207 } 00208 00209 float64_t CMultitaskLinearMachine::get_bias() 00210 { 00211 return m_tasks_c[m_current_task]; 00212 } 00213 00214 SGVector<index_t>* CMultitaskLinearMachine::get_subset_tasks_indices() 00215 { 00216 int n_tasks = ((CTaskGroup*)m_task_relation)->get_num_tasks(); 00217 SGVector<index_t>* tasks_indices = ((CTaskGroup*)m_task_relation)->get_tasks_indices(); 00218 00219 CSubsetStack* sstack = features->get_subset_stack(); 00220 map<index_t,index_t> subset_inv_map = map<index_t,index_t>(); 00221 for (int32_t i=0; i<sstack->get_size(); i++) 00222 subset_inv_map[sstack->subset_idx_conversion(i)] = i; 00223 00224 SGVector<index_t>* subset_tasks_indices = SG_MALLOC(SGVector<index_t>, n_tasks); 00225 for (int32_t i=0; i<n_tasks; i++) 00226 { 00227 new (&subset_tasks_indices[i]) SGVector<index_t>(); 00228 SGVector<index_t> task = tasks_indices[i]; 00229 //task.display_vector("task"); 00230 vector<index_t> cutted = vector<index_t>(); 00231 for (int32_t j=0; j<task.vlen; j++) 00232 { 00233 if (subset_inv_map.count(task[j])) 00234 cutted.push_back(subset_inv_map[task[j]]); 00235 } 00236 SGVector<index_t> cutted_task(cutted.size()); 00237 for (int32_t j=0; j<cutted_task.vlen; j++) 00238 cutted_task[j] = cutted[j]; 00239 //cutted_task.display_vector("cutted"); 00240 subset_tasks_indices[i] = cutted_task; 00241 } 00242 for (int32_t i=0; i<n_tasks; i++) 00243 tasks_indices[i].~SGVector<index_t>(); 00244 SG_FREE(tasks_indices); 00245 00246 return subset_tasks_indices; 00247 } 00248 00249 00250 }