SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 00012 #include <stdio.h> 00013 #include <string.h> 00014 00015 #include <shogun/lib/config.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/structure/Plif.h> 00018 00019 //#define PLIF_DEBUG 00020 00021 using namespace shogun; 00022 00023 CPlif::CPlif(int32_t l) 00024 : CPlifBase() 00025 { 00026 limits=SGVector<float64_t>(); 00027 penalties=SGVector<float64_t>(); 00028 cum_derivatives=SGVector<float64_t>(); 00029 id=-1; 00030 transform=T_LINEAR; 00031 name=NULL; 00032 max_value=0; 00033 min_value=0; 00034 cache=NULL; 00035 use_svm=0; 00036 use_cache=false; 00037 len=0; 00038 do_calc = true; 00039 if (l>0) 00040 set_plif_length(l); 00041 } 00042 00043 CPlif::~CPlif() 00044 { 00045 SG_FREE(name); 00046 SG_FREE(cache); 00047 } 00048 00049 bool CPlif::set_transform_type(const char *type_str) 00050 { 00051 invalidate_cache(); 00052 00053 if (strcmp(type_str, "linear")==0) 00054 transform = T_LINEAR ; 00055 else if (strcmp(type_str, "")==0) 00056 transform = T_LINEAR ; 00057 else if (strcmp(type_str, "log")==0) 00058 transform = T_LOG ; 00059 else if (strcmp(type_str, "log(+1)")==0) 00060 transform = T_LOG_PLUS1 ; 00061 else if (strcmp(type_str, "log(+3)")==0) 00062 transform = T_LOG_PLUS3 ; 00063 else if (strcmp(type_str, "(+3)")==0) 00064 transform = T_LINEAR_PLUS3 ; 00065 else 00066 { 00067 SG_ERROR( "unknown transform type (%s)\n", type_str) ; 00068 return false ; 00069 } 00070 return true ; 00071 } 00072 00073 void CPlif::init_penalty_struct_cache() 00074 { 00075 if (!use_cache) 00076 return ; 00077 if (cache || use_svm) 00078 return ; 00079 if (max_value<=0) 00080 return ; 00081 00082 float64_t* local_cache=SG_MALLOC(float64_t, ((int32_t) max_value) + 2); 00083 00084 if (local_cache) 00085 { 00086 for (int32_t i=0; i<=max_value; i++) 00087 { 00088 if (i<min_value) 00089 local_cache[i] = -CMath::INFTY ; 00090 else 00091 local_cache[i] = lookup_penalty(i, NULL) ; 00092 } 00093 } 00094 this->cache=local_cache ; 00095 } 00096 00097 void CPlif::set_plif_name(char *p_name) 00098 { 00099 SG_FREE(name); 00100 name=SG_MALLOC(char, strlen(p_name)+3); 00101 strcpy(name,p_name) ; 00102 } 00103 00104 void CPlif::delete_penalty_struct(CPlif** PEN, int32_t P) 00105 { 00106 for (int32_t i=0; i<P; i++) 00107 delete PEN[i] ; 00108 SG_FREE(PEN); 00109 } 00110 00111 float64_t CPlif::lookup_penalty_svm( 00112 float64_t p_value, float64_t *d_values) const 00113 { 00114 ASSERT(use_svm>0); 00115 float64_t d_value=d_values[use_svm-1] ; 00116 #ifdef PLIF_DEBUG 00117 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ; 00118 #endif 00119 00120 if (!do_calc) 00121 return d_value; 00122 switch (transform) 00123 { 00124 case T_LINEAR: 00125 break ; 00126 case T_LOG: 00127 d_value = log(d_value) ; 00128 break ; 00129 case T_LOG_PLUS1: 00130 d_value = log(d_value+1) ; 00131 break ; 00132 case T_LOG_PLUS3: 00133 d_value = log(d_value+3) ; 00134 break ; 00135 case T_LINEAR_PLUS3: 00136 d_value = d_value+3 ; 00137 break ; 00138 default: 00139 SG_ERROR("unknown transform\n"); 00140 break ; 00141 } 00142 00143 int32_t idx = 0 ; 00144 float64_t ret ; 00145 for (int32_t i=0; i<len; i++) 00146 if (limits[i]<=d_value) 00147 idx++ ; 00148 else 00149 break ; // assume it is monotonically increasing 00150 00151 #ifdef PLIF_DEBUG 00152 SG_PRINT(" -> idx = %i ", idx) ; 00153 #endif 00154 00155 if (idx==0) 00156 ret=penalties[0] ; 00157 else if (idx==len) 00158 ret=penalties[len-1] ; 00159 else 00160 { 00161 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00162 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00163 #ifdef PLIF_DEBUG 00164 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ; 00165 #endif 00166 } 00167 #ifdef PLIF_DEBUG 00168 SG_PRINT(" -> ret=%1.3f\n", ret) ; 00169 #endif 00170 00171 return ret ; 00172 } 00173 00174 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const 00175 { 00176 if (use_svm) 00177 return lookup_penalty_svm(p_value, svm_values) ; 00178 00179 if ((p_value<min_value) || (p_value>max_value)) 00180 { 00181 //SG_PRINT("Feature:%s, %s.lookup_penalty(%i): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) ; 00182 return -CMath::INFTY ; 00183 } 00184 if (!do_calc) 00185 return p_value; 00186 if (cache!=NULL && (p_value>=0) && (p_value<=max_value)) 00187 { 00188 float64_t ret=cache[p_value] ; 00189 return ret ; 00190 } 00191 return lookup_penalty((float64_t) p_value, svm_values) ; 00192 } 00193 00194 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const 00195 { 00196 if (use_svm) 00197 return lookup_penalty_svm(p_value, svm_values) ; 00198 00199 #ifdef PLIF_DEBUG 00200 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ; 00201 #endif 00202 00203 00204 if ((p_value<min_value) || (p_value>max_value)) 00205 { 00206 //SG_PRINT("Feature:%s, %s.lookup_penalty(%f): return -inf min_value: %f, max_value: %f\n", name, get_name(), p_value, min_value, max_value) ; 00207 return -CMath::INFTY ; 00208 } 00209 00210 if (!do_calc) 00211 return p_value; 00212 00213 float64_t d_value = (float64_t) p_value ; 00214 switch (transform) 00215 { 00216 case T_LINEAR: 00217 break ; 00218 case T_LOG: 00219 d_value = log(d_value) ; 00220 break ; 00221 case T_LOG_PLUS1: 00222 d_value = log(d_value+1) ; 00223 break ; 00224 case T_LOG_PLUS3: 00225 d_value = log(d_value+3) ; 00226 break ; 00227 case T_LINEAR_PLUS3: 00228 d_value = d_value+3 ; 00229 break ; 00230 default: 00231 SG_ERROR( "unknown transform\n") ; 00232 break ; 00233 } 00234 00235 #ifdef PLIF_DEBUG 00236 SG_PRINT(" -> value = %1.4f ", d_value) ; 00237 #endif 00238 00239 int32_t idx = 0 ; 00240 float64_t ret ; 00241 for (int32_t i=0; i<len; i++) 00242 if (limits[i]<=d_value) 00243 idx++ ; 00244 else 00245 break ; // assume it is monotonically increasing 00246 00247 #ifdef PLIF_DEBUG 00248 SG_PRINT(" -> idx = %i ", idx) ; 00249 #endif 00250 00251 if (idx==0) 00252 ret=penalties[0] ; 00253 else if (idx==len) 00254 ret=penalties[len-1] ; 00255 else 00256 { 00257 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]* 00258 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ; 00259 #ifdef PLIF_DEBUG 00260 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ; 00261 #endif 00262 } 00263 //if (p_value>=30 && p_value<150) 00264 //SG_PRINT("%s %i(%i) -> %1.2f\n", PEN->name, p_value, idx, ret) ; 00265 #ifdef PLIF_DEBUG 00266 SG_PRINT(" -> ret=%1.3f\n", ret) ; 00267 #endif 00268 00269 return ret ; 00270 } 00271 00272 void CPlif::penalty_clear_derivative() 00273 { 00274 for (int32_t i=0; i<len; i++) 00275 cum_derivatives[i]=0.0 ; 00276 } 00277 00278 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor) 00279 { 00280 if (use_svm) 00281 { 00282 penalty_add_derivative_svm(p_value, svm_values, factor) ; 00283 return ; 00284 } 00285 00286 if ((p_value<min_value) || (p_value>max_value)) 00287 { 00288 return ; 00289 } 00290 float64_t d_value = (float64_t) p_value ; 00291 switch (transform) 00292 { 00293 case T_LINEAR: 00294 break ; 00295 case T_LOG: 00296 d_value = log(d_value) ; 00297 break ; 00298 case T_LOG_PLUS1: 00299 d_value = log(d_value+1) ; 00300 break ; 00301 case T_LOG_PLUS3: 00302 d_value = log(d_value+3) ; 00303 break ; 00304 case T_LINEAR_PLUS3: 00305 d_value = d_value+3 ; 00306 break ; 00307 default: 00308 SG_ERROR( "unknown transform\n") ; 00309 break ; 00310 } 00311 00312 int32_t idx = 0 ; 00313 for (int32_t i=0; i<len; i++) 00314 if (limits[i]<=d_value) 00315 idx++ ; 00316 else 00317 break ; // assume it is monotonically increasing 00318 00319 if (idx==0) 00320 cum_derivatives[0]+= factor ; 00321 else if (idx==len) 00322 cum_derivatives[len-1]+= factor ; 00323 else 00324 { 00325 cum_derivatives[idx] += factor * (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00326 cum_derivatives[idx-1]+= factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00327 } 00328 } 00329 00330 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values, float64_t factor) 00331 { 00332 ASSERT(use_svm>0); 00333 float64_t d_value=d_values[use_svm-1] ; 00334 00335 if (d_value<-1e+20) 00336 return; 00337 00338 switch (transform) 00339 { 00340 case T_LINEAR: 00341 break ; 00342 case T_LOG: 00343 d_value = log(d_value) ; 00344 break ; 00345 case T_LOG_PLUS1: 00346 d_value = log(d_value+1) ; 00347 break ; 00348 case T_LOG_PLUS3: 00349 d_value = log(d_value+3) ; 00350 break ; 00351 case T_LINEAR_PLUS3: 00352 d_value = d_value+3 ; 00353 break ; 00354 default: 00355 SG_ERROR( "unknown transform\n") ; 00356 break ; 00357 } 00358 00359 int32_t idx = 0 ; 00360 for (int32_t i=0; i<len; i++) 00361 if (limits[i]<=d_value) 00362 idx++ ; 00363 else 00364 break ; // assume it is monotonically increasing 00365 00366 if (idx==0) 00367 cum_derivatives[0]+=factor ; 00368 else if (idx==len) 00369 cum_derivatives[len-1]+=factor ; 00370 else 00371 { 00372 cum_derivatives[idx] += factor*(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ; 00373 cum_derivatives[idx-1] += factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ; 00374 } 00375 } 00376 00377 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids) 00378 { 00379 if (use_svm) 00380 { 00381 svm_ids[(*num_svms)] = use_svm; 00382 (*num_svms)++; 00383 } 00384 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type()); 00385 } 00386 00387 bool CPlif::get_do_calc() 00388 { 00389 return do_calc; 00390 } 00391 00392 void CPlif::set_do_calc(bool b) 00393 { 00394 do_calc = b;; 00395 }