SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Soeren Sonnenburg 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _SQRTDIAGKERNELNORMALIZER_H___ 00012 #define _SQRTDIAGKERNELNORMALIZER_H___ 00013 00014 #include <shogun/kernel/normalizer/KernelNormalizer.h> 00015 #include <shogun/kernel/string/CommWordStringKernel.h> 00016 00017 namespace shogun 00018 { 00029 class CSqrtDiagKernelNormalizer : public CKernelNormalizer 00030 { 00031 public: 00036 CSqrtDiagKernelNormalizer(bool use_opt_diag=false): CKernelNormalizer(), 00037 sqrtdiag_lhs(NULL), num_sqrtdiag_lhs(0), 00038 sqrtdiag_rhs(NULL), num_sqrtdiag_rhs(0), 00039 use_optimized_diagonal_computation(use_opt_diag) 00040 { 00041 m_parameters->add_vector(&sqrtdiag_lhs, &num_sqrtdiag_lhs, "sqrtdiag_lhs", 00042 "sqrt(K(x,x)) for left hand side examples."); 00043 m_parameters->add_vector(&sqrtdiag_rhs, &num_sqrtdiag_rhs, "sqrtdiag_rhs", 00044 "sqrt(K(x,x)) for right hand side examples."); 00045 SG_ADD(&use_optimized_diagonal_computation, 00046 "use_optimized_diagonal_computation", 00047 "flat if optimized diagonal computation is used", MS_NOT_AVAILABLE); 00048 } 00049 00051 virtual ~CSqrtDiagKernelNormalizer() 00052 { 00053 SG_FREE(sqrtdiag_lhs); 00054 SG_FREE(sqrtdiag_rhs); 00055 } 00056 00059 virtual bool init(CKernel* k) 00060 { 00061 ASSERT(k); 00062 num_sqrtdiag_lhs=k->get_num_vec_lhs(); 00063 num_sqrtdiag_rhs=k->get_num_vec_rhs(); 00064 ASSERT(num_sqrtdiag_lhs>0); 00065 ASSERT(num_sqrtdiag_rhs>0); 00066 00067 CFeatures* old_lhs=k->lhs; 00068 CFeatures* old_rhs=k->rhs; 00069 00070 k->lhs=old_lhs; 00071 k->rhs=old_lhs; 00072 bool r1=alloc_and_compute_diag(k, sqrtdiag_lhs, num_sqrtdiag_lhs); 00073 00074 k->lhs=old_rhs; 00075 k->rhs=old_rhs; 00076 bool r2=alloc_and_compute_diag(k, sqrtdiag_rhs, num_sqrtdiag_rhs); 00077 00078 k->lhs=old_lhs; 00079 k->rhs=old_rhs; 00080 00081 return r1 && r2; 00082 } 00083 00089 inline virtual float64_t normalize( 00090 float64_t value, int32_t idx_lhs, int32_t idx_rhs) 00091 { 00092 float64_t sqrt_both=sqrtdiag_lhs[idx_lhs]*sqrtdiag_rhs[idx_rhs]; 00093 return value/sqrt_both; 00094 } 00095 00100 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00101 { 00102 return value/sqrtdiag_lhs[idx_lhs]; 00103 } 00104 00109 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00110 { 00111 return value/sqrtdiag_rhs[idx_rhs]; 00112 } 00113 00114 public: 00119 bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num) 00120 { 00121 SG_FREE(v); 00122 v=SG_MALLOC(float64_t, num); 00123 00124 for (int32_t i=0; i<num; i++) 00125 { 00126 if (k->get_kernel_type() == K_COMMWORDSTRING) 00127 { 00128 if (use_optimized_diagonal_computation) 00129 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_diag(i)); 00130 else 00131 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_helper(i,i, true)); 00132 } 00133 else 00134 v[i]=sqrt(k->compute(i,i)); 00135 00136 if (v[i]==0.0) 00137 v[i]=1e-16; /* avoid divide by zero exception */ 00138 } 00139 00140 return (v!=NULL); 00141 } 00142 00144 inline virtual const char* get_name() const { return "SqrtDiagKernelNormalizer"; } 00145 00146 protected: 00148 float64_t* sqrtdiag_lhs; 00149 00151 int32_t num_sqrtdiag_lhs; 00152 00154 float64_t* sqrtdiag_rhs; 00155 00157 int32_t num_sqrtdiag_rhs; 00158 00160 bool use_optimized_diagonal_computation; 00161 }; 00162 } 00163 #endif