SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #ifndef _LINEARMULTICLASSMACHINE_H___ 00012 #define _LINEARMULTICLASSMACHINE_H___ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/DotFeatures.h> 00016 #include <shogun/machine/LinearMachine.h> 00017 #include <shogun/machine/MulticlassMachine.h> 00018 00019 namespace shogun 00020 { 00021 00022 class CDotFeatures; 00023 class CLinearMachine; 00024 class CMulticlassStrategy; 00025 00027 class CLinearMulticlassMachine : public CMulticlassMachine 00028 { 00029 public: 00031 CLinearMulticlassMachine() : CMulticlassMachine(), m_features(NULL) 00032 { 00033 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.", 00034 MS_NOT_AVAILABLE); 00035 } 00036 00043 CLinearMulticlassMachine(CMulticlassStrategy *strategy, CDotFeatures* features, CLinearMachine* machine, CLabels* labs) : 00044 CMulticlassMachine(strategy,(CMachine*)machine,labs), m_features(NULL) 00045 { 00046 set_features(features); 00047 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.", 00048 MS_NOT_AVAILABLE); 00049 } 00050 00052 virtual ~CLinearMulticlassMachine() 00053 { 00054 SG_UNREF(m_features); 00055 } 00056 00058 virtual const char* get_name() const 00059 { 00060 return "LinearMulticlassMachine"; 00061 } 00062 00067 void set_features(CDotFeatures* f) 00068 { 00069 SG_REF(f); 00070 SG_UNREF(m_features); 00071 m_features = f; 00072 } 00073 00078 CDotFeatures* get_features() const 00079 { 00080 SG_REF(m_features); 00081 return m_features; 00082 } 00083 00084 protected: 00085 00087 virtual bool init_machine_for_train(CFeatures* data) 00088 { 00089 if (!m_machine) 00090 SG_ERROR("No machine given in Multiclass constructor\n"); 00091 00092 if (data) 00093 set_features((CDotFeatures*)data); 00094 00095 ((CLinearMachine*)m_machine)->set_features(m_features); 00096 00097 return true; 00098 } 00099 00101 virtual bool init_machines_for_apply(CFeatures* data) 00102 { 00103 if (data) 00104 set_features((CDotFeatures*)data); 00105 00106 for (int32_t i=0; i<m_machines->get_num_elements(); i++) 00107 { 00108 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i); 00109 ASSERT(m_features); 00110 ASSERT(machine); 00111 machine->set_features(m_features); 00112 SG_UNREF(machine); 00113 } 00114 00115 return true; 00116 } 00117 00119 virtual bool is_ready() 00120 { 00121 if (m_features) 00122 return true; 00123 00124 return false; 00125 } 00126 00128 virtual CMachine* get_machine_from_trained(CMachine* machine) 00129 { 00130 return new CLinearMachine((CLinearMachine*)machine); 00131 } 00132 00134 virtual int32_t get_num_rhs_vectors() 00135 { 00136 return m_features->get_num_vectors(); 00137 } 00138 00143 virtual void add_machine_subset(SGVector<index_t> subset) 00144 { 00145 /* changing the subset structure to use subset stacks. This might 00146 * have to be revised. Heiko Strathmann */ 00147 m_features->add_subset(subset); 00148 } 00149 00151 virtual void remove_machine_subset() 00152 { 00153 /* changing the subset structure to use subset stacks. This might 00154 * have to be revised. Heiko Strathmann */ 00155 m_features->remove_subset(); 00156 } 00157 00162 virtual void store_model_features() {} 00163 00164 protected: 00165 00167 CDotFeatures* m_features; 00168 }; 00169 } 00170 #endif