SHOGUN
v2.0.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <shogun/lib/common.h> 00013 #include <shogun/kernel/string/HistogramWordStringKernel.h> 00014 #include <shogun/features/Features.h> 00015 #include <shogun/features/StringFeatures.h> 00016 #include <shogun/classifier/PluginEstimate.h> 00017 #include <shogun/io/SGIO.h> 00018 00019 using namespace shogun; 00020 00021 CHistogramWordStringKernel::CHistogramWordStringKernel() 00022 : CStringKernel<uint16_t>() 00023 { 00024 init(); 00025 } 00026 00027 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie) 00028 : CStringKernel<uint16_t>(size) 00029 { 00030 init(); 00031 SG_REF(pie); 00032 estimate=pie; 00033 00034 } 00035 00036 CHistogramWordStringKernel::CHistogramWordStringKernel( 00037 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie) 00038 : CStringKernel<uint16_t>() 00039 { 00040 init(); 00041 SG_REF(pie); 00042 estimate=pie; 00043 init(l, r); 00044 } 00045 00046 CHistogramWordStringKernel::~CHistogramWordStringKernel() 00047 { 00048 SG_UNREF(estimate); 00049 00050 SG_FREE(variance); 00051 SG_FREE(mean); 00052 if (sqrtdiag_lhs != sqrtdiag_rhs) 00053 SG_FREE(sqrtdiag_rhs); 00054 SG_FREE(sqrtdiag_lhs); 00055 if (ld_mean_lhs!=ld_mean_rhs) 00056 SG_FREE(ld_mean_rhs); 00057 SG_FREE(ld_mean_lhs); 00058 if (plo_lhs!=plo_rhs) 00059 SG_FREE(plo_rhs); 00060 SG_FREE(plo_lhs); 00061 } 00062 00063 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r) 00064 { 00065 CStringKernel<uint16_t>::init(p_l,p_r); 00066 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l; 00067 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r; 00068 ASSERT(l); 00069 ASSERT(r); 00070 00071 SG_DEBUG( "init: lhs: %ld rhs: %ld\n", l, r) ; 00072 int32_t i; 00073 initialized=false; 00074 00075 if (sqrtdiag_lhs != sqrtdiag_rhs) 00076 SG_FREE(sqrtdiag_rhs); 00077 sqrtdiag_rhs=NULL ; 00078 SG_FREE(sqrtdiag_lhs); 00079 sqrtdiag_lhs=NULL ; 00080 if (ld_mean_lhs!=ld_mean_rhs) 00081 SG_FREE(ld_mean_rhs); 00082 ld_mean_rhs=NULL ; 00083 SG_FREE(ld_mean_lhs); 00084 ld_mean_lhs=NULL ; 00085 if (plo_lhs!=plo_rhs) 00086 SG_FREE(plo_rhs); 00087 plo_rhs=NULL ; 00088 SG_FREE(plo_lhs); 00089 plo_lhs=NULL ; 00090 00091 sqrtdiag_lhs= SG_MALLOC(float64_t, l->get_num_vectors()); 00092 ld_mean_lhs = SG_MALLOC(float64_t, l->get_num_vectors()); 00093 plo_lhs = SG_MALLOC(float64_t, l->get_num_vectors()); 00094 00095 for (i=0; i<l->get_num_vectors(); i++) 00096 sqrtdiag_lhs[i]=1; 00097 00098 if (l==r) 00099 { 00100 sqrtdiag_rhs=sqrtdiag_lhs; 00101 ld_mean_rhs=ld_mean_lhs; 00102 plo_rhs=plo_lhs; 00103 } 00104 else 00105 { 00106 sqrtdiag_rhs=SG_MALLOC(float64_t, r->get_num_vectors()); 00107 for (i=0; i<r->get_num_vectors(); i++) 00108 sqrtdiag_rhs[i]=1; 00109 00110 ld_mean_rhs=SG_MALLOC(float64_t, r->get_num_vectors()); 00111 plo_rhs=SG_MALLOC(float64_t, r->get_num_vectors()); 00112 } 00113 00114 float64_t* l_plo_lhs=plo_lhs; 00115 float64_t* l_plo_rhs=plo_rhs; 00116 float64_t* l_ld_mean_lhs=ld_mean_lhs; 00117 float64_t* l_ld_mean_rhs=ld_mean_rhs; 00118 00119 //from our knowledge first normalize variance to 1 and then norm=1 does the job 00120 if (!initialized) 00121 { 00122 int32_t num_vectors=l->get_num_vectors(); 00123 num_symbols=(int32_t) l->get_num_symbols(); 00124 int32_t llen=l->get_vector_length(0); 00125 int32_t rlen=r->get_vector_length(0); 00126 num_params=llen*((int32_t) l->get_num_symbols()); 00127 num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols()); 00128 00129 if ((!estimate) || (!estimate->check_models())) 00130 { 00131 SG_ERROR( "no estimate available\n"); 00132 return false ; 00133 } ; 00134 if (num_params2!=estimate->get_num_params()) 00135 { 00136 SG_ERROR( "number of parameters of estimate and feature representation do not match\n"); 00137 return false ; 00138 } ; 00139 00140 //add 1 as we have the 'bias' also in this vector 00141 num_params2++; 00142 00143 SG_FREE(mean); 00144 mean=SG_MALLOC(float64_t, num_params2); 00145 SG_FREE(variance); 00146 variance=SG_MALLOC(float64_t, num_params2); 00147 00148 for (i=0; i<num_params2; i++) 00149 { 00150 mean[i]=0; 00151 variance[i]=0; 00152 } 00153 00154 // compute mean 00155 for (i=0; i<num_vectors; i++) 00156 { 00157 int32_t len; 00158 bool free_vec; 00159 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00160 00161 mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors; 00162 00163 for (int32_t j=0; j<len; j++) 00164 { 00165 int32_t idx=compute_index(j, vec[j]); 00166 mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors; 00167 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors; 00168 } 00169 00170 l->free_feature_vector(vec, i, free_vec); 00171 } 00172 00173 // compute variance 00174 for (i=0; i<num_vectors; i++) 00175 { 00176 int32_t len; 00177 bool free_vec; 00178 uint16_t* vec=l->get_feature_vector(i, len, free_vec); 00179 00180 variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors; 00181 00182 for (int32_t j=0; j<len; j++) 00183 { 00184 for (int32_t k=0; k<4; k++) 00185 { 00186 int32_t idx=compute_index(j, k); 00187 if (k!=vec[j]) 00188 { 00189 variance[idx]+=mean[idx]*mean[idx]/num_vectors; 00190 variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors; 00191 } 00192 else 00193 { 00194 variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j) 00195 -mean[idx])/num_vectors; 00196 variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j) 00197 -mean[idx+num_params])/num_vectors; 00198 } 00199 } 00200 } 00201 00202 l->free_feature_vector(vec, i, free_vec); 00203 } 00204 00205 00206 // compute sum_i m_i^2/s_i^2 00207 sum_m2_s2=0 ; 00208 for (i=1; i<num_params2; i++) 00209 { 00210 if (variance[i]<1e-14) // then it is likely to be numerical inaccuracy 00211 variance[i]=1 ; 00212 00213 //fprintf(stderr, "%i: mean=%1.2e std=%1.2e\n", i, mean[i], std[i]) ; 00214 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ; 00215 } ; 00216 } 00217 00218 // compute sum of 00219 //result -= estimate->log_derivative_pos(avec[i], i)*mean[a_idx]/variance[a_idx] ; 00220 //result -= estimate->log_derivative_neg(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00221 for (i=0; i<l->get_num_vectors(); i++) 00222 { 00223 int32_t alen; 00224 bool free_avec; 00225 uint16_t* avec = l->get_feature_vector(i, alen, free_avec); 00226 00227 float64_t result=0 ; 00228 for (int32_t j=0; j<alen; j++) 00229 { 00230 int32_t a_idx = compute_index(j, avec[j]); 00231 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; 00232 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00233 } 00234 ld_mean_lhs[i]=result ; 00235 00236 // precompute posterior-log-odds 00237 plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; 00238 l->free_feature_vector(avec, i, free_avec); 00239 } ; 00240 00241 if (ld_mean_lhs!=ld_mean_rhs) 00242 { 00243 // compute sum of 00244 //result -= estimate->log_derivative_pos(bvec[i], i)*mean[b_idx]/variance[b_idx] ; 00245 //result -= estimate->log_derivative_neg(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ; 00246 for (i=0; i < r->get_num_vectors(); i++) 00247 { 00248 int32_t alen; 00249 bool free_avec; 00250 uint16_t* avec=r->get_feature_vector(i, alen, free_avec); 00251 00252 float64_t result=0 ; 00253 for (int32_t j=0; j<alen; j++) 00254 { 00255 int32_t a_idx = compute_index(j, avec[j]) ; 00256 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ; 00257 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00258 } 00259 ld_mean_rhs[i]=result ; 00260 00261 // precompute posterior-log-odds 00262 plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ; 00263 r->free_feature_vector(avec, i, free_avec); 00264 } ; 00265 } ; 00266 00267 //warning hacky 00268 // 00269 this->lhs=l; 00270 this->rhs=l; 00271 plo_lhs = l_plo_lhs ; 00272 plo_rhs = l_plo_lhs ; 00273 ld_mean_lhs = l_ld_mean_lhs ; 00274 ld_mean_rhs = l_ld_mean_lhs ; 00275 00276 //compute normalize to 1 values 00277 for (i=0; i<l->get_num_vectors(); i++) 00278 { 00279 sqrtdiag_lhs[i]=sqrt(compute(i,i)); 00280 00281 //trap divide by zero exception 00282 if (sqrtdiag_lhs[i]==0) 00283 sqrtdiag_lhs[i]=1e-16; 00284 } 00285 00286 // if lhs is different from rhs (train/test data) 00287 // compute also the normalization for rhs 00288 if (sqrtdiag_lhs!=sqrtdiag_rhs) 00289 { 00290 this->lhs=r; 00291 this->rhs=r; 00292 plo_lhs = l_plo_rhs ; 00293 plo_rhs = l_plo_rhs ; 00294 ld_mean_lhs = l_ld_mean_rhs ; 00295 ld_mean_rhs = l_ld_mean_rhs ; 00296 00297 //compute normalize to 1 values 00298 for (i=0; i<r->get_num_vectors(); i++) 00299 { 00300 sqrtdiag_rhs[i]=sqrt(compute(i,i)); 00301 00302 //trap divide by zero exception 00303 if (sqrtdiag_rhs[i]==0) 00304 sqrtdiag_rhs[i]=1e-16; 00305 } 00306 } 00307 00308 this->lhs=l; 00309 this->rhs=r; 00310 plo_lhs = l_plo_lhs ; 00311 plo_rhs = l_plo_rhs ; 00312 ld_mean_lhs = l_ld_mean_lhs ; 00313 ld_mean_rhs = l_ld_mean_rhs ; 00314 00315 initialized = true ; 00316 return init_normalizer(); 00317 } 00318 00319 void CHistogramWordStringKernel::cleanup() 00320 { 00321 SG_FREE(variance); 00322 variance=NULL; 00323 00324 SG_FREE(mean); 00325 mean=NULL; 00326 00327 if (sqrtdiag_lhs != sqrtdiag_rhs) 00328 SG_FREE(sqrtdiag_rhs); 00329 sqrtdiag_rhs=NULL; 00330 00331 SG_FREE(sqrtdiag_lhs); 00332 sqrtdiag_lhs=NULL; 00333 00334 if (ld_mean_lhs!=ld_mean_rhs) 00335 SG_FREE(ld_mean_rhs); 00336 ld_mean_rhs=NULL; 00337 00338 SG_FREE(ld_mean_lhs); 00339 ld_mean_lhs=NULL; 00340 00341 if (plo_lhs!=plo_rhs) 00342 SG_FREE(plo_rhs); 00343 plo_rhs=NULL; 00344 00345 SG_FREE(plo_lhs); 00346 plo_lhs=NULL; 00347 00348 num_params2=0; 00349 num_params=0; 00350 num_symbols=0; 00351 sum_m2_s2=0; 00352 initialized = false; 00353 00354 CKernel::cleanup(); 00355 } 00356 00357 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b) 00358 { 00359 int32_t alen, blen; 00360 bool free_avec, free_bvec; 00361 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec); 00362 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec); 00363 // can only deal with strings of same length 00364 ASSERT(alen==blen); 00365 00366 float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0]; 00367 result+= sum_m2_s2 ; // does not contain 0-th element 00368 00369 for (int32_t i=0; i<alen; i++) 00370 { 00371 if (avec[i]==bvec[i]) 00372 { 00373 int32_t a_idx = compute_index(i, avec[i]) ; 00374 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ; 00375 result += dd*dd/variance[a_idx] ; 00376 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ; 00377 result += dd*dd/variance[a_idx+num_params] ; 00378 } ; 00379 } 00380 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ; 00381 00382 if (initialized) 00383 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; 00384 00385 #ifdef DEBUG_HWSK_COMPUTATION 00386 float64_t result2 = compute_slow(idx_a, idx_b) ; 00387 if (fabs(result - result2)>1e-10) 00388 SG_ERROR("new=%e old = %e diff = %e\n", result, result2, result - result2); 00389 #endif 00390 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec); 00391 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec); 00392 return result; 00393 } 00394 00395 void CHistogramWordStringKernel::init() 00396 { 00397 estimate=NULL; 00398 mean=NULL; 00399 variance=NULL; 00400 00401 sqrtdiag_lhs=NULL; 00402 sqrtdiag_rhs=NULL; 00403 00404 ld_mean_lhs=NULL; 00405 ld_mean_rhs=NULL; 00406 00407 plo_lhs=NULL; 00408 plo_rhs=NULL; 00409 num_params=0; 00410 num_params2=0; 00411 00412 num_symbols=0; 00413 sum_m2_s2=0; 00414 initialized=false; 00415 00416 SG_ADD(&initialized, "initialized", "If kernel is initalized.", 00417 MS_NOT_AVAILABLE); 00418 m_parameters->add_vector(&plo_lhs, &num_lhs, "plo_lhs"); 00419 m_parameters->add_vector(&plo_rhs, &num_rhs, "plo_rhs"); 00420 m_parameters->add_vector(&ld_mean_lhs, &num_lhs, "ld_mean_lhs"); 00421 m_parameters->add_vector(&ld_mean_rhs, &num_rhs, "ld_mean_rhs"); 00422 m_parameters->add_vector(&sqrtdiag_lhs, &num_lhs, "sqrtdiag_lhs"); 00423 m_parameters->add_vector(&sqrtdiag_rhs, &num_rhs, "sqrtdiag_rhs"); 00424 m_parameters->add_vector(&mean, &num_params2, "mean"); 00425 m_parameters->add_vector(&variance, &num_params2, "variance"); 00426 00427 SG_ADD((CSGObject**) &estimate, "estimate", "Plugin Estimate.", 00428 MS_NOT_AVAILABLE); 00429 } 00430 00431 #ifdef DEBUG_HWSK_COMPUTATION 00432 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b) 00433 { 00434 int32_t alen, blen; 00435 bool free_avec, free_bvec; 00436 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec); 00437 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec); 00438 // can only deal with strings of same length 00439 ASSERT(alen==blen); 00440 00441 float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])* 00442 (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]); 00443 result+= sum_m2_s2 ; // does not contain 0-th element 00444 00445 for (int32_t i=0; i<alen; i++) 00446 { 00447 int32_t a_idx = compute_index(i, avec[i]) ; 00448 int32_t b_idx = compute_index(i, bvec[i]) ; 00449 00450 if (avec[i]==bvec[i]) 00451 { 00452 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ; 00453 result += dd*dd/variance[a_idx] ; 00454 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ; 00455 result += dd*dd/variance[a_idx+num_params] ; 00456 } ; 00457 00458 result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ; 00459 result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ; 00460 result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ; 00461 result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ; 00462 } 00463 00464 if (initialized) 00465 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ; 00466 00467 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec); 00468 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec); 00469 return result; 00470 } 00471 00472 #endif