SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2011 Christian Widmer 00008 * Copyright (C) 2007-2011 Max-Planck-Society 00009 */ 00010 00011 #include <shogun/lib/config.h> 00012 00013 #ifdef HAVE_LAPACK 00014 00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVMLinear.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/base/Parameter.h> 00018 #include <shogun/labels/Labels.h> 00019 #include <shogun/labels/BinaryLabels.h> 00020 #include <shogun/labels/RegressionLabels.h> 00021 #include <iostream> 00022 #include <vector> 00023 00024 using namespace shogun; 00025 00026 00027 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear() : CLibLinear(L2R_L1LOSS_SVC_DUAL) 00028 { 00029 init(NULL, 0.0); 00030 } 00031 00032 00033 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear(float64_t C, CDotFeatures* f, CLabels* lab, CLinearMachine* pre_svm, float64_t B_param) : CLibLinear(C, f, lab) 00034 { 00035 init(pre_svm, B_param); 00036 00037 } 00038 00039 00040 CDomainAdaptationSVMLinear::~CDomainAdaptationSVMLinear() 00041 { 00042 00043 SG_UNREF(presvm); 00044 SG_DEBUG("deleting DomainAdaptationSVMLinear\n"); 00045 } 00046 00047 00048 void CDomainAdaptationSVMLinear::init(CLinearMachine* pre_svm, float64_t B_param) 00049 { 00050 00051 if (pre_svm) 00052 { 00053 // increase reference counts 00054 SG_REF(pre_svm); 00055 00056 // set bias of parent svm to zero 00057 pre_svm->set_bias(0.0); 00058 } 00059 00060 this->presvm = pre_svm; 00061 this->B = B_param; 00062 this->train_factor = 1.0; 00063 00064 set_liblinear_solver_type(L2R_L1LOSS_SVC_DUAL); 00065 00066 // invoke sanity check 00067 is_presvm_sane(); 00068 00069 // serialization code 00070 m_parameters->add((CSGObject**) &presvm, "presvm", "SVM to regularize against"); 00071 m_parameters->add(&B, "B", "Regularization strenth B."); 00072 m_parameters->add(&train_factor, "train_factor", "train_factor"); 00073 00074 } 00075 00076 00077 bool CDomainAdaptationSVMLinear::is_presvm_sane() 00078 { 00079 00080 if (!presvm) { 00081 00082 SG_WARNING("presvm is null"); 00083 00084 } else { 00085 00086 if (presvm->get_bias() != 0) { 00087 SG_ERROR("presvm bias not set to zero"); 00088 } 00089 00090 if (presvm->get_features()->get_feature_type() != this->get_features()->get_feature_type()) { 00091 SG_ERROR("feature types do not agree"); 00092 } 00093 } 00094 00095 return true; 00096 00097 } 00098 00099 00100 bool CDomainAdaptationSVMLinear::train_machine(CFeatures* train_data) 00101 { 00102 00103 CDotFeatures* tmp_data; 00104 00105 if (m_labels->get_label_type() != LT_BINARY) 00106 SG_ERROR("DomainAdaptationSVMLinear requires binary labels\n"); 00107 00108 if (train_data) 00109 { 00110 if (!train_data->has_property(FP_DOT)) 00111 SG_ERROR("DotFeatures expected\n"); 00112 00113 if (((CBinaryLabels*) m_labels)->get_num_labels() != train_data->get_num_vectors()) 00114 SG_ERROR("Number of training vectors does not match number of labels\n"); 00115 00116 tmp_data = (CDotFeatures*) train_data; 00117 } 00118 else 00119 { 00120 tmp_data = features; 00121 } 00122 00123 CBinaryLabels* labels = (CBinaryLabels*) get_labels(); 00124 int32_t num_training_points = labels->get_num_labels(); 00125 00126 std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points); 00127 00128 if (presvm) 00129 { 00130 ASSERT(presvm->get_bias() == 0.0); 00131 00132 // bias of parent SVM was set to zero in constructor, already contains B 00133 CBinaryLabels* parent_svm_out = presvm->apply_binary(tmp_data); 00134 00135 SG_DEBUG("pre-computing linear term from presvm\n"); 00136 00137 // pre-compute linear term 00138 for (int32_t i=0; i!=num_training_points; i++) 00139 { 00140 lin_term[i] = train_factor * B * labels->get_confidence(i) * parent_svm_out->get_confidence(i) - 1.0; 00141 } 00142 00143 // set linear term for QP 00144 this->set_linear_term( 00145 SGVector<float64_t>(&lin_term[0], lin_term.size())); 00146 00147 } 00148 00149 SG_UNREF(labels); 00150 00151 /* 00152 // warm-start liblinear 00153 //TODO test this code, measure speed-ups 00154 //presvm w stored in presvm 00155 float64_t* tmp_w; 00156 presvm->get_w(tmp_w, w_dim); 00157 00158 //copy vector 00159 float64_t* tmp_w_copy = SG_MALLOC(float64_t, w_dim); 00160 std::copy(tmp_w, tmp_w + w_dim, tmp_w_copy); 00161 00162 for (int32_t i=0; i!=w_dim; i++) 00163 { 00164 tmp_w_copy[i] = B * tmp_w_copy[i]; 00165 } 00166 00167 //set w (copied in setter) 00168 set_w(tmp_w_copy, w_dim); 00169 SG_FREE(tmp_w_copy); 00170 */ 00171 00172 bool success = false; 00173 00174 //train SVM 00175 if (train_data) 00176 { 00177 success = CLibLinear::train_machine(train_data); 00178 } else { 00179 success = CLibLinear::train_machine(); 00180 } 00181 00182 //ASSERT(presvm) 00183 00184 return success; 00185 00186 } 00187 00188 00189 CLinearMachine* CDomainAdaptationSVMLinear::get_presvm() 00190 { 00191 return presvm; 00192 } 00193 00194 00195 float64_t CDomainAdaptationSVMLinear::get_B() 00196 { 00197 return B; 00198 } 00199 00200 00201 float64_t CDomainAdaptationSVMLinear::get_train_factor() 00202 { 00203 return train_factor; 00204 } 00205 00206 00207 void CDomainAdaptationSVMLinear::set_train_factor(float64_t factor) 00208 { 00209 train_factor = factor; 00210 } 00211 00212 00213 CBinaryLabels* CDomainAdaptationSVMLinear::apply_binary(CFeatures* data) 00214 { 00215 ASSERT(presvm->get_bias()==0.0); 00216 00217 int32_t num_examples = data->get_num_vectors(); 00218 00219 CBinaryLabels* out_current = CLibLinear::apply_binary(data); 00220 00221 SGVector<float64_t> out_combined(num_examples); 00222 if (presvm) 00223 { 00224 // recursive call if used on DomainAdaptationSVM object 00225 CBinaryLabels* out_presvm = presvm->apply_binary(data); 00226 00227 00228 // combine outputs 00229 for (int32_t i=0; i!=num_examples; i++) 00230 out_combined[i] = out_current->get_confidence(i) + B*out_presvm->get_confidence(i); 00231 00232 SG_UNREF(out_presvm); 00233 } 00234 00235 SG_UNREF(out_current); 00236 00237 return new CBinaryLabels(out_combined); 00238 } 00239 00240 #endif //HAVE_LAPACK 00241