SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _PLUGINESTIMATE_H___ 00012 #define _PLUGINESTIMATE_H___ 00013 00014 #include <shogun/machine/Machine.h> 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/labels/BinaryLabels.h> 00017 #include <shogun/distributions/LinearHMM.h> 00018 00019 namespace shogun 00020 { 00034 class CPluginEstimate: public CMachine 00035 { 00036 public: 00037 00039 MACHINE_PROBLEM_TYPE(PT_BINARY); 00040 00045 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10); 00046 virtual ~CPluginEstimate(); 00047 00053 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL); 00054 00059 virtual inline void set_features(CStringFeatures<uint16_t>* feat) 00060 { 00061 SG_UNREF(features); 00062 SG_REF(feat); 00063 features=feat; 00064 } 00065 00070 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; } 00071 00073 float64_t apply_one(int32_t vec_idx); 00074 00081 inline float64_t posterior_log_odds_obsolete( 00082 uint16_t* vector, int32_t len) 00083 { 00084 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len); 00085 } 00086 00093 inline float64_t get_parameterwise_log_odds( 00094 uint16_t obs, int32_t position) 00095 { 00096 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position); 00097 } 00098 00105 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos) 00106 { 00107 return pos_model->get_log_derivative_obsolete(obs, pos); 00108 } 00109 00116 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos) 00117 { 00118 return neg_model->get_log_derivative_obsolete(obs, pos); 00119 } 00120 00129 inline bool get_model_params( 00130 float64_t*& pos_params, float64_t*& neg_params, 00131 int32_t &seq_length, int32_t &num_symbols) 00132 { 00133 if ((!pos_model) || (!neg_model)) 00134 { 00135 SG_ERROR( "no model available\n"); 00136 return false; 00137 } 00138 00139 SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs(); 00140 pos_params = log_pos_trans.vector; 00141 SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs(); 00142 neg_params = log_neg_trans.vector; 00143 00144 seq_length = pos_model->get_sequence_length(); 00145 num_symbols = pos_model->get_num_symbols(); 00146 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters()); 00147 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols()); 00148 return true; 00149 } 00150 00157 inline void set_model_params( 00158 float64_t* pos_params, float64_t* neg_params, 00159 int32_t seq_length, int32_t num_symbols) 00160 { 00161 int32_t num_params; 00162 00163 SG_UNREF(pos_model); 00164 pos_model=new CLinearHMM(seq_length, num_symbols); 00165 SG_REF(pos_model); 00166 00167 00168 SG_UNREF(neg_model); 00169 neg_model=new CLinearHMM(seq_length, num_symbols); 00170 SG_REF(neg_model); 00171 00172 num_params=pos_model->get_num_model_parameters(); 00173 ASSERT(seq_length*num_symbols==num_params); 00174 ASSERT(num_params==neg_model->get_num_model_parameters()); 00175 00176 pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params)); 00177 neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params)); 00178 } 00179 00184 inline int32_t get_num_params() 00185 { 00186 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters(); 00187 } 00188 00193 inline bool check_models() 00194 { 00195 return ( (pos_model!=NULL) && (neg_model!=NULL) ); 00196 } 00197 00199 inline virtual const char* get_name() const { return "PluginEstimate"; } 00200 00201 protected: 00210 virtual bool train_machine(CFeatures* data=NULL); 00211 00212 protected: 00214 float64_t m_pos_pseudo; 00216 float64_t m_neg_pseudo; 00217 00219 CLinearHMM* pos_model; 00221 CLinearHMM* neg_model; 00222 00224 CStringFeatures<uint16_t>* features; 00225 }; 00226 } 00227 #endif