SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Sergey Lisitsyn 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #ifndef _MULTICLASSOCAS_H___ 00012 #define _MULTICLASSOCAS_H___ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/features/DotFeatures.h> 00016 #include <shogun/lib/external/libocas.h> 00017 #include <shogun/machine/LinearMulticlassMachine.h> 00018 00019 namespace shogun 00020 { 00021 00023 class CMulticlassOCAS : public CLinearMulticlassMachine 00024 { 00025 public: 00026 MACHINE_PROBLEM_TYPE(PT_MULTICLASS) 00027 00028 00029 CMulticlassOCAS(); 00030 00036 CMulticlassOCAS(float64_t C, CDotFeatures* features, CLabels* labs); 00037 00039 virtual ~CMulticlassOCAS(); 00040 00042 virtual const char* get_name() const 00043 { 00044 return "MulticlassOCAS"; 00045 } 00046 00050 inline void set_C(float64_t C) 00051 { 00052 ASSERT(C>0); 00053 m_C = C; 00054 } 00058 inline float64_t get_C() const { return m_C; } 00059 00063 inline void set_epsilon(float64_t epsilon) 00064 { 00065 ASSERT(epsilon>0); 00066 m_epsilon = epsilon; 00067 } 00071 inline float64_t get_epsilon() const { return m_epsilon; } 00072 00076 inline void set_max_iter(int32_t max_iter) 00077 { 00078 ASSERT(max_iter>0); 00079 m_max_iter = max_iter; 00080 } 00084 inline int32_t get_max_iter() const { return m_max_iter; } 00085 00089 inline void set_method(int32_t method) 00090 { 00091 ASSERT(method==0 || method==1); 00092 m_method = method; 00093 } 00097 inline int32_t get_method() const { return m_method; } 00098 00102 inline void set_buf_size(int32_t buf_size) 00103 { 00104 ASSERT(buf_size>0); 00105 m_buf_size = buf_size; 00106 } 00110 inline int32_t get_buf_size() const { return m_buf_size; } 00111 00112 protected: 00113 00115 virtual bool train_machine(CFeatures* data = NULL); 00116 00118 static float64_t msvm_update_W(float64_t t, void* user_data); 00119 00121 static void msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW, 00122 float64_t *alpha, uint32_t nSel, void* user_data); 00123 00125 static int msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut, 00126 uint32_t nSel, void* user_data); 00127 00129 static int msvm_full_compute_output(float64_t *output, void* user_data); 00130 00132 static int msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size); 00133 00135 static void msvm_print(ocas_return_value_T value); 00136 00137 private: 00138 00140 void register_parameters(); 00141 00142 protected: 00143 00145 float64_t m_C; 00146 00148 float64_t m_epsilon; 00149 00151 int32_t m_max_iter; 00152 00154 int32_t m_method; 00155 00157 int32_t m_buf_size; 00158 }; 00159 } 00160 #endif