HistogramWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <shogun/lib/common.h>
00013 #include <shogun/kernel/HistogramWordStringKernel.h>
00014 #include <shogun/features/Features.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/classifier/PluginEstimate.h>
00017 #include <shogun/io/SGIO.h>
00018 
00019 using namespace shogun;
00020 
00021 CHistogramWordStringKernel::CHistogramWordStringKernel()
00022 : CStringKernel<uint16_t>()
00023 {
00024     init();
00025 }
00026 
00027 CHistogramWordStringKernel::CHistogramWordStringKernel(int32_t size, CPluginEstimate* pie)
00028 : CStringKernel<uint16_t>(size)
00029 {
00030     init();
00031     SG_REF(pie);
00032     estimate=pie;
00033 
00034 }
00035 
00036 CHistogramWordStringKernel::CHistogramWordStringKernel(
00037     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, CPluginEstimate* pie)
00038 : CStringKernel<uint16_t>()
00039 {
00040     init();
00041     SG_REF(pie);
00042     estimate=pie;
00043     init(l, r);
00044 }
00045 
00046 CHistogramWordStringKernel::~CHistogramWordStringKernel()
00047 {
00048     SG_UNREF(estimate);
00049 
00050     SG_FREE(variance);
00051     SG_FREE(mean);
00052     if (sqrtdiag_lhs != sqrtdiag_rhs)
00053         SG_FREE(sqrtdiag_rhs);
00054     SG_FREE(sqrtdiag_lhs);
00055     if (ld_mean_lhs!=ld_mean_rhs)
00056         SG_FREE(ld_mean_rhs);
00057     SG_FREE(ld_mean_lhs);
00058     if (plo_lhs!=plo_rhs)
00059         SG_FREE(plo_rhs);
00060     SG_FREE(plo_lhs);
00061 }
00062 
00063 bool CHistogramWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00064 {
00065     CStringKernel<uint16_t>::init(p_l,p_r);
00066     CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00067     CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00068     ASSERT(l);
00069     ASSERT(r);
00070 
00071     SG_DEBUG( "init: lhs: %ld   rhs: %ld\n", l, r) ;
00072     int32_t i;
00073     initialized=false;
00074 
00075     if (sqrtdiag_lhs != sqrtdiag_rhs)
00076         SG_FREE(sqrtdiag_rhs);
00077     sqrtdiag_rhs=NULL ;
00078     SG_FREE(sqrtdiag_lhs);
00079     sqrtdiag_lhs=NULL ;
00080     if (ld_mean_lhs!=ld_mean_rhs)
00081         SG_FREE(ld_mean_rhs);
00082     ld_mean_rhs=NULL ;
00083     SG_FREE(ld_mean_lhs);
00084     ld_mean_lhs=NULL ;
00085     if (plo_lhs!=plo_rhs)
00086         SG_FREE(plo_rhs);
00087     plo_rhs=NULL ;
00088     SG_FREE(plo_lhs);
00089     plo_lhs=NULL ;
00090 
00091     sqrtdiag_lhs= SG_MALLOC(float64_t, l->get_num_vectors());
00092     ld_mean_lhs = SG_MALLOC(float64_t, l->get_num_vectors());
00093     plo_lhs     = SG_MALLOC(float64_t, l->get_num_vectors());
00094 
00095     for (i=0; i<l->get_num_vectors(); i++)
00096         sqrtdiag_lhs[i]=1;
00097 
00098     if (l==r)
00099     {
00100         sqrtdiag_rhs=sqrtdiag_lhs;
00101         ld_mean_rhs=ld_mean_lhs;
00102         plo_rhs=plo_lhs;
00103     }
00104     else
00105     {
00106         sqrtdiag_rhs=SG_MALLOC(float64_t, r->get_num_vectors());
00107         for (i=0; i<r->get_num_vectors(); i++)
00108             sqrtdiag_rhs[i]=1;
00109 
00110         ld_mean_rhs=SG_MALLOC(float64_t, r->get_num_vectors());
00111         plo_rhs=SG_MALLOC(float64_t, r->get_num_vectors());
00112     }
00113 
00114     float64_t* l_plo_lhs=plo_lhs;
00115     float64_t* l_plo_rhs=plo_rhs;
00116     float64_t* l_ld_mean_lhs=ld_mean_lhs;
00117     float64_t* l_ld_mean_rhs=ld_mean_rhs;
00118     
00119     //from our knowledge first normalize variance to 1 and then norm=1 does the job
00120     if (!initialized)
00121     {
00122         int32_t num_vectors=l->get_num_vectors();
00123         num_symbols=(int32_t) l->get_num_symbols();
00124         int32_t llen=l->get_vector_length(0);
00125         int32_t rlen=r->get_vector_length(0);
00126         num_params=llen*((int32_t) l->get_num_symbols());
00127         num_params2=llen*((int32_t) l->get_num_symbols())+rlen*((int32_t) r->get_num_symbols());
00128 
00129         if ((!estimate) || (!estimate->check_models()))
00130         {
00131             SG_ERROR( "no estimate available\n");
00132             return false ;
00133         } ;
00134         if (num_params2!=estimate->get_num_params())
00135         {
00136             SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00137             return false ;
00138         } ;
00139 
00140         //add 1 as we have the 'bias' also in this vector
00141         num_params2++;
00142 
00143         SG_FREE(mean);
00144         mean=SG_MALLOC(float64_t, num_params2);
00145         SG_FREE(variance);
00146         variance=SG_MALLOC(float64_t, num_params2);
00147 
00148         for (i=0; i<num_params2; i++)
00149         {
00150             mean[i]=0;
00151             variance[i]=0;
00152         }
00153 
00154         // compute mean
00155         for (i=0; i<num_vectors; i++)
00156         {
00157             int32_t len;
00158             bool free_vec;
00159             uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00160 
00161             mean[0]+=estimate->posterior_log_odds_obsolete(vec, len)/num_vectors;
00162 
00163             for (int32_t j=0; j<len; j++)
00164             {
00165                 int32_t idx=compute_index(j, vec[j]);
00166                 mean[idx]             += estimate->log_derivative_pos_obsolete(vec[j], j)/num_vectors;
00167                 mean[idx+num_params] += estimate->log_derivative_neg_obsolete(vec[j], j)/num_vectors;
00168             }
00169 
00170             l->free_feature_vector(vec, i, free_vec);
00171         }
00172 
00173         // compute variance
00174         for (i=0; i<num_vectors; i++)
00175         {
00176             int32_t len;
00177             bool free_vec;
00178             uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00179 
00180             variance[0] += CMath::sq(estimate->posterior_log_odds_obsolete(vec, len)-mean[0])/num_vectors;
00181 
00182             for (int32_t j=0; j<len; j++)
00183             {
00184                 for (int32_t k=0; k<4; k++)
00185                 {
00186                     int32_t idx=compute_index(j, k);
00187                     if (k!=vec[j])
00188                     {
00189                         variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00190                         variance[idx+num_params]+=mean[idx+num_params]*mean[idx+num_params]/num_vectors;
00191                     }
00192                     else
00193                     {
00194                         variance[idx] += CMath::sq(estimate->log_derivative_pos_obsolete(vec[j], j)
00195                                 -mean[idx])/num_vectors;
00196                         variance[idx+num_params] += CMath::sq(estimate->log_derivative_neg_obsolete(vec[j], j)
00197                                 -mean[idx+num_params])/num_vectors;
00198                     }
00199                 }
00200             }
00201 
00202             l->free_feature_vector(vec, i, free_vec);
00203         }
00204 
00205 
00206         // compute sum_i m_i^2/s_i^2
00207         sum_m2_s2=0 ;
00208         for (i=1; i<num_params2; i++)
00209         {
00210             if (variance[i]<1e-14) // then it is likely to be numerical inaccuracy
00211                 variance[i]=1 ;
00212 
00213             //fprintf(stderr, "%i: mean=%1.2e  std=%1.2e\n", i, mean[i], std[i]) ;
00214             sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00215         } ;
00216     } 
00217 
00218     // compute sum of 
00219     //result -= estimate->log_derivative_pos(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00220     //result -= estimate->log_derivative_neg(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00221     for (i=0; i<l->get_num_vectors(); i++)
00222     {
00223         int32_t alen;
00224         bool free_avec;
00225         uint16_t* avec = l->get_feature_vector(i, alen, free_avec);
00226 
00227         float64_t  result=0 ;
00228         for (int32_t j=0; j<alen; j++)
00229         {
00230             int32_t a_idx = compute_index(j, avec[j]);
00231             result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00232             result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00233         }
00234         ld_mean_lhs[i]=result ;
00235 
00236         // precompute posterior-log-odds
00237         plo_lhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00238         l->free_feature_vector(avec, i, free_avec);
00239     } ;
00240 
00241     if (ld_mean_lhs!=ld_mean_rhs)
00242     {
00243         // compute sum of 
00244         //result -= estimate->log_derivative_pos(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00245         //result -= estimate->log_derivative_neg(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;    
00246         for (i=0; i < r->get_num_vectors(); i++)
00247         {
00248             int32_t alen;
00249             bool free_avec;
00250             uint16_t* avec=r->get_feature_vector(i, alen, free_avec);
00251 
00252             float64_t  result=0 ;
00253             for (int32_t j=0; j<alen; j++)
00254             {
00255                 int32_t a_idx = compute_index(j, avec[j]) ;
00256                 result -= estimate->log_derivative_pos_obsolete(avec[j], j)*mean[a_idx]/variance[a_idx] ;
00257                 result -= estimate->log_derivative_neg_obsolete(avec[j], j)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00258             }
00259             ld_mean_rhs[i]=result ;
00260 
00261             // precompute posterior-log-odds
00262             plo_rhs[i] = estimate->posterior_log_odds_obsolete(avec, alen)-mean[0] ;
00263             r->free_feature_vector(avec, i, free_avec);
00264         } ;
00265     } ;
00266 
00267     //warning hacky
00268     //
00269     this->lhs=l;
00270     this->rhs=l;
00271     plo_lhs = l_plo_lhs ;
00272     plo_rhs = l_plo_lhs ;
00273     ld_mean_lhs = l_ld_mean_lhs ;
00274     ld_mean_rhs = l_ld_mean_lhs ;
00275 
00276     //compute normalize to 1 values
00277     for (i=0; i<l->get_num_vectors(); i++)
00278     {
00279         sqrtdiag_lhs[i]=sqrt(compute(i,i));
00280 
00281         //trap divide by zero exception
00282         if (sqrtdiag_lhs[i]==0)
00283             sqrtdiag_lhs[i]=1e-16;
00284     }
00285 
00286     // if lhs is different from rhs (train/test data)
00287     // compute also the normalization for rhs
00288     if (sqrtdiag_lhs!=sqrtdiag_rhs)
00289     {
00290         this->lhs=r;
00291         this->rhs=r;
00292         plo_lhs = l_plo_rhs ;
00293         plo_rhs = l_plo_rhs ;
00294         ld_mean_lhs = l_ld_mean_rhs ;
00295         ld_mean_rhs = l_ld_mean_rhs ;
00296 
00297         //compute normalize to 1 values
00298         for (i=0; i<r->get_num_vectors(); i++)
00299         {
00300             sqrtdiag_rhs[i]=sqrt(compute(i,i));
00301 
00302             //trap divide by zero exception
00303             if (sqrtdiag_rhs[i]==0)
00304                 sqrtdiag_rhs[i]=1e-16;
00305         }
00306     }
00307 
00308     this->lhs=l;
00309     this->rhs=r;
00310     plo_lhs = l_plo_lhs ;
00311     plo_rhs = l_plo_rhs ;
00312     ld_mean_lhs = l_ld_mean_lhs ;
00313     ld_mean_rhs = l_ld_mean_rhs ;
00314 
00315     initialized = true ;
00316     return init_normalizer();
00317 }
00318 
00319 void CHistogramWordStringKernel::cleanup()
00320 {
00321     SG_FREE(variance);
00322     variance=NULL;
00323 
00324     SG_FREE(mean);
00325     mean=NULL;
00326 
00327     if (sqrtdiag_lhs != sqrtdiag_rhs)
00328         SG_FREE(sqrtdiag_rhs);
00329     sqrtdiag_rhs=NULL;
00330 
00331     SG_FREE(sqrtdiag_lhs);
00332     sqrtdiag_lhs=NULL;
00333 
00334     if (ld_mean_lhs!=ld_mean_rhs)
00335         SG_FREE(ld_mean_rhs);
00336     ld_mean_rhs=NULL;
00337 
00338     SG_FREE(ld_mean_lhs);
00339     ld_mean_lhs=NULL;
00340 
00341     if (plo_lhs!=plo_rhs)
00342         SG_FREE(plo_rhs);
00343     plo_rhs=NULL;
00344 
00345     SG_FREE(plo_lhs);
00346     plo_lhs=NULL;
00347 
00348     num_params2=0;
00349     num_params=0;
00350     num_symbols=0;
00351     sum_m2_s2=0;
00352     initialized = false;
00353 
00354     CKernel::cleanup();
00355 }
00356 
00357 float64_t CHistogramWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00358 {
00359     int32_t alen, blen;
00360     bool free_avec, free_bvec;
00361     uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00362     uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00363     // can only deal with strings of same length
00364     ASSERT(alen==blen);
00365 
00366     float64_t result = plo_lhs[idx_a]*plo_rhs[idx_b]/variance[0];
00367     result+= sum_m2_s2 ; // does not contain 0-th element
00368 
00369     for (int32_t i=0; i<alen; i++)
00370     {
00371         if (avec[i]==bvec[i])
00372         {
00373             int32_t a_idx = compute_index(i, avec[i]) ;
00374             float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00375             result   += dd*dd/variance[a_idx] ;
00376             dd        = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00377             result   += dd*dd/variance[a_idx+num_params] ;
00378         } ;
00379     }
00380     result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00381 
00382     if (initialized)
00383         result /=  (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00384 
00385 #ifdef DEBUG_HWSK_COMPUTATION
00386     float64_t result2 = compute_slow(idx_a, idx_b) ;
00387     if (fabs(result - result2)>1e-10)
00388         SG_ERROR("new=%e  old = %e  diff = %e\n", result, result2, result - result2);
00389 #endif
00390     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00391     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00392     return result;
00393 }
00394 
00395 void CHistogramWordStringKernel::init()
00396 {
00397     estimate=NULL;
00398     mean=NULL;
00399     variance=NULL;
00400 
00401     sqrtdiag_lhs=NULL;
00402     sqrtdiag_rhs=NULL;
00403 
00404     ld_mean_lhs=NULL;
00405     ld_mean_rhs=NULL;
00406 
00407     plo_lhs=NULL;
00408     plo_rhs=NULL;
00409     num_params=0;
00410     num_params2=0;
00411 
00412     num_symbols=0;
00413     sum_m2_s2=0;
00414     initialized=false;
00415 
00416     m_parameters->add(&initialized, "initialized", "if kernel is initalized");
00417     m_parameters->add_vector(&plo_lhs, &num_lhs, "plo_lhs");
00418     m_parameters->add_vector(&plo_rhs, &num_rhs, "plo_rhs");
00419     m_parameters->add_vector(&ld_mean_lhs, &num_lhs, "ld_mean_lhs");
00420     m_parameters->add_vector(&ld_mean_rhs, &num_rhs, "ld_mean_rhs");
00421     m_parameters->add_vector(&sqrtdiag_lhs, &num_lhs, "sqrtdiag_lhs");
00422     m_parameters->add_vector(&sqrtdiag_rhs, &num_rhs, "sqrtdiag_rhs");
00423     m_parameters->add_vector(&mean, &num_params2, "mean");
00424     m_parameters->add_vector(&variance, &num_params2, "variance");
00425 
00426     m_parameters->add((CSGObject**) &estimate, "estimate", "Plugin Estimate.");
00427 }
00428 
00429 #ifdef DEBUG_HWSK_COMPUTATION
00430 float64_t CHistogramWordStringKernel::compute_slow(int32_t idx_a, int32_t idx_b)
00431 {
00432     int32_t alen, blen;
00433     bool free_avec, free_bvec;
00434     uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00435     uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00436     // can only deal with strings of same length
00437     ASSERT(alen==blen);
00438 
00439     float64_t result=(estimate->posterior_log_odds_obsolete(avec, alen)-mean[0])*
00440         (estimate->posterior_log_odds_obsolete(bvec, blen)-mean[0])/(variance[0]);
00441     result+= sum_m2_s2 ; // does not contain 0-th element
00442 
00443     for (int32_t i=0; i<alen; i++)
00444     {
00445         int32_t a_idx = compute_index(i, avec[i]) ;
00446         int32_t b_idx = compute_index(i, bvec[i]) ;
00447 
00448         if (avec[i]==bvec[i])
00449         {
00450             float64_t dd = estimate->log_derivative_pos_obsolete(avec[i], i) ;
00451             result   += dd*dd/variance[a_idx] ;
00452             dd        = estimate->log_derivative_neg_obsolete(avec[i], i) ;
00453             result   += dd*dd/variance[a_idx+num_params] ;
00454         } ;
00455 
00456         result -= estimate->log_derivative_pos_obsolete(avec[i], i)*mean[a_idx]/variance[a_idx] ;
00457         result -= estimate->log_derivative_pos_obsolete(bvec[i], i)*mean[b_idx]/variance[b_idx] ;
00458         result -= estimate->log_derivative_neg_obsolete(avec[i], i)*mean[a_idx+num_params]/variance[a_idx+num_params] ;
00459         result -= estimate->log_derivative_neg_obsolete(bvec[i], i)*mean[b_idx+num_params]/variance[b_idx+num_params] ;
00460     }
00461 
00462     if (initialized)
00463         result /=  (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00464 
00465     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00466     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00467     return result;
00468 }
00469 
00470 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation