SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2011 Soeren Sonnenburg 00008 * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn 00009 * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia 00010 */ 00011 00012 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h> 00013 #include <shogun/machine/LinearMachine.h> 00014 #include <shogun/machine/KernelMachine.h> 00015 #include <shogun/machine/MulticlassMachine.h> 00016 #include <shogun/base/Parameter.h> 00017 #include <shogun/labels/MulticlassLabels.h> 00018 #include <shogun/labels/RegressionLabels.h> 00019 00020 using namespace shogun; 00021 00022 CMulticlassMachine::CMulticlassMachine() 00023 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()), 00024 m_machine(NULL) 00025 { 00026 SG_REF(m_multiclass_strategy); 00027 register_parameters(); 00028 } 00029 00030 CMulticlassMachine::CMulticlassMachine( 00031 CMulticlassStrategy *strategy, 00032 CMachine* machine, CLabels* labs) 00033 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy) 00034 { 00035 SG_REF(strategy); 00036 set_labels(labs); 00037 SG_REF(machine); 00038 m_machine = machine; 00039 register_parameters(); 00040 00041 if (labs) 00042 init_strategy(); 00043 } 00044 00045 CMulticlassMachine::~CMulticlassMachine() 00046 { 00047 SG_UNREF(m_multiclass_strategy); 00048 SG_UNREF(m_machine); 00049 } 00050 00051 void CMulticlassMachine::set_labels(CLabels* lab) 00052 { 00053 CMachine::set_labels(lab); 00054 if (lab) 00055 init_strategy(); 00056 } 00057 00058 void CMulticlassMachine::register_parameters() 00059 { 00060 SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE); 00061 SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE); 00062 } 00063 00064 void CMulticlassMachine::init_strategy() 00065 { 00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes(); 00067 m_multiclass_strategy->set_num_classes(num_classes); 00068 } 00069 00070 CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i) 00071 { 00072 CMachine *machine = (CMachine*)m_machines->get_element(i); 00073 ASSERT(machine); 00074 CBinaryLabels* output = machine->apply_binary(); 00075 SG_UNREF(machine); 00076 return output; 00077 } 00078 00079 float64_t CMulticlassMachine::get_submachine_output(int32_t i, int32_t num) 00080 { 00081 CMachine *machine = get_machine(i); 00082 float64_t output = 0.0; 00083 // dirty hack 00084 if (dynamic_cast<CLinearMachine*>(machine)) 00085 output = ((CLinearMachine*)machine)->apply_one(num); 00086 if (dynamic_cast<CKernelMachine*>(machine)) 00087 output = ((CKernelMachine*)machine)->apply_one(num); 00088 SG_UNREF(machine); 00089 return output; 00090 } 00091 00092 CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data) 00093 { 00094 SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n", 00095 get_name(), data ? data->get_name() : "NULL", data); 00096 00097 CMulticlassLabels* return_labels=NULL; 00098 00099 if (data) 00100 init_machines_for_apply(data); 00101 else 00102 init_machines_for_apply(NULL); 00103 00104 if (is_ready()) 00105 { 00106 /* num vectors depends on whether data is provided */ 00107 int32_t num_vectors=data ? data->get_num_vectors() : 00108 get_num_rhs_vectors(); 00109 00110 int32_t num_machines=m_machines->get_num_elements(); 00111 if (num_machines <= 0) 00112 SG_ERROR("num_machines = %d, did you train your machine?", num_machines); 00113 00114 CMulticlassLabels* result=new CMulticlassLabels(num_vectors); 00115 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines); 00116 00117 for (int32_t i=0; i < num_machines; ++i) 00118 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i); 00119 00120 SGVector<float64_t> output_for_i(num_machines); 00121 for (int32_t i=0; i<num_vectors; i++) 00122 { 00123 for (int32_t j=0; j<num_machines; j++) 00124 output_for_i[j] = outputs[j]->get_confidence(i); 00125 00126 result->set_label(i, m_multiclass_strategy->decide_label(output_for_i)); 00127 result->set_multiclass_confidences(i, output_for_i.clone()); 00128 } 00129 00130 for (int32_t i=0; i < num_machines; ++i) 00131 SG_UNREF(outputs[i]); 00132 00133 SG_FREE(outputs); 00134 00135 return_labels=result; 00136 } 00137 else 00138 SG_ERROR("Not ready"); 00139 00140 00141 SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n", 00142 get_name(), data ? data->get_name() : "NULL", data); 00143 return return_labels; 00144 } 00145 00146 CMulticlassMultipleOutputLabels* CMulticlassMachine::apply_multiclass_multiple_output(CFeatures* data, int32_t n_outputs) 00147 { 00148 CMulticlassMultipleOutputLabels* return_labels=NULL; 00149 00150 if (data) 00151 init_machines_for_apply(data); 00152 else 00153 init_machines_for_apply(NULL); 00154 00155 if (is_ready()) 00156 { 00157 /* num vectors depends on whether data is provided */ 00158 int32_t num_vectors=data ? data->get_num_vectors() : 00159 get_num_rhs_vectors(); 00160 00161 int32_t num_machines=m_machines->get_num_elements(); 00162 if (num_machines <= 0) 00163 SG_ERROR("num_machines = %d, did you train your machine?", num_machines); 00164 REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available"); 00165 00166 CMulticlassMultipleOutputLabels* result=new CMulticlassMultipleOutputLabels(num_vectors); 00167 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines); 00168 00169 for (int32_t i=0; i < num_machines; ++i) 00170 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i); 00171 00172 SGVector<float64_t> output_for_i(num_machines); 00173 for (int32_t i=0; i<num_vectors; i++) 00174 { 00175 for (int32_t j=0; j<num_machines; j++) 00176 output_for_i[j] = outputs[j]->get_confidence(i); 00177 00178 result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs)); 00179 } 00180 00181 for (int32_t i=0; i < num_machines; ++i) 00182 SG_UNREF(outputs[i]); 00183 00184 SG_FREE(outputs); 00185 00186 return_labels=result; 00187 } 00188 else 00189 SG_ERROR("Not ready"); 00190 00191 return return_labels; 00192 } 00193 00194 bool CMulticlassMachine::train_machine(CFeatures* data) 00195 { 00196 ASSERT(m_multiclass_strategy); 00197 00198 if ( !data && !is_ready() ) 00199 SG_ERROR("Please provide training data.\n"); 00200 else 00201 init_machine_for_train(data); 00202 00203 m_machines->reset_array(); 00204 CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors()); 00205 SG_REF(train_labels); 00206 m_machine->set_labels(train_labels); 00207 00208 m_multiclass_strategy->train_start(CMulticlassLabels::obtain_from_generic(m_labels), train_labels); 00209 while (m_multiclass_strategy->train_has_more()) 00210 { 00211 SGVector<index_t> subset=m_multiclass_strategy->train_prepare_next(); 00212 if (subset.vlen) 00213 { 00214 train_labels->add_subset(subset); 00215 add_machine_subset(subset); 00216 } 00217 00218 m_machine->train(); 00219 m_machines->push_back(get_machine_from_trained(m_machine)); 00220 00221 if (subset.vlen) 00222 { 00223 train_labels->remove_subset(); 00224 remove_machine_subset(); 00225 } 00226 } 00227 00228 m_multiclass_strategy->train_stop(); 00229 SG_UNREF(train_labels); 00230 00231 return true; 00232 } 00233 00234 float64_t CMulticlassMachine::apply_one(int32_t vec_idx) 00235 { 00236 init_machines_for_apply(NULL); 00237 00238 ASSERT(m_machines->get_num_elements()>0); 00239 SGVector<float64_t> outputs(m_machines->get_num_elements()); 00240 00241 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00242 outputs[i] = get_submachine_output(i, vec_idx); 00243 00244 float64_t result = m_multiclass_strategy->decide_label(outputs); 00245 00246 return result; 00247 }