SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Viktor Gal 00008 * Copyright (C) 2012 Viktor Gal 00009 */ 00010 00011 #include <typeinfo> 00012 00013 #include <shogun/classifier/svm/SVMOcas.h> 00014 #include <shogun/latent/LatentSVM.h> 00015 00016 using namespace shogun; 00017 00018 CLatentSVM::CLatentSVM() 00019 : CLinearLatentMachine() 00020 { 00021 } 00022 00023 CLatentSVM::CLatentSVM(CLatentModel* model, float64_t C) 00024 : CLinearLatentMachine(model, C) 00025 { 00026 } 00027 00028 CLatentSVM::~CLatentSVM() 00029 { 00030 } 00031 00032 CLatentLabels* CLatentSVM::apply_latent() 00033 { 00034 if (!m_model) 00035 SG_ERROR("LatentModel is not set!\n"); 00036 00037 if (!features) 00038 return NULL; 00039 00040 index_t num_examples = m_model->get_num_vectors(); 00041 CLatentLabels* hs = new CLatentLabels(num_examples); 00042 CBinaryLabels* ys = new CBinaryLabels(num_examples); 00043 hs->set_labels(ys); 00044 m_model->set_labels(hs); 00045 00046 for (index_t i = 0; i < num_examples; ++i) 00047 { 00048 /* find h for the example */ 00049 CData* h = m_model->infer_latent_variable(w, i); 00050 hs->add_latent_label(h); 00051 } 00052 00053 /* compute the y labels */ 00054 CDotFeatures* x = m_model->get_psi_feature_vectors(); 00055 x->dense_dot_range(ys->get_labels().vector, 0, num_examples, NULL, w.vector, w.vlen, 0.0); 00056 00057 return hs; 00058 } 00059 00060 float64_t CLatentSVM::do_inner_loop(float64_t cooling_eps) 00061 { 00062 CLabels* ys = m_model->get_labels()->get_labels(); 00063 CSVMOcas svm(m_C, features, ys); 00064 svm.set_epsilon(cooling_eps); 00065 svm.train(); 00066 SG_UNREF(ys); 00067 00068 /* copy the resulting w */ 00069 SGVector<float64_t> cur_w = svm.get_w(); 00070 memcpy(w.vector, cur_w.vector, cur_w.vlen*sizeof(float64_t)); 00071 00072 return svm.compute_primal_objective(); 00073 } 00074