00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "lib/common.h"
00013 #include "kernel/HistogramWordStringKernel.h"
00014 #include "features/Features.h"
00015 #include "features/StringFeatures.h"
00016 #include "classifier/PluginEstimate.h"
00017 #include "lib/io.h"
00018
00019 using namespace shogun;
00020
00021 CHistogramWordStringKernel::CHistogramWordStringKernel(void)
00022 : CStringKernel<uint16_t>()
00023 {
00024 init();
00025 }
00026
00027 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie)
00028 : CStringKernel<uint16_t>(size)
00029 {
00030 init();
00031 SG_REF(pie);
00032 estimate=pie;
00033
00034 }
00035
00036 CHistogramWordStringKernel::CHistogramWordStringKernel(
00037 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie)
00038 : CStringKernel<uint16_t>()
00039 {
00040 init();
00041 SG_REF(pie);
00042 estimate=pie;
00043 init(l, r);
00044 }
00045
00046 CHistogramWordStringKernel::~CHistogramWordStringKernel()
00047 {
00048 SG_UNREF(estimate);
00049
00050 delete[] variance;
00051 delete[] mean;
00052 if (sqrtdiag_lhs != sqrtdiag_rhs)
00053 delete[] sqrtdiag_rhs;
00054 delete[] sqrtdiag_lhs;
00055 if (ld_mean_lhs!=ld_mean_rhs)
00056 delete[] ld_mean_rhs ;
00057 delete[] ld_mean_lhs ;
00058 if (plo_lhs!=plo_rhs)
00059 delete[] plo_rhs ;
00060 delete[] plo_lhs ;
00061 }
00062
00063 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00064 {
00065 CStringKernel<uint16_t>::init(p_l,p_r);
00066 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00067 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00068 ASSERT(l);
00069 ASSERT(r);
00070
00071 SG_DEBUG( "init: lhs: %ld rhs: %ld\n", l, r) ;
00072 int32_t i;
00073 initialized=false;
00074
00075 if (sqrtdiag_lhs != sqrtdiag_rhs)
00076 delete[] sqrtdiag_rhs;
00077 sqrtdiag_rhs=NULL ;
00078 delete[] sqrtdiag_lhs;
00079 sqrtdiag_lhs=NULL ;
00080 if (ld_mean_lhs!=ld_mean_rhs)
00081 delete[] ld_mean_rhs ;
00082 ld_mean_rhs=NULL ;
00083 delete[] ld_mean_lhs ;
00084 ld_mean_lhs=NULL ;
00085 if (plo_lhs!=plo_rhs)
00086 delete[] plo_rhs ;
00087 plo_rhs=NULL ;
00088 delete[] plo_lhs ;
00089 plo_lhs=NULL ;
00090
00091 sqrtdiag_lhs= new float64_t[l->get_num_vectors()];
00092 ld_mean_lhs = new float64_t[l->get_num_vectors()];
00093 plo_lhs = new float64_t[l->get_num_vectors()];
00094
00095 for (i=0; i<l->get_num_vectors(); i++)
00096 sqrtdiag_lhs[i]=1;
00097
00098 if (l==r)
00099 {
00100 sqrtdiag_rhs=sqrtdiag_lhs;
00101 ld_mean_rhs=ld_mean_lhs;
00102 plo_rhs=plo_lhs;
00103 }
00104 else
00105 {
00106 sqrtdiag_rhs=new float64_t[r->get_num_vectors()];
00107 for (i=0; i<r->get_num_vectors(); i++)
00108 sqrtdiag_rhs[i]=1;
00109
00110 ld_mean_rhs=new float64_t[r->get_num_vectors()];
00111 plo_rhs=new float64_t[r->get_num_vectors()];
00112 }
00113
00114 float64_t* l_plo_lhs=plo_lhs;
00115 float64_t* l_plo_rhs=plo_rhs;
00116 float64_t* l_ld_mean_lhs=ld_mean_lhs;
00117 float64_t* l_ld_mean_rhs=ld_mean_rhs;
00118
00119
00120 if (!initialized)
00121 {
00122 int32_t num_vectors=l->get_num_vectors();
00123 num_symbols=(int32_t) l->get_num_symbols();
00124 int32_t llen=l->get_vector_length(0);
00125 int32_t rlen=r->get_vector_length(0);
00126 num_params=llen*((int32_t) l->get_num_symbols());
00127 num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols());
00128
00129 if ((!estimate) || (!estimate->check_models()))
00130 {
00131 SG_ERROR( "no estimate available\n");
00132 return false ;
00133 } ;
00134 if (num_params2!=estimate->get_num_params())
00135 {
00136 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00137 return false ;
00138 } ;
00139
00140
00141 num_params2++;
00142
00143 delete[] mean;
00144 mean=new float64_t[num_params2];
00145 delete[] variance;
00146 variance=new float64_t[num_params2];
00147
00148 for (i=0; i<num_params2; i++)
00149 {
00150 mean[i]=0;
00151 variance[i]=0;
00152 }
00153
00154
00155 for (i=0; i<num_vectors; i++)
00156 {
00157 int32_t len;
00158 bool free_vec;
00159 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00160
00161 mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors;
00162
00163 for (int32_t j=0; j<len; j++)
00164 {
00165 int32_t idx=compute_index(j, vec[j]);
00166 mean[idx] += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors;
00167 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors;
00168 }
00169
00170 l->free_feature_vector(vec, i, free_vec);
00171 }
00172
00173
00174 for (i=0; i<num_vectors; i++)
00175 {
00176 int32_t len;
00177 bool free_vec;
00178 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00179
00180 variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors;
00181
00182 for (int32_t j=0; j<len; j++)
00183 {
00184 for (int32_t k=0; k<4; k++)
00185 {
00186 int32_t idx=compute_index(j, k);
00187 if (k!=vec[j])
00188 {
00189 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00190 variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors;
00191 }
00192 else
00193 {
00194 variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j)
00195 -mean[idx])/num_vectors;
00196 variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j)
00197 -mean[idx+num_params])/num_vectors;
00198 }
00199 }
00200 }
00201
00202 l->free_feature_vector(vec, i, free_vec);
00203 }
00204
00205
00206
00207 sum_m2_s2=0 ;
00208 for (i=1; i<num_params2; i++)
00209 {
00210 if (variance[i]<1e-14)
00211 variance[i]=1 ;
00212
00213
00214 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00215 } ;
00216 }
00217
00218
00219
00220
00221 for (i=0; i<l->get_num_vectors(); i++)
00222 {
00223 int32_t alen;
00224 bool free_avec;
00225 uint16_t* avec = l->get_feature_vector(i, alen, free_avec);
00226
00227 float64_t result=0 ;
00228 for (int32_t j=0; j<alen; j++)
00229 {
00230 int32_t a_idx = compute_index(j, avec[j]) ;
00231 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00232 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00233 }
00234 ld_mean_lhs[i]=result ;
00235
00236
00237 plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00238 l->free_feature_vector(avec, alen, free_avec);
00239 } ;
00240
00241 if (ld_mean_lhs!=ld_mean_rhs)
00242 {
00243
00244
00245
00246 for (i=0; i < r->get_num_vectors(); i++)
00247 {
00248 int32_t alen;
00249 bool free_avec;
00250 uint16_t* avec=r->get_feature_vector(i, alen, free_avec);
00251
00252 float64_t result=0 ;
00253 for (int32_t j=0; j<alen; j++)
00254 {
00255 int32_t a_idx = compute_index(j, avec[j]) ;
00256 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00257 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00258 }
00259 ld_mean_rhs[i]=result ;
00260
00261
00262 plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00263 r->free_feature_vector(avec, alen, free_avec);
00264 } ;
00265 } ;
00266
00267
00268
00269 this->lhs=l;
00270 this->rhs=l;
00271 plo_lhs = l_plo_lhs ;
00272 plo_rhs = l_plo_lhs ;
00273 ld_mean_lhs = l_ld_mean_lhs ;
00274 ld_mean_rhs = l_ld_mean_lhs ;
00275
00276
00277 for (i=0; i<l->get_num_vectors(); i++)
00278 {
00279 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00280
00281
00282 if (sqrtdiag_lhs[i]==0)
00283 sqrtdiag_lhs[i]=1e-16;
00284 }
00285
00286
00287
00288 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00289 {
00290 this->lhs=r;
00291 this->rhs=r;
00292 plo_lhs = l_plo_rhs ;
00293 plo_rhs = l_plo_rhs ;
00294 ld_mean_lhs = l_ld_mean_rhs ;
00295 ld_mean_rhs = l_ld_mean_rhs ;
00296
00297
00298 for (i=0; i<r->get_num_vectors(); i++)
00299 {
00300 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00301
00302
00303 if (sqrtdiag_rhs[i]==0)
00304 sqrtdiag_rhs[i]=1e-16;
00305 }
00306 }
00307
00308 this->lhs=l;
00309 this->rhs=r;
00310 plo_lhs = l_plo_lhs ;
00311 plo_rhs = l_plo_rhs ;
00312 ld_mean_lhs = l_ld_mean_lhs ;
00313 ld_mean_rhs = l_ld_mean_rhs ;
00314
00315 initialized = true ;
00316 return init_normalizer();
00317 }
00318
00319 void CHistogramWordStringKernel::cleanup()
00320 {
00321 delete[] variance;
00322 variance=NULL;
00323
00324 delete[] mean;
00325 mean=NULL;
00326
00327 if (sqrtdiag_lhs != sqrtdiag_rhs)
00328 delete[] sqrtdiag_rhs;
00329 sqrtdiag_rhs=NULL;
00330
00331 delete[] sqrtdiag_lhs;
00332 sqrtdiag_lhs=NULL;
00333
00334 if (ld_mean_lhs!=ld_mean_rhs)
00335 delete[] ld_mean_rhs ;
00336 ld_mean_rhs=NULL;
00337
00338 delete[] ld_mean_lhs ;
00339 ld_mean_lhs=NULL;
00340
00341 if (plo_lhs!=plo_rhs)
00342 delete[] plo_rhs ;
00343 plo_rhs=NULL;
00344
00345 delete[] plo_lhs ;
00346 plo_lhs=NULL;
00347
00348 num_params2=0;
00349 num_params=0;
00350 num_symbols=0;
00351 sum_m2_s2=0;
00352 initialized = false;
00353
00354 CKernel::cleanup();
00355 }
00356
00357 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00358 {
00359 int32_t alen, blen;
00360 bool free_avec, free_bvec;
00361 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00362 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00363
00364 ASSERT(alen==blen);
00365
00366 float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0];
00367 result+= sum_m2_s2 ;
00368
00369 for (int32_t i=0; i<alen; i++)
00370 {
00371 if (avec[i]==bvec[i])
00372 {
00373 int32_t a_idx = compute_index(i, avec[i]) ;
00374 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00375 result += dd*dd/variance[a_idx] ;
00376 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00377 result += dd*dd/variance[a_idx+num_params] ;
00378 } ;
00379 }
00380 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00381
00382 if (initialized)
00383 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00384
00385 #ifdef DEBUG_HWSK_COMPUTATION
00386 float64_t result2 = compute_slow(idx_a, idx_b) ;
00387 if (fabs(result - result2)>1e-10)
00388 SG_ERROR("new=%e old = %e diff = %e\n", result, result2, result - result2);
00389 #endif
00390 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00391 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00392 return result;
00393 }
00394
00395 void CHistogramWordStringKernel::init()
00396 {
00397 estimate=NULL;
00398 mean=NULL;
00399 variance=NULL;
00400
00401 sqrtdiag_lhs=NULL;
00402 sqrtdiag_rhs=NULL;
00403
00404 ld_mean_lhs=NULL;
00405 ld_mean_rhs=NULL;
00406
00407 plo_lhs=NULL;
00408 plo_rhs=NULL;
00409 num_params=0;
00410 num_params2=0;
00411
00412 num_symbols=0;
00413 sum_m2_s2=0;
00414 initialized=false;
00415
00416 m_parameters->add(&initialized, "initialized", "if kernel is initalized");
00417 m_parameters->add_vector(&plo_lhs, &num_lhs, "plo_lhs");
00418 m_parameters->add_vector(&plo_rhs, &num_rhs, "plo_rhs");
00419 m_parameters->add_vector(&ld_mean_lhs, &num_lhs, "ld_mean_lhs");
00420 m_parameters->add_vector(&ld_mean_rhs, &num_rhs, "ld_mean_rhs");
00421 m_parameters->add_vector(&sqrtdiag_lhs, &num_lhs, "sqrtdiag_lhs");
00422 m_parameters->add_vector(&sqrtdiag_rhs, &num_rhs, "sqrtdiag_rhs");
00423 m_parameters->add_vector(&mean, &num_params2, "mean");
00424 m_parameters->add_vector(&variance, &num_params2, "variance");
00425
00426 m_parameters->add((CSGObject**) &estimate, "estimate", "Plugin Estimate.");
00427 }
00428
00429 #ifdef DEBUG_HWSK_COMPUTATION
00430 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b)
00431 {
00432 int32_t alen, blen;
00433 bool free_avec, free_bvec;
00434 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00435 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00436
00437 ASSERT(alen==blen);
00438
00439 float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])*
00440 (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]);
00441 result+= sum_m2_s2 ;
00442
00443 for (int32_t i=0; i<alen; i++)
00444 {
00445 int32_t a_idx = compute_index(i, avec[i]) ;
00446 int32_t b_idx = compute_index(i, bvec[i]) ;
00447
00448 if (avec[i]==bvec[i])
00449 {
00450 float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00451 result += dd*dd/variance[a_idx] ;
00452 dd = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00453 result += dd*dd/variance[a_idx+num_params] ;
00454 } ;
00455
00456 result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00457 result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00458 result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00459 result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;
00460 }
00461
00462 if (initialized)
00463 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00464
00465 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00466 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00467 return result;
00468 }
00469
00470 #endif