SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Viktor Gal 00008 * Copyright (C) 2012 Viktor Gal 00009 */ 00010 00011 #include <shogun/latent/LatentModel.h> 00012 #include <shogun/labels/BinaryLabels.h> 00013 00014 using namespace shogun; 00015 00016 CLatentModel::CLatentModel() 00017 : m_features(NULL), m_labels(NULL) 00018 { 00019 register_parameters(); 00020 } 00021 00022 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels) 00023 : m_features(feats), m_labels(labels) 00024 { 00025 register_parameters(); 00026 SG_REF(m_features); 00027 SG_REF(m_labels); 00028 } 00029 00030 CLatentModel::~CLatentModel() 00031 { 00032 SG_UNREF(m_labels); 00033 SG_UNREF(m_features); 00034 } 00035 00036 int32_t CLatentModel::get_num_vectors() const 00037 { 00038 return m_features->get_num_vectors(); 00039 } 00040 00041 void CLatentModel::set_labels(CLatentLabels* labs) 00042 { 00043 SG_UNREF(m_labels); 00044 SG_REF(labs); 00045 m_labels = labs; 00046 } 00047 00048 CLatentLabels* CLatentModel::get_labels() const 00049 { 00050 SG_REF(m_labels); 00051 return m_labels; 00052 } 00053 00054 void CLatentModel::set_features(CLatentFeatures* feats) 00055 { 00056 SG_UNREF(m_features); 00057 SG_REF(feats); 00058 m_features = feats; 00059 } 00060 00061 void CLatentModel::argmax_h(const SGVector<float64_t>& w) 00062 { 00063 int32_t num = get_num_vectors(); 00064 CBinaryLabels* y = CBinaryLabels::obtain_from_generic(m_labels->get_labels()); 00065 ASSERT(num > 0); 00066 ASSERT(num == m_labels->get_num_labels()); 00067 00068 00069 // argmax_h only for positive examples 00070 for (int32_t i = 0; i < num; ++i) 00071 { 00072 if (y->get_label(i) == 1) 00073 { 00074 // infer h and set it for the argmax_h <w,psi(x,h)> 00075 CData* latent_data = infer_latent_variable(w, i); 00076 m_labels->set_latent_label(i, latent_data); 00077 } 00078 } 00079 } 00080 00081 void CLatentModel::register_parameters() 00082 { 00083 m_parameters->add((CSGObject**) &m_features, "features", "Latent features"); 00084 m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels"); 00085 } 00086