SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Fernando José Iglesias García 00008 * Copyright (C) 2012 Fernando José Iglesias García 00009 */ 00010 00011 #ifndef _STRUCTURED_MODEL__H__ 00012 #define _STRUCTURED_MODEL__H__ 00013 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/features/Features.h> 00016 #include <shogun/labels/StructuredLabels.h> 00017 #include <shogun/lib/SGVector.h> 00018 #include <shogun/lib/StructuredData.h> 00019 00020 namespace shogun 00021 { 00022 00023 #define IGNORE_IN_CLASSLIST 00024 00029 IGNORE_IN_CLASSLIST struct TMultipleCPinfo { 00031 uint32_t _from; 00033 uint32_t N; 00034 }; 00035 00036 class CStructuredModel; 00037 00039 struct CResultSet : public CSGObject 00040 { 00042 CResultSet() : CSGObject(), argmax(NULL) { }; 00043 00045 virtual ~CResultSet() { SG_UNREF(argmax) } 00046 00048 CStructuredData* argmax; 00049 00051 SGVector< float64_t > psi_truth; 00052 00054 SGVector< float64_t > psi_pred; 00055 00058 float64_t score; 00059 00061 float64_t delta; 00062 00064 virtual const char* get_name() const { return "ResultSet"; } 00065 }; 00066 00077 class CStructuredModel : public CSGObject 00078 { 00079 public: 00081 CStructuredModel(); 00082 00088 CStructuredModel(CFeatures* features, CStructuredLabels* labels); 00089 00091 virtual ~CStructuredModel(); 00092 00103 virtual void init_opt( 00104 SGMatrix< float64_t > & A, SGVector< float64_t > a, 00105 SGMatrix< float64_t > B, SGVector< float64_t > & b, 00106 SGVector< float64_t > lb, SGVector< float64_t > ub, 00107 SGMatrix < float64_t > & C); 00108 00113 virtual int32_t get_dim() const = 0; 00114 00119 void set_labels(CStructuredLabels* labs); 00120 00125 CStructuredLabels* get_labels(); 00126 00131 void set_features(CFeatures* feats); 00132 00137 CFeatures* get_features(); 00138 00151 SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx); 00152 00165 virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData* y); 00166 00180 virtual CResultSet* argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training = true) = 0; 00181 00189 float64_t delta_loss(int32_t ytrue_idx, CStructuredData* ypred); 00190 00198 virtual float64_t delta_loss(CStructuredData* y1, CStructuredData* y2); 00199 00201 virtual const char* get_name() const { return "StructuredModel"; } 00202 00210 virtual bool check_training_setup() const; 00211 00221 virtual int32_t get_num_aux() const; 00222 00232 virtual int32_t get_num_aux_con() const; 00233 00241 virtual float64_t risk(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0); 00242 00243 private: 00245 void init(); 00246 00247 protected: 00249 CStructuredLabels* m_labels; 00250 00252 CFeatures* m_features; 00253 00254 }; /* class CStructuredModel */ 00255 00256 } /* namespace shogun */ 00257 00258 #endif /* _STRUCTURED_MODEL__H__ */