SHOGUN
v2.0.0
|
00001 // -*- C++ -*- 00002 // Main functions of the LaRank algorithm for soving Multiclass SVM 00003 // Copyright (C) 2008- Antoine Bordes 00004 // Shogun specific adjustments (w) 2009 Soeren Sonnenburg 00005 00006 // This library is free software; you can redistribute it and/or 00007 // modify it under the terms of the GNU Lesser General Public 00008 // License as published by the Free Software Foundation; either 00009 // version 2.1 of the License, or (at your option) any later version. 00010 // 00011 // This program is distributed in the hope that it will be useful, 00012 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 // GNU General Public License for more details. 00015 // 00016 // You should have received a copy of the GNU General Public License 00017 // along with this program; if not, write to the Free Software 00018 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00019 // 00020 /*********************************************************************** 00021 * 00022 * LUSH Lisp Universal Shell 00023 * Copyright (C) 2002 Leon Bottou, Yann Le Cun, AT&T Corp, NECI. 00024 * Includes parts of TL3: 00025 * Copyright (C) 1987-1999 Leon Bottou and Neuristique. 00026 * Includes selected parts of SN3.2: 00027 * Copyright (C) 1991-2001 AT&T Corp. 00028 * 00029 * This program is free software; you can redistribute it and/or modify 00030 * it under the terms of the GNU General Public License as published by 00031 * the Free Software Foundation; either version 2 of the License, or 00032 * (at your option) any later version. 00033 * 00034 * This program is distributed in the hope that it will be useful, 00035 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00036 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00037 * GNU General Public License for more details. 00038 * 00039 * You should have received a copy of the GNU General Public License 00040 * along with this program; if not, write to the Free Software 00041 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00042 * 00043 ***********************************************************************/ 00044 00045 /*********************************************************************** 00046 * $Id: kcache.h,v 1.8 2007/01/25 22:42:09 leonb Exp $ 00047 **********************************************************************/ 00048 00049 #ifndef LARANK_H 00050 #define LARANK_H 00051 00052 #include <ctime> 00053 #include <vector> 00054 #include <algorithm> 00055 #include <sys/time.h> 00056 #include <set> 00057 #include <map> 00058 #define STDEXT_NAMESPACE __gnu_cxx 00059 #define std_hash_map std::map 00060 #define std_hash_set std::set 00061 00062 #include <shogun/io/SGIO.h> 00063 #include <shogun/kernel/Kernel.h> 00064 #include <shogun/multiclass/MulticlassSVM.h> 00065 00066 namespace shogun 00067 { 00068 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00069 struct larank_kcache_s; 00070 typedef struct larank_kcache_s larank_kcache_t; 00071 struct larank_kcache_s 00072 { 00073 CKernel* func; 00074 larank_kcache_t *prevbuddy; 00075 larank_kcache_t *nextbuddy; 00076 int64_t maxsize; 00077 int64_t cursize; 00078 int32_t l; 00079 int32_t *i2r; 00080 int32_t *r2i; 00081 int32_t maxrowlen; 00082 /* Rows */ 00083 int32_t *rsize; 00084 float32_t *rdiag; 00085 float32_t **rdata; 00086 int32_t *rnext; 00087 int32_t *rprev; 00088 int32_t *qnext; 00089 int32_t *qprev; 00090 }; 00091 00092 /* 00093 ** OUTPUT: one per class of the raining set, keep tracks of support 00094 * vectors and their beta coefficients 00095 */ 00096 class LaRankOutput 00097 { 00098 public: 00099 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0) 00100 { 00101 } 00102 virtual ~LaRankOutput () 00103 { 00104 destroy(); 00105 } 00106 00107 // Initializing an output class (basically creating a kernel cache for it) 00108 void initialize (CKernel* kfunc, int64_t cache); 00109 00110 // Destroying an output class (basically destroying the kernel cache) 00111 void destroy (); 00112 00113 // !Important! Computing the score of a given input vector for the actual output 00114 float64_t computeScore (int32_t x_id); 00115 00116 // !Important! Computing the gradient of a given input vector for the actual output 00117 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis); 00118 00119 // Updating the solution in the actual output 00120 void update (int32_t x_id, float64_t lambda, float64_t gp); 00121 00122 // Linking the cache of this output to the cache of an other "buddy" output 00123 // so that if a requested value is not found in this cache, you can 00124 // ask your buddy if it has it. 00125 void set_kernel_buddy (larank_kcache_t * bud); 00126 00127 // Removing useless support vectors (for which beta=0) 00128 int32_t cleanup (); 00129 00130 // --- Below are information or "get" functions --- // 00131 00132 // 00133 inline larank_kcache_t *getKernel () const 00134 { 00135 return kernel; 00136 } 00137 // 00138 inline int32_t get_l () const 00139 { 00140 return l; 00141 } 00142 00143 // 00144 float64_t getW2 (); 00145 00146 // 00147 float64_t getKii (int32_t x_id); 00148 00149 // 00150 float64_t getBeta (int32_t x_id); 00151 00152 // 00153 inline float32_t* getBetas () const 00154 { 00155 return beta; 00156 } 00157 00158 // 00159 float64_t getGradient (int32_t x_id); 00160 00161 // 00162 bool isSupportVector (int32_t x_id) const; 00163 00164 // 00165 int32_t getSV (float32_t* &sv) const; 00166 00167 private: 00168 // the solution of LaRank relative to the actual class is stored in 00169 // this parameters 00170 float32_t* beta; // Beta coefficiens 00171 float32_t* g; // Strored gradient derivatives 00172 larank_kcache_t *kernel; // Cache for kernel values 00173 int32_t l; // Number of support vectors 00174 }; 00175 00176 /* 00177 ** LARANKPATTERN: to keep track of the support patterns 00178 */ 00179 class LaRankPattern 00180 { 00181 public: 00182 LaRankPattern (int32_t x_index, int32_t label) 00183 : x_id (x_index), y (label) {} 00184 LaRankPattern () 00185 : x_id (0) {} 00186 00187 bool exists () const 00188 { 00189 return x_id >= 0; 00190 } 00191 00192 void clear () 00193 { 00194 x_id = -1; 00195 } 00196 00197 int32_t x_id; 00198 int32_t y; 00199 }; 00200 00201 /* 00202 ** LARANKPATTERNS: the collection of support patterns 00203 */ 00204 class LaRankPatterns 00205 { 00206 public: 00207 LaRankPatterns () {} 00208 ~LaRankPatterns () {} 00209 00210 void insert (const LaRankPattern & pattern) 00211 { 00212 if (!isPattern (pattern.x_id)) 00213 { 00214 if (freeidx.size ()) 00215 { 00216 std_hash_set < uint32_t >::iterator it = freeidx.begin (); 00217 patterns[*it] = pattern; 00218 x_id2rank[pattern.x_id] = *it; 00219 freeidx.erase (it); 00220 } 00221 else 00222 { 00223 patterns.push_back (pattern); 00224 x_id2rank[pattern.x_id] = patterns.size () - 1; 00225 } 00226 } 00227 else 00228 { 00229 int32_t rank = getPatternRank (pattern.x_id); 00230 patterns[rank] = pattern; 00231 } 00232 } 00233 00234 void remove (uint32_t i) 00235 { 00236 x_id2rank[patterns[i].x_id] = 0; 00237 patterns[i].clear (); 00238 freeidx.insert (i); 00239 } 00240 00241 bool empty () const 00242 { 00243 return patterns.size () == freeidx.size (); 00244 } 00245 00246 uint32_t size () const 00247 { 00248 return patterns.size () - freeidx.size (); 00249 } 00250 00251 LaRankPattern & sample () 00252 { 00253 ASSERT (!empty ()); 00254 while (true) 00255 { 00256 uint32_t r = CMath::random(uint32_t(0), uint32_t(patterns.size ()-1)); 00257 if (patterns[r].exists ()) 00258 return patterns[r]; 00259 } 00260 return patterns[0]; 00261 } 00262 00263 uint32_t getPatternRank (int32_t x_id) 00264 { 00265 return x_id2rank[x_id]; 00266 } 00267 00268 bool isPattern (int32_t x_id) 00269 { 00270 return x_id2rank[x_id] != 0; 00271 } 00272 00273 LaRankPattern & getPattern (int32_t x_id) 00274 { 00275 uint32_t rank = x_id2rank[x_id]; 00276 return patterns[rank]; 00277 } 00278 00279 uint32_t maxcount () const 00280 { 00281 return patterns.size (); 00282 } 00283 00284 LaRankPattern & operator [] (uint32_t i) 00285 { 00286 return patterns[i]; 00287 } 00288 00289 const LaRankPattern & operator [] (uint32_t i) const 00290 { 00291 return patterns[i]; 00292 } 00293 00294 private: 00295 std_hash_set < uint32_t >freeidx; 00296 std::vector < LaRankPattern > patterns; 00297 std_hash_map < int32_t, uint32_t >x_id2rank; 00298 }; 00299 00300 00301 #endif // DOXYGEN_SHOULD_SKIP_THIS 00302 00303 00307 class CLaRank: public CMulticlassSVM 00308 { 00309 public: 00310 CLaRank (); 00311 00318 CLaRank(float64_t C, CKernel* k, CLabels* lab); 00319 00320 virtual ~CLaRank (); 00321 00322 // LEARNING FUNCTION: add new patterns and run optimization steps 00323 // selected with adaptative schedule 00328 virtual int32_t add (int32_t x_id, int32_t yi); 00329 00330 // PREDICTION FUNCTION: main function in la_rank_classify 00334 virtual int32_t predict (int32_t x_id); 00335 00337 virtual void destroy (); 00338 00339 // Compute Duality gap (costly but used in stopping criteria in batch mode) 00341 virtual float64_t computeGap (); 00342 00343 // Nuber of classes so far 00345 virtual uint32_t getNumOutputs () const; 00346 00347 // Number of Support Vectors 00349 int32_t getNSV (); 00350 00351 // Norm of the parameters vector 00353 float64_t computeW2 (); 00354 00355 // Compute Dual objective value 00357 float64_t getDual (); 00358 00363 virtual inline EMachineType get_classifier_type() { return CT_LARANK; } 00364 00366 inline virtual const char* get_name() const { return "LaRank"; } 00367 00371 void set_batch_mode(bool enable) { batch_mode=enable; }; 00373 bool get_batch_mode() { return batch_mode; }; 00377 void set_tau(float64_t t) { tau=t; }; 00381 float64_t get_tau() { return tau; }; 00382 00383 protected: 00385 bool train_machine(CFeatures* data); 00386 00387 private: 00388 /* 00389 ** MAIN DARK OPTIMIZATION PROCESSES 00390 */ 00391 00392 // Hash Table used to store the different outputs 00394 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t; // class index -> LaRankOutput 00395 00397 outputhash_t outputs; 00398 00399 LaRankOutput *getOutput (int32_t index); 00400 00401 // 00402 LaRankPatterns patterns; 00403 00404 // Parameters 00405 int32_t nb_seen_examples; 00406 int32_t nb_removed; 00407 00408 // Numbers of each operation performed so far 00409 int32_t n_pro; 00410 int32_t n_rep; 00411 int32_t n_opt; 00412 00413 // Running estimates for each operations 00414 float64_t w_pro; 00415 float64_t w_rep; 00416 float64_t w_opt; 00417 00418 int32_t y0; 00419 float64_t dual; 00420 00421 struct outputgradient_t 00422 { 00423 outputgradient_t (int32_t result_output, float64_t result_gradient) 00424 : output (result_output), gradient (result_gradient) {} 00425 outputgradient_t () 00426 : output (0), gradient (0) {} 00427 00428 int32_t output; 00429 float64_t gradient; 00430 00431 bool operator < (const outputgradient_t & og) const 00432 { 00433 return gradient > og.gradient; 00434 } 00435 }; 00436 00437 //3 types of operations in LaRank 00438 enum process_type 00439 { 00440 processNew, 00441 processOld, 00442 processOptimize 00443 }; 00444 00445 struct process_return_t 00446 { 00447 process_return_t (float64_t dual, int32_t yprediction) 00448 : dual_increase (dual), ypred (yprediction) {} 00449 process_return_t () {} 00450 float64_t dual_increase; 00451 int32_t ypred; 00452 }; 00453 00454 // IMPORTANT Main SMO optimization step 00455 process_return_t process (const LaRankPattern & pattern, process_type ptype); 00456 00457 // ProcessOld 00458 float64_t reprocess (); 00459 00460 // Optimize 00461 float64_t optimize (); 00462 00463 // remove patterns and return the number of patterns that were removed 00464 uint32_t cleanup (); 00465 00466 protected: 00467 00469 std_hash_set < int32_t >classes; 00470 00472 inline uint32_t class_count () const 00473 { 00474 return classes.size (); 00475 } 00476 00478 float64_t tau; 00479 00481 int32_t nb_train; 00483 int64_t cache; 00485 bool batch_mode; 00486 00488 int32_t step; 00489 }; 00490 } 00491 #endif // LARANK_H