SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #include <shogun/multiclass/ecoc/ECOCStrategy.h> 00012 #include <shogun/labels/BinaryLabels.h> 00013 #include <shogun/labels/MulticlassLabels.h> 00014 00015 using namespace shogun; 00016 00017 CECOCStrategy::CECOCStrategy() 00018 { 00019 init(); 00020 } 00021 00022 CECOCStrategy::CECOCStrategy(CECOCEncoder *encoder, CECOCDecoder *decoder) 00023 :m_encoder(encoder), m_decoder(decoder) 00024 { 00025 init(); 00026 } 00027 00028 void CECOCStrategy::init() 00029 { 00030 SG_REF(m_encoder); 00031 SG_REF(m_decoder); 00032 00033 SG_ADD((CSGObject **)&m_encoder, "encoder", "ECOC Encoder", MS_NOT_AVAILABLE); 00034 SG_ADD((CSGObject **)&m_decoder, "decoder", "ECOC Decoder", MS_NOT_AVAILABLE); 00035 } 00036 00037 CECOCStrategy::~CECOCStrategy() 00038 { 00039 SG_UNREF(m_encoder); 00040 SG_UNREF(m_decoder); 00041 } 00042 00043 void CECOCStrategy::train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels) 00044 { 00045 CMulticlassStrategy::train_start(orig_labels, train_labels); 00046 00047 m_codebook = m_encoder->create_codebook(m_num_classes); 00048 } 00049 00050 bool CECOCStrategy::train_has_more() 00051 { 00052 return m_train_iter < m_codebook.num_rows; 00053 } 00054 00055 SGVector<int32_t> CECOCStrategy::train_prepare_next() 00056 { 00057 SGVector<int32_t> subset(m_orig_labels->get_num_labels(), false); 00058 int32_t tot=0; 00059 for (int32_t i=0; i < m_orig_labels->get_num_labels(); ++i) 00060 { 00061 int32_t label = ((CMulticlassLabels*) m_orig_labels)->get_int_label(i); 00062 switch (m_codebook(m_train_iter, label)) 00063 { 00064 case -1: 00065 ((CBinaryLabels*) m_train_labels)->set_label(i, -1); 00066 subset[tot++]=i; 00067 break; 00068 case 1: 00069 ((CBinaryLabels*) m_train_labels)->set_label(i, 1); 00070 subset[tot++]=i; 00071 break; 00072 default: 00073 // 0 means ignore 00074 break; 00075 } 00076 } 00077 00078 CMulticlassStrategy::train_prepare_next(); 00079 return SGVector<int32_t>(subset.vector, tot, true); 00080 } 00081 00082 int32_t CECOCStrategy::decide_label(SGVector<float64_t> outputs) 00083 { 00084 return m_decoder->decide_label(outputs, m_codebook); 00085 } 00086 00087 int32_t CECOCStrategy::get_num_machines() 00088 { 00089 return m_codebook.num_cols; 00090 }