SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2008 Gunnar Raetsch 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef __PLIF_H__ 00012 #define __PLIF_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/mathematics/Math.h> 00016 #include <shogun/structure/PlifBase.h> 00017 00018 namespace shogun 00019 { 00020 00022 enum ETransformType 00023 { 00025 T_LINEAR, 00027 T_LOG, 00029 T_LOG_PLUS1, 00031 T_LOG_PLUS3, 00033 T_LINEAR_PLUS3 00034 }; 00035 00037 class CPlif: public CPlifBase 00038 { 00039 public: 00044 CPlif(int32_t len=0); 00045 virtual ~CPlif(); 00046 00048 void init_penalty_struct_cache(); 00049 00056 float64_t lookup_penalty_svm( 00057 float64_t p_value, float64_t *d_values) const; 00058 00065 float64_t lookup_penalty( 00066 float64_t p_value, float64_t* svm_values) const; 00067 00074 float64_t lookup_penalty(int32_t p_value, float64_t* svm_values) const; 00075 00081 inline float64_t lookup(float64_t p_value) 00082 { 00083 ASSERT(use_svm == 0); 00084 return lookup_penalty(p_value, NULL); 00085 } 00086 00088 void penalty_clear_derivative(); 00089 00096 void penalty_add_derivative_svm( 00097 float64_t p_value, float64_t* svm_values, float64_t factor) ; 00098 00105 void penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor); 00106 00112 const float64_t * get_cum_derivative(int32_t & p_len) const 00113 { 00114 p_len = len; 00115 return cum_derivatives.vector; 00116 } 00117 00123 bool set_transform_type(const char *type_str); 00124 00129 const char* get_transform_type() 00130 { 00131 if (transform== T_LINEAR) 00132 return "linear"; 00133 else if (transform== T_LOG) 00134 return "log"; 00135 else if (transform== T_LOG_PLUS1) 00136 return "log(+1)"; 00137 else if (transform== T_LOG_PLUS3) 00138 return "log(+3)"; 00139 else if (transform== T_LINEAR_PLUS3) 00140 return "(+3)"; 00141 else 00142 SG_ERROR("wrong type"); 00143 return ""; 00144 } 00145 00146 00151 void set_id(int32_t p_id) 00152 { 00153 id=p_id; 00154 } 00155 00160 int32_t get_id() const 00161 { 00162 return id; 00163 } 00164 00169 int32_t get_max_id() const 00170 { 00171 return get_id(); 00172 } 00173 00178 void set_use_svm(int32_t p_use_svm) 00179 { 00180 invalidate_cache(); 00181 use_svm=p_use_svm; 00182 } 00183 00188 int32_t get_use_svm() const 00189 { 00190 return use_svm; 00191 } 00192 00197 virtual bool uses_svm_values() const 00198 { 00199 return (get_use_svm()!=0); 00200 } 00201 00206 void set_use_cache(int32_t p_use_cache) 00207 { 00208 invalidate_cache(); 00209 use_cache=p_use_cache; 00210 } 00211 00214 void invalidate_cache() 00215 { 00216 SG_FREE(cache); 00217 cache=NULL; 00218 } 00219 00224 int32_t get_use_cache() 00225 { 00226 return use_cache; 00227 } 00228 00235 void set_plif( 00236 int32_t p_len, float64_t *p_limits, float64_t* p_penalties) 00237 { 00238 ASSERT(len==p_len); 00239 00240 for (int32_t i=0; i<len; i++) 00241 { 00242 limits[i]=p_limits[i]; 00243 penalties[i]=p_penalties[i]; 00244 } 00245 00246 invalidate_cache(); 00247 penalty_clear_derivative(); 00248 } 00249 00254 void set_plif_limits(SGVector<float64_t> p_limits) 00255 { 00256 ASSERT(len==p_limits.vlen); 00257 00258 limits = p_limits; 00259 00260 invalidate_cache(); 00261 penalty_clear_derivative(); 00262 } 00263 00264 00269 void set_plif_penalty(SGVector<float64_t> p_penalties) 00270 { 00271 ASSERT(len==p_penalties.vlen); 00272 00273 penalties = p_penalties; 00274 00275 invalidate_cache(); 00276 penalty_clear_derivative(); 00277 } 00278 00283 void set_plif_length(int32_t p_len) 00284 { 00285 if (len!=p_len) 00286 { 00287 len=p_len; 00288 00289 SG_DEBUG( "set_plif len=%i\n", p_len); 00290 limits = SGVector<float64_t>(len); 00291 penalties = SGVector<float64_t>(len); 00292 cum_derivatives = SGVector<float64_t>(len); 00293 } 00294 00295 for (int32_t i=0; i<len; i++) 00296 { 00297 limits[i]=0.0; 00298 penalties[i]=0.0; 00299 cum_derivatives[i]=0.0; 00300 } 00301 00302 invalidate_cache(); 00303 penalty_clear_derivative(); 00304 } 00305 00310 float64_t* get_plif_limits() 00311 { 00312 return limits.vector; 00313 } 00314 00319 float64_t* get_plif_penalties() 00320 { 00321 return penalties.vector; 00322 } 00323 00328 inline void set_max_value(float64_t p_max_value) 00329 { 00330 max_value=p_max_value; 00331 invalidate_cache(); 00332 } 00333 00338 virtual float64_t get_max_value() const 00339 { 00340 return max_value; 00341 } 00342 00347 inline void set_min_value(float64_t p_min_value) 00348 { 00349 min_value=p_min_value; 00350 invalidate_cache(); 00351 } 00352 00357 virtual float64_t get_min_value() const 00358 { 00359 return min_value; 00360 } 00361 00366 void set_plif_name(char *p_name); 00367 00372 inline char* get_plif_name() const 00373 { 00374 if (name) 00375 return name; 00376 else 00377 { 00378 char buf[20]; 00379 sprintf(buf, "plif%i", id); 00380 //name = strdup(buf); 00381 return strdup(buf); 00382 } 00383 } 00384 00389 bool get_do_calc(); 00390 00395 void set_do_calc(bool b); 00396 00400 void get_used_svms(int32_t* num_svms, int32_t* svm_ids); 00401 00406 inline int32_t get_plif_len() 00407 { 00408 return len; 00409 } 00410 00415 virtual void list_plif() const 00416 { 00417 SG_PRINT("CPlif(min_value=%1.2f, max_value=%1.2f, use_svm=%i)\n", min_value, max_value, use_svm) ; 00418 } 00419 00425 static void delete_penalty_struct(CPlif** PEN, int32_t P); 00426 00428 inline virtual const char* get_name() const { return "Plif"; } 00429 00430 protected: 00432 int32_t len; 00434 SGVector<float64_t> limits; 00436 SGVector<float64_t> penalties; 00438 SGVector<float64_t> cum_derivatives; 00440 float64_t max_value; 00442 float64_t min_value; 00444 float64_t *cache; 00446 enum ETransformType transform; 00448 int32_t id; 00450 char * name; 00452 int32_t use_svm; 00454 bool use_cache; 00458 bool do_calc; 00459 }; 00460 } 00461 #endif