SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #ifndef RELAXEDTREE_H__ 00012 #define RELAXEDTREE_H__ 00013 00014 #include <utility> 00015 #include <vector> 00016 00017 #include <shogun/features/DenseFeatures.h> 00018 #include <shogun/classifier/svm/LibSVM.h> 00019 #include <shogun/multiclass/tree/TreeMachine.h> 00020 #include <shogun/multiclass/tree/RelaxedTreeNodeData.h> 00021 00022 namespace shogun 00023 { 00024 00025 class CBaseMulticlassMachine; 00026 00034 class CRelaxedTree: public CTreeMachine<RelaxedTreeNodeData> 00035 { 00036 public: 00038 CRelaxedTree(); 00039 00041 virtual ~CRelaxedTree(); 00042 00044 virtual const char* get_name() const { return "RelaxedTree"; } 00045 00047 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00048 00052 void set_features(CDenseFeatures<float64_t> *feats) 00053 { 00054 SG_REF(feats); 00055 SG_UNREF(m_feats); 00056 m_feats = feats; 00057 } 00058 00062 virtual void set_kernel(CKernel *kernel) 00063 { 00064 SG_REF(kernel); 00065 SG_UNREF(m_kernel); 00066 m_kernel = kernel; 00067 } 00068 00073 virtual void set_labels(CLabels* lab) 00074 { 00075 CMulticlassLabels *mlab = dynamic_cast<CMulticlassLabels *>(lab); 00076 REQUIRE(lab, "requires MulticlassLabes\n"); 00077 00078 CMachine::set_labels(mlab); 00079 m_num_classes = mlab->get_num_classes(); 00080 } 00081 00085 void set_machine_for_confusion_matrix(CBaseMulticlassMachine *machine) 00086 { 00087 SG_REF(machine); 00088 SG_UNREF(m_machine_for_confusion_matrix); 00089 m_machine_for_confusion_matrix = machine; 00090 } 00091 00095 void set_svm_C(float64_t C) 00096 { 00097 m_svm_C = C; 00098 } 00102 float64_t get_svm_C() const 00103 { 00104 return m_svm_C; 00105 } 00106 00110 void set_svm_epsilon(float64_t epsilon) 00111 { 00112 m_svm_epsilon = epsilon; 00113 } 00117 float64_t get_svm_epsilon() const 00118 { 00119 return m_svm_epsilon; 00120 } 00121 00127 void set_A(float64_t A) 00128 { 00129 m_A = A; 00130 } 00134 float64_t get_A() const 00135 { 00136 return m_A; 00137 } 00138 00143 void set_B(int32_t B) 00144 { 00145 m_B = B; 00146 } 00150 int32_t get_B() const 00151 { 00152 return m_B; 00153 } 00154 00158 void set_max_num_iter(int32_t n_iter) 00159 { 00160 m_max_num_iter = n_iter; 00161 } 00165 int32_t get_max_num_iter() const 00166 { 00167 return m_max_num_iter; 00168 } 00169 00179 virtual bool train(CFeatures* data=NULL) 00180 { 00181 return CMachine::train(data); 00182 } 00183 00185 typedef std::pair<std::pair<int32_t, int32_t>, float64_t> entry_t; 00186 protected: 00193 float64_t apply_one(int32_t idx); 00194 00201 virtual bool train_machine(CFeatures* data); 00202 00204 node_t *train_node(const SGMatrix<float64_t> &conf_mat, SGVector<int32_t> classes); 00206 std::vector<entry_t> init_node(const SGMatrix<float64_t> &global_conf_mat, SGVector<int32_t> classes); 00208 SGVector<int32_t> train_node_with_initialization(const CRelaxedTree::entry_t &mu_entry, SGVector<int32_t> classes, CSVM *svm); 00209 00211 float64_t compute_score(SGVector<int32_t> mu, CSVM *svm); 00213 SGVector<int32_t> color_label_space(CSVM *svm, SGVector<int32_t> classes); 00215 SGVector<float64_t> eval_binary_model_K(CSVM *svm); 00216 00218 void enforce_balance_constraints_upper(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class); 00220 void enforce_balance_constraints_lower(SGVector<int32_t> &mu, SGVector<float64_t> &delta_neg, SGVector<float64_t> &delta_pos, int32_t B_prime, SGVector<float64_t>& xi_neg_class); 00221 00223 int32_t m_max_num_iter; 00225 float64_t m_A; 00227 int32_t m_B; 00229 float64_t m_svm_C; 00231 float64_t m_svm_epsilon; 00233 CKernel *m_kernel; 00235 CDenseFeatures<float64_t> *m_feats; 00237 CBaseMulticlassMachine *m_machine_for_confusion_matrix; 00239 int32_t m_num_classes; 00240 }; 00241 00242 } /* shogun */ 00243 00244 #endif /* end of include guard: RELAXEDTREE_H__ */ 00245