SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2011 Christian Widmer 00008 * Copyright (C) 2007-2011 Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 00013 #ifdef USE_SVMLIGHT 00014 00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVM.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/labels/Labels.h> 00018 #include <shogun/labels/BinaryLabels.h> 00019 #include <shogun/labels/RegressionLabels.h> 00020 #include <iostream> 00021 #include <vector> 00022 00023 using namespace shogun; 00024 00025 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight() 00026 { 00027 } 00028 00029 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab) 00030 { 00031 init(); 00032 init(pre_svm, B_param); 00033 } 00034 00035 CDomainAdaptationSVM::~CDomainAdaptationSVM() 00036 { 00037 SG_UNREF(presvm); 00038 SG_DEBUG("deleting DomainAdaptationSVM\n"); 00039 } 00040 00041 00042 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param) 00043 { 00044 // increase reference counts 00045 SG_REF(pre_svm); 00046 00047 this->presvm=pre_svm; 00048 this->B=B_param; 00049 this->train_factor=1.0; 00050 00051 // set bias of parent svm to zero 00052 this->presvm->set_bias(0.0); 00053 00054 // invoke sanity check 00055 is_presvm_sane(); 00056 } 00057 00058 bool CDomainAdaptationSVM::is_presvm_sane() 00059 { 00060 if (!presvm) { 00061 SG_ERROR("presvm is null"); 00062 } 00063 00064 if (presvm->get_num_support_vectors() == 0) { 00065 SG_ERROR("presvm has no support vectors, please train first"); 00066 } 00067 00068 if (presvm->get_bias() != 0) { 00069 SG_ERROR("presvm bias not set to zero"); 00070 } 00071 00072 if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) { 00073 SG_ERROR("kernel types do not agree"); 00074 } 00075 00076 if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) { 00077 SG_ERROR("feature types do not agree"); 00078 } 00079 00080 return true; 00081 } 00082 00083 00084 bool CDomainAdaptationSVM::train_machine(CFeatures* data) 00085 { 00086 00087 if (data) 00088 { 00089 if (m_labels->get_num_labels() != data->get_num_vectors()) 00090 SG_ERROR("Number of training vectors does not match number of labels\n"); 00091 kernel->init(data, data); 00092 } 00093 00094 if (m_labels->get_label_type() != LT_BINARY) 00095 SG_ERROR("DomainAdaptationSVM requires binary labels\n"); 00096 00097 int32_t num_training_points = get_labels()->get_num_labels(); 00098 CBinaryLabels* labels = (CBinaryLabels*) get_labels(); 00099 00100 float64_t* lin_term = SG_MALLOC(float64_t, num_training_points); 00101 00102 // grab current training features 00103 CFeatures* train_data = get_kernel()->get_lhs(); 00104 00105 // bias of parent SVM was set to zero in constructor, already contains B 00106 CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data); 00107 00108 // pre-compute linear term 00109 for (int32_t i=0; i<num_training_points; i++) 00110 { 00111 lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0; 00112 } 00113 00114 //set linear term for QP 00115 this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points)); 00116 00117 //train SVM 00118 bool success = CSVMLight::train_machine(); 00119 SG_UNREF(labels); 00120 00121 ASSERT(presvm) 00122 00123 return success; 00124 00125 } 00126 00127 00128 CSVM* CDomainAdaptationSVM::get_presvm() 00129 { 00130 SG_REF(presvm); 00131 return presvm; 00132 } 00133 00134 00135 float64_t CDomainAdaptationSVM::get_B() 00136 { 00137 return B; 00138 } 00139 00140 00141 float64_t CDomainAdaptationSVM::get_train_factor() 00142 { 00143 return train_factor; 00144 } 00145 00146 00147 void CDomainAdaptationSVM::set_train_factor(float64_t factor) 00148 { 00149 train_factor = factor; 00150 } 00151 00152 00153 CBinaryLabels* CDomainAdaptationSVM::apply_binary(CFeatures* data) 00154 { 00155 ASSERT(data); 00156 ASSERT(presvm->get_bias()==0.0); 00157 00158 int32_t num_examples = data->get_num_vectors(); 00159 00160 CBinaryLabels* out_current = CSVMLight::apply_binary(data); 00161 00162 // recursive call if used on DomainAdaptationSVM object 00163 CBinaryLabels* out_presvm = presvm->apply_binary(data); 00164 00165 // combine outputs 00166 SGVector<float64_t> out_combined(num_examples); 00167 for (int32_t i=0; i<num_examples; i++) 00168 { 00169 out_combined[i] = out_current->get_confidence(i) + B*out_presvm->get_confidence(i); 00170 } 00171 SG_UNREF(out_current); 00172 SG_UNREF(out_presvm); 00173 00174 return new CBinaryLabels(out_combined); 00175 00176 } 00177 00178 void CDomainAdaptationSVM::init() 00179 { 00180 presvm = NULL; 00181 B = 0; 00182 train_factor = 1.0; 00183 00184 m_parameters->add((CSGObject**) &presvm, "presvm", 00185 "SVM to regularize against."); 00186 m_parameters->add(&B, "B", "regularization parameter B."); 00187 m_parameters->add(&train_factor, 00188 "train_factor", "flag to switch off regularization in training."); 00189 } 00190 00191 #endif //USE_SVMLIGHT