SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <vector> 00012 #include <stack> 00013 00014 #include <shogun/multiclass/tree/ConditionalProbabilityTree.h> 00015 #include <shogun/classifier/svm/OnlineLibLinear.h> 00016 00017 using namespace shogun; 00018 using namespace std; 00019 00020 CMulticlassLabels* CConditionalProbabilityTree::apply_multiclass(CFeatures* data) 00021 { 00022 if (data) 00023 { 00024 if (data->get_feature_class() != C_STREAMING_DENSE) 00025 SG_ERROR("Expected StreamingDenseFeatures\n"); 00026 if (data->get_feature_type() != F_SHORTREAL) 00027 SG_ERROR("Expected float32_t feature type\n"); 00028 00029 set_features(dynamic_cast<CStreamingDenseFeatures<float32_t>* >(data)); 00030 } 00031 00032 vector<int32_t> predicts; 00033 00034 m_feats->start_parser(); 00035 while (m_feats->get_next_example()) 00036 { 00037 predicts.push_back(apply_multiclass_example(m_feats->get_vector())); 00038 m_feats->release_example(); 00039 } 00040 m_feats->end_parser(); 00041 00042 CMulticlassLabels *labels = new CMulticlassLabels(predicts.size()); 00043 for (size_t i=0; i < predicts.size(); ++i) 00044 labels->set_int_label(i, predicts[i]); 00045 return labels; 00046 } 00047 00048 int32_t CConditionalProbabilityTree::apply_multiclass_example(SGVector<float32_t> ex) 00049 { 00050 compute_conditional_probabilities(ex); 00051 SGVector<float64_t> probs(m_leaves.size()); 00052 for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it) 00053 { 00054 probs[it->first] = accumulate_conditional_probability(it->second); 00055 } 00056 return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen); 00057 } 00058 00059 void CConditionalProbabilityTree::compute_conditional_probabilities(SGVector<float32_t> ex) 00060 { 00061 stack<node_t *> nodes; 00062 nodes.push(m_root); 00063 00064 while (!nodes.empty()) 00065 { 00066 node_t *node = nodes.top(); 00067 nodes.pop(); 00068 if (node->left()) 00069 { 00070 nodes.push(node->left()); 00071 nodes.push(node->right()); 00072 00073 // don't calculate for leaf 00074 node->data.p_right = predict_node(ex, node); 00075 } 00076 } 00077 } 00078 00079 float64_t CConditionalProbabilityTree::accumulate_conditional_probability(node_t *leaf) 00080 { 00081 float64_t prob = 1; 00082 node_t *par = leaf->parent(); 00083 while (par != NULL) 00084 { 00085 if (leaf == par->left()) 00086 prob *= (1-par->data.p_right); 00087 else 00088 prob *= par->data.p_right; 00089 00090 leaf = par; 00091 par = leaf->parent(); 00092 } 00093 00094 return prob; 00095 } 00096 00097 bool CConditionalProbabilityTree::train_machine(CFeatures* data) 00098 { 00099 if (data) 00100 { 00101 if (data->get_feature_class() != C_STREAMING_DENSE) 00102 SG_ERROR("Expected StreamingDenseFeatures\n"); 00103 if (data->get_feature_type() != F_SHORTREAL) 00104 SG_ERROR("Expected float32_t features\n"); 00105 set_features(dynamic_cast<CStreamingDenseFeatures<float32_t> *>(data)); 00106 } 00107 else 00108 { 00109 if (!m_feats) 00110 SG_ERROR("No data features provided\n"); 00111 } 00112 00113 m_machines->reset_array(); 00114 SG_UNREF(m_root); 00115 m_root = NULL; 00116 00117 m_leaves.clear(); 00118 00119 m_feats->start_parser(); 00120 for (int32_t ipass=0; ipass < m_num_passes; ++ipass) 00121 { 00122 while (m_feats->get_next_example()) 00123 { 00124 train_example(m_feats->get_vector(), static_cast<int32_t>(m_feats->get_label())); 00125 m_feats->release_example(); 00126 } 00127 00128 if (ipass < m_num_passes-1) 00129 m_feats->reset_stream(); 00130 } 00131 m_feats->end_parser(); 00132 00133 for (int32_t i=0; i < m_machines->get_num_elements(); ++i) 00134 { 00135 COnlineLibLinear *lll = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(i)); 00136 lll->stop_train(); 00137 SG_UNREF(lll); 00138 } 00139 00140 return true; 00141 } 00142 00143 void CConditionalProbabilityTree::print_tree() 00144 { 00145 if (m_root) 00146 m_root->debug_print(ConditionalProbabilityTreeNodeData::print_data); 00147 else 00148 printf("Empty Tree\n"); 00149 } 00150 00151 void CConditionalProbabilityTree::train_example(SGVector<float32_t> ex, int32_t label) 00152 { 00153 if (m_root == NULL) 00154 { 00155 m_root = new node_t(); 00156 m_root->data.label = label; 00157 m_leaves.insert(make_pair(label, m_root)); 00158 m_root->machine(create_machine(ex)); 00159 return; 00160 } 00161 00162 if (m_leaves.find(label) != m_leaves.end()) 00163 { 00164 train_path(ex, m_leaves[label]); 00165 } 00166 else 00167 { 00168 node_t *node = m_root; 00169 while (node->left() != NULL) 00170 { 00171 // not a leaf 00172 bool is_left = which_subtree(node, ex); 00173 float64_t node_label; 00174 if (is_left) 00175 node_label = 0; 00176 else 00177 node_label = 1; 00178 train_node(ex, node_label, node); 00179 00180 if (is_left) 00181 node = node->left(); 00182 else 00183 node = node->right(); 00184 } 00185 00186 m_leaves.erase(node->data.label); 00187 00188 node_t *left_node = new node_t(); 00189 left_node->data.label = node->data.label; 00190 node->data.label = -1; 00191 COnlineLibLinear *node_mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine())); 00192 COnlineLibLinear *mch = new COnlineLibLinear(node_mch); 00193 SG_UNREF(node_mch); 00194 mch->start_train(); 00195 m_machines->push_back(mch); 00196 left_node->machine(m_machines->get_num_elements()-1); 00197 m_leaves.insert(make_pair(left_node->data.label, left_node)); 00198 node->left(left_node); 00199 00200 node_t *right_node = new node_t(); 00201 right_node->data.label = label; 00202 right_node->machine(create_machine(ex)); 00203 m_leaves.insert(make_pair(label, right_node)); 00204 node->right(right_node); 00205 } 00206 } 00207 00208 void CConditionalProbabilityTree::train_path(SGVector<float32_t> ex, node_t *node) 00209 { 00210 float64_t node_label = 0; 00211 train_node(ex, node_label, node); 00212 00213 node_t *par = node->parent(); 00214 while (par != NULL) 00215 { 00216 if (par->left() == node) 00217 node_label = 0; 00218 else 00219 node_label = 1; 00220 00221 train_node(ex, node_label, par); 00222 node = par; 00223 par = node->parent(); 00224 } 00225 } 00226 00227 void CConditionalProbabilityTree::train_node(SGVector<float32_t> ex, float64_t label, node_t *node) 00228 { 00229 COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine())); 00230 ASSERT(mch); 00231 mch->train_one(ex, label); 00232 SG_UNREF(mch); 00233 } 00234 00235 float64_t CConditionalProbabilityTree::predict_node(SGVector<float32_t> ex, node_t *node) 00236 { 00237 COnlineLibLinear *mch = dynamic_cast<COnlineLibLinear *>(m_machines->get_element(node->machine())); 00238 ASSERT(mch); 00239 float64_t pred = mch->apply_one(ex.vector, ex.vlen); 00240 SG_UNREF(mch); 00241 // use sigmoid function to turn the decision value into valid probability 00242 return 1.0/(1+CMath::exp(-pred)); 00243 } 00244 00245 int32_t CConditionalProbabilityTree::create_machine(SGVector<float32_t> ex) 00246 { 00247 COnlineLibLinear *mch = new COnlineLibLinear(); 00248 mch->start_train(); 00249 mch->train_one(ex, 0); 00250 m_machines->push_back(mch); 00251 return m_machines->get_num_elements()-1; 00252 }