SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines
LatentModel.cpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Viktor Gal
00008  * Copyright (C) 2012 Viktor Gal
00009  */
00010 
00011 #include <shogun/latent/LatentModel.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 
00014 using namespace shogun;
00015 
00016 CLatentModel::CLatentModel()
00017     : m_features(NULL), m_labels(NULL)
00018 {
00019     register_parameters();
00020 }
00021 
00022 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels)
00023     : m_features(feats), m_labels(labels)
00024 {
00025     register_parameters();
00026     SG_REF(m_features);
00027     SG_REF(m_labels);
00028 }
00029 
00030 CLatentModel::~CLatentModel()
00031 {
00032     SG_UNREF(m_labels);
00033     SG_UNREF(m_features);
00034 }
00035 
00036 int32_t CLatentModel::get_num_vectors() const
00037 {
00038     return m_features->get_num_vectors();
00039 }
00040 
00041 void CLatentModel::set_labels(CLatentLabels* labs)
00042 {
00043     SG_UNREF(m_labels);
00044     SG_REF(labs);
00045     m_labels = labs;
00046 }
00047 
00048 CLatentLabels* CLatentModel::get_labels() const
00049 {
00050     SG_REF(m_labels);
00051     return m_labels;
00052 }
00053 
00054 void CLatentModel::set_features(CLatentFeatures* feats)
00055 {
00056     SG_UNREF(m_features);
00057     SG_REF(feats);
00058     m_features = feats;
00059 }
00060 
00061 void CLatentModel::argmax_h(const SGVector<float64_t>& w)
00062 {
00063     int32_t num = get_num_vectors();
00064     CBinaryLabels* y = CBinaryLabels::obtain_from_generic(m_labels->get_labels());
00065     ASSERT(num > 0);
00066     ASSERT(num == m_labels->get_num_labels());
00067     
00068 
00069     // argmax_h only for positive examples
00070     for (int32_t i = 0; i < num; ++i)
00071     {
00072         if (y->get_label(i) == 1)
00073         {
00074             // infer h and set it for the argmax_h <w,psi(x,h)>
00075             CData* latent_data = infer_latent_variable(w, i);
00076             m_labels->set_latent_label(i, latent_data);
00077         }
00078     }
00079 }
00080 
00081 void CLatentModel::register_parameters()
00082 {
00083     m_parameters->add((CSGObject**) &m_features, "features", "Latent features");
00084     m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels");
00085 }
00086 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation