SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00013 #define _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/kernel/string/StringKernel.h> 00017 #include <shogun/kernel/string/WeightedDegreeStringKernel.h> 00018 #include <shogun/lib/Trie.h> 00019 00020 namespace shogun 00021 { 00022 00023 class CSVM; 00024 00048 class CWeightedDegreePositionStringKernel: public CStringKernel<char> 00049 { 00050 public: 00052 CWeightedDegreePositionStringKernel(); 00053 00061 CWeightedDegreePositionStringKernel( 00062 int32_t size, int32_t degree, 00063 int32_t max_mismatch=0, int32_t mkl_stepsize=1); 00064 00075 CWeightedDegreePositionStringKernel( 00076 int32_t size, float64_t* weights, int32_t degree, 00077 int32_t max_mismatch, int32_t* shift, int32_t shift_len, 00078 int32_t mkl_stepsize=1); 00079 00086 CWeightedDegreePositionStringKernel( 00087 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t degree); 00088 00089 virtual ~CWeightedDegreePositionStringKernel(); 00090 00097 virtual bool init(CFeatures* l, CFeatures* r); 00098 00100 virtual void cleanup(); 00101 00106 virtual EKernelType get_kernel_type() { return K_WEIGHTEDDEGREEPOS; } 00107 00112 virtual const char* get_name() const { return "WeightedDegreePositionStringKernel"; } 00113 00121 inline virtual bool init_optimization( 00122 int32_t p_count, int32_t *IDX, float64_t * alphas) 00123 { 00124 return init_optimization(p_count, IDX, alphas, -1); 00125 } 00126 00138 virtual bool init_optimization( 00139 int32_t count, int32_t *IDX, float64_t * alphas, int32_t tree_num, 00140 int32_t upto_tree=-1); 00141 00146 virtual bool delete_optimization(); 00147 00153 inline virtual float64_t compute_optimized(int32_t idx) 00154 { 00155 ASSERT(get_is_initialized()); 00156 ASSERT(alphabet); 00157 ASSERT(alphabet->get_alphabet()==DNA || alphabet->get_alphabet()==RNA); 00158 return compute_by_tree(idx); 00159 } 00160 00165 static void* compute_batch_helper(void* p); 00166 00177 virtual void compute_batch( 00178 int32_t num_vec, int32_t* vec_idx, float64_t* target, 00179 int32_t num_suppvec, int32_t* IDX, float64_t* alphas, 00180 float64_t factor=1.0); 00181 00185 inline virtual void clear_normal() 00186 { 00187 if ((opt_type==FASTBUTMEMHUNGRY) && (tries.get_use_compact_terminal_nodes())) 00188 { 00189 tries.set_use_compact_terminal_nodes(false) ; 00190 SG_DEBUG( "disabling compact trie nodes with FASTBUTMEMHUNGRY\n") ; 00191 } 00192 00193 if (get_is_initialized()) 00194 { 00195 if (opt_type==SLOWBUTMEMEFFICIENT) 00196 tries.delete_trees(true); 00197 else if (opt_type==FASTBUTMEMHUNGRY) 00198 tries.delete_trees(false); // still buggy 00199 else 00200 SG_ERROR( "unknown optimization type\n"); 00201 00202 set_is_initialized(false); 00203 } 00204 } 00205 00211 inline virtual void add_to_normal(int32_t idx, float64_t weight) 00212 { 00213 add_example_to_tree(idx, weight); 00214 set_is_initialized(true); 00215 } 00216 00221 inline virtual int32_t get_num_subkernels() 00222 { 00223 if (position_weights!=NULL) 00224 return (int32_t) ceil(1.0*seq_length/mkl_stepsize) ; 00225 if (length==0) 00226 return (int32_t) ceil(1.0*get_degree()/mkl_stepsize); 00227 return (int32_t) ceil(1.0*get_degree()*length/mkl_stepsize) ; 00228 } 00229 00235 inline void compute_by_subkernel( 00236 int32_t idx, float64_t * subkernel_contrib) 00237 { 00238 if (get_is_initialized()) 00239 { 00240 compute_by_tree(idx, subkernel_contrib); 00241 return ; 00242 } 00243 00244 SG_ERROR( "CWeightedDegreePositionStringKernel optimization not initialized\n") ; 00245 } 00246 00252 inline const float64_t* get_subkernel_weights(int32_t& num_weights) 00253 { 00254 num_weights = get_num_subkernels() ; 00255 00256 SG_FREE(weights_buffer); 00257 weights_buffer = SG_MALLOC(float64_t, num_weights); 00258 00259 if (position_weights!=NULL) 00260 for (int32_t i=0; i<num_weights; i++) 00261 weights_buffer[i] = position_weights[i*mkl_stepsize] ; 00262 else 00263 for (int32_t i=0; i<num_weights; i++) 00264 weights_buffer[i] = weights[i*mkl_stepsize] ; 00265 00266 return weights_buffer ; 00267 } 00268 00273 virtual void set_subkernel_weights(SGVector<float64_t> w) 00274 { 00275 float64_t* weights2=w.vector; 00276 int32_t num_weights2=w.vlen; 00277 00278 int32_t num_weights = get_num_subkernels() ; 00279 if (num_weights!=num_weights2) 00280 SG_ERROR( "number of weights do not match\n") ; 00281 00282 if (position_weights!=NULL) 00283 for (int32_t i=0; i<num_weights; i++) 00284 for (int32_t j=0; j<mkl_stepsize; j++) 00285 { 00286 if (i*mkl_stepsize+j<seq_length) 00287 position_weights[i*mkl_stepsize+j] = weights2[i] ; 00288 } 00289 else if (length==0) 00290 { 00291 for (int32_t i=0; i<num_weights; i++) 00292 for (int32_t j=0; j<mkl_stepsize; j++) 00293 if (i*mkl_stepsize+j<get_degree()) 00294 weights[i*mkl_stepsize+j] = weights2[i] ; 00295 } 00296 else 00297 { 00298 for (int32_t i=0; i<num_weights; i++) 00299 for (int32_t j=0; j<mkl_stepsize; j++) 00300 if (i*mkl_stepsize+j<get_degree()*length) 00301 weights[i*mkl_stepsize+j] = weights2[i] ; 00302 } 00303 } 00304 00305 // other kernel tree operations 00311 float64_t* compute_abs_weights(int32_t & len); 00312 00317 bool is_tree_initialized() { return tree_initialized; } 00318 00323 inline int32_t get_max_mismatch() { return max_mismatch; } 00324 00329 inline int32_t get_degree() { return degree; } 00330 00336 inline float64_t *get_degree_weights(int32_t& d, int32_t& len) 00337 { 00338 d=degree; 00339 len=length; 00340 return weights; 00341 } 00342 00348 inline float64_t *get_weights(int32_t& num_weights) 00349 { 00350 if (position_weights!=NULL) 00351 { 00352 num_weights = seq_length ; 00353 return position_weights ; 00354 } 00355 if (length==0) 00356 num_weights = degree ; 00357 else 00358 num_weights = degree*length ; 00359 return weights; 00360 } 00361 00367 inline float64_t *get_position_weights(int32_t& len) 00368 { 00369 len=seq_length; 00370 return position_weights; 00371 } 00372 00377 void set_shifts(SGVector<int32_t> shifts); 00378 00383 bool set_weights(SGMatrix<float64_t> new_weights); 00384 00389 virtual bool set_wd_weights(); 00390 00396 virtual void set_position_weights(SGVector<float64_t> pws); 00397 00405 bool set_position_weights_lhs(float64_t* pws, int32_t len, int32_t num); 00406 00414 bool set_position_weights_rhs(float64_t* pws, int32_t len, int32_t num); 00415 00420 bool init_block_weights(); 00421 00426 bool init_block_weights_from_wd(); 00427 00432 bool init_block_weights_from_wd_external(); 00433 00438 bool init_block_weights_const(); 00439 00444 bool init_block_weights_linear(); 00445 00450 bool init_block_weights_sqpoly(); 00451 00456 bool init_block_weights_cubicpoly(); 00457 00462 bool init_block_weights_exp(); 00463 00468 bool init_block_weights_log(); 00469 00474 bool delete_position_weights() 00475 { 00476 SG_FREE(position_weights); 00477 position_weights=NULL; 00478 return true; 00479 } 00480 00485 bool delete_position_weights_lhs() 00486 { 00487 SG_FREE(position_weights_lhs); 00488 position_weights_lhs=NULL; 00489 return true; 00490 } 00491 00496 bool delete_position_weights_rhs() 00497 { 00498 SG_FREE(position_weights_rhs); 00499 position_weights_rhs=NULL; 00500 return true; 00501 } 00502 00508 virtual float64_t compute_by_tree(int32_t idx); 00509 00515 virtual void compute_by_tree(int32_t idx, float64_t* LevelContrib); 00516 00529 float64_t* compute_scoring( 00530 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00531 float64_t* target, int32_t num_suppvec, int32_t* IDX, 00532 float64_t* weights); 00533 00542 char* compute_consensus( 00543 int32_t &num_feat, int32_t num_suppvec, int32_t* IDX, 00544 float64_t* alphas); 00545 00557 float64_t* extract_w( 00558 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00559 float64_t* w_result, int32_t num_suppvec, int32_t* IDX, 00560 float64_t* alphas); 00561 00574 float64_t* compute_POIM( 00575 int32_t max_degree, int32_t& num_feat, int32_t& num_sym, 00576 float64_t* poim_result, int32_t num_suppvec, int32_t* IDX, 00577 float64_t* alphas, float64_t* distrib); 00578 00583 void prepare_POIM2(SGMatrix<float64_t> distrib); 00584 00591 void compute_POIM2(int32_t max_degree, CSVM* svm); 00592 00597 SGVector<float64_t> get_POIM2(); 00598 00600 void cleanup_POIM2(); 00601 00602 protected: 00604 void create_empty_tries(); 00605 00611 virtual void add_example_to_tree( 00612 int32_t idx, float64_t weight); 00613 00620 void add_example_to_single_tree( 00621 int32_t idx, float64_t weight, int32_t tree_num); 00622 00631 virtual float64_t compute(int32_t idx_a, int32_t idx_b); 00632 00641 float64_t compute_with_mismatch( 00642 char* avec, int32_t alen, char* bvec, int32_t blen); 00643 00652 float64_t compute_without_mismatch( 00653 char* avec, int32_t alen, char* bvec, int32_t blen); 00654 00663 float64_t compute_without_mismatch_matrix( 00664 char* avec, int32_t alen, char* bvec, int32_t blen); 00665 00676 float64_t compute_without_mismatch_position_weights( 00677 char* avec, float64_t *posweights_lhs, int32_t alen, 00678 char* bvec, float64_t *posweights_rhs, int32_t blen); 00679 00681 virtual void remove_lhs(); 00682 00691 virtual void load_serializable_post() throw (ShogunException); 00692 00693 private: 00696 void init(); 00697 00698 protected: 00700 float64_t* weights; 00702 int32_t weights_degree; 00704 int32_t weights_length; 00705 00707 float64_t* position_weights; 00709 int32_t position_weights_len; 00710 00712 float64_t* position_weights_lhs; 00714 int32_t position_weights_lhs_len; 00716 float64_t* position_weights_rhs; 00718 int32_t position_weights_rhs_len; 00720 bool* position_mask; 00721 00723 float64_t* weights_buffer; 00725 int32_t mkl_stepsize; 00726 00728 int32_t degree; 00730 int32_t length; 00731 00733 int32_t max_mismatch; 00735 int32_t seq_length; 00736 00738 int32_t *shift; 00740 int32_t shift_len; 00742 int32_t max_shift; 00743 00745 bool block_computation; 00746 00748 float64_t* block_weights; 00750 EWDKernType type; 00752 int32_t which_degree; 00753 00755 CTrie<DNATrie> tries; 00757 CTrie<POIMTrie> poim_tries; 00758 00760 bool tree_initialized; 00762 bool use_poim_tries; 00763 00765 float64_t* m_poim_distrib; 00767 float64_t* m_poim; 00768 00770 int32_t m_poim_num_sym; 00772 int32_t m_poim_num_feat; 00774 int32_t m_poim_result_len; 00775 00777 CAlphabet* alphabet; 00778 }; 00779 } 00780 #endif /* _WEIGHTEDDEGREEPOSITIONSTRINGKERNEL_H__ */