SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <vector> 00012 #include <stack> 00013 #include <shogun/multiclass/tree/VwConditionalProbabilityTree.h> 00014 00015 using namespace shogun; 00016 using namespace std; 00017 00018 CMulticlassLabels* CVwConditionalProbabilityTree::apply_multiclass(CFeatures* data) 00019 { 00020 if (data) 00021 { 00022 if (data->get_feature_class() != C_STREAMING_VW) 00023 SG_ERROR("Expected StreamingVwFeatures\n"); 00024 set_features(dynamic_cast<CStreamingVwFeatures*>(data)); 00025 } 00026 00027 vector<int32_t> predicts; 00028 00029 m_feats->start_parser(); 00030 while (m_feats->get_next_example()) 00031 { 00032 predicts.push_back(apply_multiclass_example(m_feats->get_example())); 00033 m_feats->release_example(); 00034 } 00035 m_feats->end_parser(); 00036 00037 CMulticlassLabels *labels = new CMulticlassLabels(predicts.size()); 00038 for (size_t i=0; i < predicts.size(); ++i) 00039 labels->set_int_label(i, predicts[i]); 00040 return labels; 00041 } 00042 00043 int32_t CVwConditionalProbabilityTree::apply_multiclass_example(VwExample* ex) 00044 { 00045 ex->ld->label = FLT_MAX; // this will disable VW learning from this example 00046 00047 compute_conditional_probabilities(ex); 00048 SGVector<float64_t> probs(m_leaves.size()); 00049 for (map<int32_t,node_t*>::iterator it = m_leaves.begin(); it != m_leaves.end(); ++it) 00050 { 00051 probs[it->first] = accumulate_conditional_probability(it->second); 00052 } 00053 return SGVector<float64_t>::arg_max(probs.vector, 1, probs.vlen); 00054 } 00055 00056 void CVwConditionalProbabilityTree::compute_conditional_probabilities(VwExample *ex) 00057 { 00058 stack<node_t *> nodes; 00059 nodes.push(m_root); 00060 00061 while (!nodes.empty()) 00062 { 00063 node_t *node = nodes.top(); 00064 nodes.pop(); 00065 if (node->left()) 00066 { 00067 nodes.push(node->left()); 00068 nodes.push(node->right()); 00069 00070 // don't calculate for leaf 00071 node->data.p_right = train_node(ex, node); 00072 } 00073 } 00074 } 00075 00076 float64_t CVwConditionalProbabilityTree::accumulate_conditional_probability(node_t *leaf) 00077 { 00078 float64_t prob = 1; 00079 node_t *par = leaf->parent(); 00080 while (par != NULL) 00081 { 00082 if (leaf == par->left()) 00083 prob *= (1-par->data.p_right); 00084 else 00085 prob *= par->data.p_right; 00086 00087 leaf = par; 00088 par = leaf->parent(); 00089 } 00090 00091 return prob; 00092 } 00093 00094 bool CVwConditionalProbabilityTree::train_machine(CFeatures* data) 00095 { 00096 if (data) 00097 { 00098 if (data->get_feature_class() != C_STREAMING_VW) 00099 SG_ERROR("Expected StreamingVwFeatures\n"); 00100 set_features(dynamic_cast<CStreamingVwFeatures*>(data)); 00101 } 00102 else 00103 { 00104 if (!m_feats) 00105 SG_ERROR("No data features provided\n"); 00106 } 00107 00108 m_machines->reset_array(); 00109 SG_UNREF(m_root); 00110 m_root = NULL; 00111 00112 m_leaves.clear(); 00113 00114 m_feats->start_parser(); 00115 for (int32_t ipass=0; ipass < m_num_passes; ++ipass) 00116 { 00117 while (m_feats->get_next_example()) 00118 { 00119 train_example(m_feats->get_example()); 00120 m_feats->release_example(); 00121 } 00122 00123 if (ipass < m_num_passes-1) 00124 m_feats->reset_stream(); 00125 } 00126 m_feats->end_parser(); 00127 00128 return true; 00129 } 00130 00131 void CVwConditionalProbabilityTree::train_example(VwExample *ex) 00132 { 00133 int32_t label = static_cast<int32_t>(ex->ld->label); 00134 00135 if (m_root == NULL) 00136 { 00137 m_root = new node_t(); 00138 m_root->data.label = label; 00139 printf(" insert %d %p\n", label, m_root); 00140 m_leaves.insert(make_pair(label, m_root)); 00141 m_root->machine(create_machine(ex)); 00142 return; 00143 } 00144 00145 if (m_leaves.find(label) != m_leaves.end()) 00146 { 00147 train_path(ex, m_leaves[label]); 00148 } 00149 else 00150 { 00151 node_t *node = m_root; 00152 while (node->left() != NULL) 00153 { 00154 // not a leaf 00155 bool is_left = which_subtree(node, ex); 00156 if (is_left) 00157 ex->ld->label = 0; 00158 else 00159 ex->ld->label = 1; 00160 train_node(ex, node); 00161 00162 if (is_left) 00163 node = node->left(); 00164 else 00165 node = node->right(); 00166 } 00167 00168 printf(" remove %d %p\n", node->data.label, m_leaves[node->data.label]); 00169 m_leaves.erase(node->data.label); 00170 00171 node_t *left_node = new node_t(); 00172 left_node->data.label = node->data.label; 00173 node->data.label = -1; 00174 CVowpalWabbit *node_vw = dynamic_cast<CVowpalWabbit *>(m_machines->get_element(node->machine())); 00175 CVowpalWabbit *vw = new CVowpalWabbit(node_vw); 00176 SG_UNREF(node_vw); 00177 vw->set_learner(); 00178 m_machines->push_back(vw); 00179 left_node->machine(m_machines->get_num_elements()-1); 00180 printf(" insert %d %p\n", left_node->data.label, left_node); 00181 m_leaves.insert(make_pair(left_node->data.label, left_node)); 00182 node->left(left_node); 00183 00184 node_t *right_node = new node_t(); 00185 right_node->data.label = label; 00186 right_node->machine(create_machine(ex)); 00187 printf(" insert %d %p\n", label, right_node); 00188 m_leaves.insert(make_pair(label, right_node)); 00189 node->right(right_node); 00190 } 00191 } 00192 00193 void CVwConditionalProbabilityTree::train_path(VwExample *ex, node_t *node) 00194 { 00195 ex->ld->label = 0; 00196 train_node(ex, node); 00197 00198 node_t *par = node->parent(); 00199 while (par != NULL) 00200 { 00201 if (par->left() == node) 00202 ex->ld->label = 0; 00203 else 00204 ex->ld->label = 1; 00205 00206 train_node(ex, par); 00207 node = par; 00208 par = node->parent(); 00209 } 00210 } 00211 00212 float64_t CVwConditionalProbabilityTree::train_node(VwExample *ex, node_t *node) 00213 { 00214 CVowpalWabbit *vw = dynamic_cast<CVowpalWabbit*>(m_machines->get_element(node->machine())); 00215 ASSERT(vw); 00216 float64_t pred = vw->predict_and_finalize(ex); 00217 if (ex->ld->label != FLT_MAX) 00218 vw->get_learner()->train(ex, ex->eta_round); 00219 SG_UNREF(vw); 00220 return pred; 00221 } 00222 00223 int32_t CVwConditionalProbabilityTree::create_machine(VwExample *ex) 00224 { 00225 CVowpalWabbit *vw = new CVowpalWabbit(m_feats); 00226 vw->set_learner(); 00227 ex->ld->label = 0; 00228 vw->predict_and_finalize(ex); 00229 m_machines->push_back(vw); 00230 return m_machines->get_num_elements()-1; 00231 }