SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #ifndef _MULTICLASSLIBLINEAR_H___ 00012 #define _MULTICLASSLIBLINEAR_H___ 00013 #include <shogun/lib/config.h> 00014 #ifdef HAVE_LAPACK 00015 #include <shogun/lib/common.h> 00016 #include <shogun/features/DotFeatures.h> 00017 #include <shogun/machine/LinearMulticlassMachine.h> 00018 #include <shogun/optimization/liblinear/shogun_liblinear.h> 00019 00020 namespace shogun 00021 { 00022 00037 class CMulticlassLibLinear : public CLinearMulticlassMachine 00038 { 00039 public: 00040 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00041 00042 00043 CMulticlassLibLinear(); 00044 00050 CMulticlassLibLinear(float64_t C, CDotFeatures* features, CLabels* labs); 00051 00053 virtual ~CMulticlassLibLinear(); 00054 00056 virtual const char* get_name() const 00057 { 00058 return "MulticlassLibLinear"; 00059 } 00060 00064 inline void set_C(float64_t C) 00065 { 00066 ASSERT(C>0); 00067 m_C = C; 00068 } 00072 inline float64_t get_C() const { return m_C; } 00073 00077 inline void set_epsilon(float64_t epsilon) 00078 { 00079 ASSERT(epsilon>0); 00080 m_epsilon = epsilon; 00081 } 00085 inline float64_t get_epsilon() const { return m_epsilon; } 00086 00090 inline void set_use_bias(bool use_bias) 00091 { 00092 m_use_bias = use_bias; 00093 } 00097 inline bool get_use_bias() const 00098 { 00099 return m_use_bias; 00100 } 00101 00105 inline void set_save_train_state(bool save_train_state) 00106 { 00107 m_save_train_state = save_train_state; 00108 } 00112 inline bool get_save_train_state() const 00113 { 00114 return m_save_train_state; 00115 } 00116 00120 inline void set_max_iter(int32_t max_iter) 00121 { 00122 ASSERT(max_iter>0); 00123 m_max_iter = max_iter; 00124 } 00128 inline int32_t get_max_iter() const { return m_max_iter; } 00129 00131 void reset_train_state() 00132 { 00133 if (m_train_state) 00134 { 00135 delete m_train_state; 00136 m_train_state = NULL; 00137 } 00138 } 00139 00143 SGVector<int32_t> get_support_vectors() const; 00144 00145 protected: 00146 00148 virtual bool train_machine(CFeatures* data = NULL); 00149 00151 virtual SGMatrix<float64_t> obtain_regularizer_matrix() const; 00152 00153 private: 00154 00156 void init_defaults(); 00157 00159 void register_parameters(); 00160 00161 protected: 00162 00164 float64_t m_C; 00165 00167 float64_t m_epsilon; 00168 00170 int32_t m_max_iter; 00171 00173 bool m_use_bias; 00174 00176 bool m_save_train_state; 00177 00179 mcsvm_state* m_train_state; 00180 }; 00181 } 00182 #endif /* HAVE_LAPACK */ 00183 #endif