SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Sergey Lisitsyn 00008 */ 00009 00010 #include <shogun/transfer/multitask/MultitaskCompositeMachine.h> 00011 00012 #include <map> 00013 #include <vector> 00014 00015 using namespace std; 00016 00017 namespace shogun 00018 { 00019 00020 CMultitaskCompositeMachine::CMultitaskCompositeMachine() : 00021 CMachine(), m_machine(NULL), m_features(NULL), m_current_task(0), 00022 m_task_group(NULL) 00023 { 00024 register_parameters(); 00025 } 00026 00027 CMultitaskCompositeMachine::CMultitaskCompositeMachine( 00028 CMachine* machine, CFeatures* train_features, 00029 CLabels* train_labels, CTaskGroup* task_group) : 00030 CMachine(), m_machine(NULL), m_features(NULL), 00031 m_current_task(0), m_task_group(NULL) 00032 { 00033 set_machine(machine); 00034 set_features(train_features); 00035 set_labels(train_labels); 00036 set_task_group(task_group); 00037 register_parameters(); 00038 } 00039 00040 CMultitaskCompositeMachine::~CMultitaskCompositeMachine() 00041 { 00042 SG_UNREF(m_machine); 00043 SG_UNREF(m_features); 00044 SG_UNREF(m_task_machines); 00045 SG_UNREF(m_task_group); 00046 } 00047 00048 void CMultitaskCompositeMachine::register_parameters() 00049 { 00050 SG_ADD((CSGObject**)&m_machine, "machine", "machine", MS_AVAILABLE); 00051 SG_ADD((CSGObject**)&m_features, "features", "features", MS_NOT_AVAILABLE); 00052 SG_ADD((CSGObject**)&m_task_machines, "task_machines", "task machines", MS_NOT_AVAILABLE); 00053 SG_ADD((CSGObject**)&m_task_group, "task_group", "task group", MS_NOT_AVAILABLE); 00054 } 00055 00056 int32_t CMultitaskCompositeMachine::get_current_task() const 00057 { 00058 return m_current_task; 00059 } 00060 00061 void CMultitaskCompositeMachine::set_current_task(int32_t task) 00062 { 00063 m_current_task = task; 00064 } 00065 00066 CTaskGroup* CMultitaskCompositeMachine::get_task_group() const 00067 { 00068 SG_REF(m_task_group); 00069 return m_task_group; 00070 } 00071 00072 void CMultitaskCompositeMachine::set_task_group(CTaskGroup* task_group) 00073 { 00074 SG_UNREF(m_task_group); 00075 SG_REF(task_group); 00076 m_task_group = task_group; 00077 } 00078 00079 bool CMultitaskCompositeMachine::train_machine(CFeatures* data) 00080 { 00081 SG_NOTIMPLEMENTED; 00082 return false; 00083 } 00084 00085 void CMultitaskCompositeMachine::post_lock(CLabels* labels, CFeatures* features) 00086 { 00087 ASSERT(m_task_group); 00088 set_features(m_features); 00089 if (!m_machine->is_data_locked()) 00090 m_machine->data_lock(labels,features); 00091 00092 int n_tasks = m_task_group->get_num_tasks(); 00093 SGVector<index_t>* tasks_indices = m_task_group->get_tasks_indices(); 00094 00095 m_tasks_indices.clear(); 00096 for (int32_t i=0; i<n_tasks; i++) 00097 { 00098 set<index_t> indices_set; 00099 SGVector<index_t> task_indices = tasks_indices[i]; 00100 for (int32_t j=0; j<task_indices.vlen; j++) 00101 indices_set.insert(task_indices[j]); 00102 00103 m_tasks_indices.push_back(indices_set); 00104 } 00105 00106 for (int32_t i=0; i<n_tasks; i++) 00107 tasks_indices[i].~SGVector<index_t>(); 00108 SG_FREE(tasks_indices); 00109 } 00110 00111 bool CMultitaskCompositeMachine::train_locked(SGVector<index_t> indices) 00112 { 00113 int n_tasks = m_task_group->get_num_tasks(); 00114 ASSERT((int)m_tasks_indices.size()==n_tasks); 00115 vector< vector<index_t> > cutted_task_indices; 00116 for (int32_t i=0; i<n_tasks; i++) 00117 cutted_task_indices.push_back(vector<index_t>()); 00118 for (int32_t i=0; i<indices.vlen; i++) 00119 { 00120 for (int32_t j=0; j<n_tasks; j++) 00121 { 00122 if (m_tasks_indices[j].count(indices[i])) 00123 { 00124 cutted_task_indices[j].push_back(indices[i]); 00125 break; 00126 } 00127 } 00128 } 00129 //SG_UNREF(m_task_machines); 00130 m_task_machines = new CDynamicObjectArray(); 00131 for (int32_t i=0; i<n_tasks; i++) 00132 { 00133 SGVector<index_t> task_indices(cutted_task_indices[i].size()); 00134 for (int32_t j=0; j<(int)cutted_task_indices[i].size(); j++) 00135 task_indices[j] = cutted_task_indices[i][j]; 00136 00137 m_machine->train_locked(task_indices); 00138 m_task_machines->push_back(m_machine->clone()); 00139 } 00140 return true; 00141 } 00142 00143 float64_t CMultitaskCompositeMachine::apply_one(int32_t i) 00144 { 00145 CMachine* m = (CMachine*)(m_task_machines->get_element(m_current_task)); 00146 float64_t result = m->apply_one(i); 00147 SG_UNREF(m); 00148 return result; 00149 } 00150 00151 CBinaryLabels* CMultitaskCompositeMachine::apply_locked_binary(SGVector<index_t> indices) 00152 { 00153 int n_tasks = m_task_group->get_num_tasks(); 00154 SGVector<float64_t> result(indices.vlen); 00155 result.zero(); 00156 for (int32_t i=0; i<indices.vlen; i++) 00157 { 00158 for (int32_t j=0; j<n_tasks; j++) 00159 { 00160 if (m_tasks_indices[j].count(indices[i])) 00161 { 00162 set_current_task(j); 00163 result[i] = apply_one(indices[i]); 00164 break; 00165 } 00166 } 00167 } 00168 return new CBinaryLabels(result); 00169 } 00170 00171 }