00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 #include <shogun/kernel/string/SalzbergWordStringKernel.h>
00014 #include <shogun/features/Features.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/labels/Labels.h>
00017 #include <shogun/labels/BinaryLabels.h>
00018 #include <shogun/classifier/PluginEstimate.h>
00019
00020 using namespace shogun;
00021
00022 CSalzbergWordStringKernel::CSalzbergWordStringKernel()
00023 : CStringKernel<uint16_t>(0)
00024 {
00025 init();
00026 }
00027
00028 CSalzbergWordStringKernel::CSalzbergWordStringKernel(int32_t size, CPluginEstimate* pie, CLabels* labels)
00029 : CStringKernel<uint16_t>(size)
00030 {
00031 init();
00032 estimate=pie;
00033
00034 if (labels)
00035 set_prior_probs_from_labels(labels);
00036 }
00037
00038 CSalzbergWordStringKernel::CSalzbergWordStringKernel(
00039 CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r,
00040 CPluginEstimate* pie, CLabels* labels)
00041 : CStringKernel<uint16_t>(10),estimate(pie)
00042 {
00043 init();
00044 estimate=pie;
00045
00046 if (labels)
00047 set_prior_probs_from_labels(labels);
00048
00049 init(l, r);
00050 }
00051
00052 CSalzbergWordStringKernel::~CSalzbergWordStringKernel()
00053 {
00054 cleanup();
00055 }
00056
00057 bool CSalzbergWordStringKernel::init(CFeatures* p_l, CFeatures* p_r)
00058 {
00059 CStringKernel<uint16_t>::init(p_l,p_r);
00060 CStringFeatures<uint16_t>* l=(CStringFeatures<uint16_t>*) p_l;
00061 ASSERT(l);
00062 CStringFeatures<uint16_t>* r=(CStringFeatures<uint16_t>*) p_r;
00063 ASSERT(r);
00064
00065 int32_t i;
00066 initialized=false;
00067
00068 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00069 SG_FREE(sqrtdiag_rhs);
00070 sqrtdiag_rhs=NULL;
00071 SG_FREE(sqrtdiag_lhs);
00072 sqrtdiag_lhs=NULL;
00073 if (ld_mean_lhs!=ld_mean_rhs)
00074 SG_FREE(ld_mean_rhs);
00075 ld_mean_rhs=NULL;
00076 SG_FREE(ld_mean_lhs);
00077 ld_mean_lhs=NULL;
00078
00079 sqrtdiag_lhs=SG_MALLOC(float64_t, l->get_num_vectors());
00080 ld_mean_lhs=SG_MALLOC(float64_t, l->get_num_vectors());
00081
00082 for (i=0; i<l->get_num_vectors(); i++)
00083 sqrtdiag_lhs[i]=1;
00084
00085 if (l==r)
00086 {
00087 sqrtdiag_rhs=sqrtdiag_lhs;
00088 ld_mean_rhs=ld_mean_lhs;
00089 }
00090 else
00091 {
00092 sqrtdiag_rhs=SG_MALLOC(float64_t, r->get_num_vectors());
00093 for (i=0; i<r->get_num_vectors(); i++)
00094 sqrtdiag_rhs[i]=1;
00095
00096 ld_mean_rhs=SG_MALLOC(float64_t, r->get_num_vectors());
00097 }
00098
00099 float64_t* l_ld_mean_lhs=ld_mean_lhs;
00100 float64_t* l_ld_mean_rhs=ld_mean_rhs;
00101
00102
00103 if (!initialized)
00104 {
00105 int32_t num_vectors=l->get_num_vectors();
00106 num_symbols=(int32_t) l->get_num_symbols();
00107 int32_t llen=l->get_vector_length(0);
00108 int32_t rlen=r->get_vector_length(0);
00109 num_params=(int32_t) llen*l->get_num_symbols();
00110 int32_t num_params2=(int32_t) llen*l->get_num_symbols()+rlen*r->get_num_symbols();
00111 if ((!estimate) || (!estimate->check_models()))
00112 {
00113 SG_ERROR( "no estimate available\n");
00114 return false ;
00115 } ;
00116 if (num_params2!=estimate->get_num_params())
00117 {
00118 SG_ERROR( "number of parameters of estimate and feature representation do not match\n");
00119 return false ;
00120 } ;
00121
00122 SG_FREE(variance);
00123 SG_FREE(mean);
00124 mean=SG_MALLOC(float64_t, num_params);
00125 ASSERT(mean);
00126 variance=SG_MALLOC(float64_t, num_params);
00127 ASSERT(variance);
00128
00129 for (i=0; i<num_params; i++)
00130 {
00131 mean[i]=0;
00132 variance[i]=0;
00133 }
00134
00135
00136
00137 for (i=0; i<num_vectors; i++)
00138 {
00139 int32_t len;
00140 bool free_vec;
00141 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00142
00143 for (int32_t j=0; j<len; j++)
00144 {
00145 int32_t idx=compute_index(j, vec[j]);
00146 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00147 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00148 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00149
00150 mean[idx] += value/num_vectors ;
00151 }
00152 l->free_feature_vector(vec, i, free_vec);
00153 }
00154
00155
00156 for (i=0; i<num_vectors; i++)
00157 {
00158 int32_t len;
00159 bool free_vec;
00160 uint16_t* vec=l->get_feature_vector(i, len, free_vec);
00161
00162 for (int32_t j=0; j<len; j++)
00163 {
00164 for (int32_t k=0; k<4; k++)
00165 {
00166 int32_t idx=compute_index(j, k);
00167 if (k!=vec[j])
00168 variance[idx]+=mean[idx]*mean[idx]/num_vectors;
00169 else
00170 {
00171 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(vec[j], j) ;
00172 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(vec[j], j) ;
00173 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00174
00175 variance[idx] += CMath::sq(value-mean[idx])/num_vectors;
00176 }
00177 }
00178 }
00179 l->free_feature_vector(vec, i, free_vec);
00180 }
00181
00182
00183
00184 sum_m2_s2=0 ;
00185 for (i=0; i<num_params; i++)
00186 {
00187 if (variance[i]<1e-14)
00188 variance[i]=1 ;
00189
00190
00191 sum_m2_s2 += mean[i]*mean[i]/(variance[i]) ;
00192 } ;
00193 }
00194
00195
00196
00197
00198 for (i=0; i<l->get_num_vectors(); i++)
00199 {
00200 int32_t alen ;
00201 bool free_avec;
00202 uint16_t* avec=l->get_feature_vector(i, alen, free_avec);
00203 float64_t result=0 ;
00204 for (int32_t j=0; j<alen; j++)
00205 {
00206 int32_t a_idx = compute_index(j, avec[j]) ;
00207 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[j], j) ;
00208 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[j], j) ;
00209 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00210
00211 if (variance[a_idx]!=0)
00212 result-=value*mean[a_idx]/variance[a_idx];
00213 }
00214 ld_mean_lhs[i]=result ;
00215
00216 l->free_feature_vector(avec, i, free_avec);
00217 }
00218
00219 if (ld_mean_lhs!=ld_mean_rhs)
00220 {
00221
00222
00223 for (i=0; i<r->get_num_vectors(); i++)
00224 {
00225 int32_t alen;
00226 bool free_avec;
00227 uint16_t* avec=r->get_feature_vector(i, alen, free_avec);
00228 float64_t result=0;
00229
00230 for (int32_t j=0; j<alen; j++)
00231 {
00232 int32_t a_idx = compute_index(j, avec[j]) ;
00233 float64_t theta_p=1/estimate->log_derivative_pos_obsolete(
00234 avec[j], j) ;
00235 float64_t theta_n=1/estimate->log_derivative_neg_obsolete(
00236 avec[j], j) ;
00237 float64_t value=(theta_p/(pos_prior*theta_p+neg_prior*theta_n));
00238
00239 result -= value*mean[a_idx]/variance[a_idx] ;
00240 }
00241
00242 ld_mean_rhs[i]=result;
00243 r->free_feature_vector(avec, i, free_avec);
00244 }
00245 }
00246
00247
00248
00249 this->lhs=l;
00250 this->rhs=l;
00251 ld_mean_lhs = l_ld_mean_lhs ;
00252 ld_mean_rhs = l_ld_mean_lhs ;
00253
00254
00255 for (i=0; i<lhs->get_num_vectors(); i++)
00256 {
00257 sqrtdiag_lhs[i]=sqrt(compute(i,i));
00258
00259
00260 if (sqrtdiag_lhs[i]==0)
00261 sqrtdiag_lhs[i]=1e-16;
00262 }
00263
00264
00265
00266 if (sqrtdiag_lhs!=sqrtdiag_rhs)
00267 {
00268 this->lhs=r;
00269 this->rhs=r;
00270 ld_mean_lhs = l_ld_mean_rhs ;
00271 ld_mean_rhs = l_ld_mean_rhs ;
00272
00273
00274 for (i=0; i<rhs->get_num_vectors(); i++)
00275 {
00276 sqrtdiag_rhs[i]=sqrt(compute(i,i));
00277
00278
00279 if (sqrtdiag_rhs[i]==0)
00280 sqrtdiag_rhs[i]=1e-16;
00281 }
00282 }
00283
00284 this->lhs=l;
00285 this->rhs=r;
00286 ld_mean_lhs = l_ld_mean_lhs ;
00287 ld_mean_rhs = l_ld_mean_rhs ;
00288
00289 initialized = true ;
00290 return init_normalizer();
00291 }
00292
00293 void CSalzbergWordStringKernel::cleanup()
00294 {
00295 SG_FREE(variance);
00296 variance=NULL;
00297
00298 SG_FREE(mean);
00299 mean=NULL;
00300
00301 if (sqrtdiag_lhs != sqrtdiag_rhs)
00302 SG_FREE(sqrtdiag_rhs);
00303 sqrtdiag_rhs=NULL;
00304
00305 SG_FREE(sqrtdiag_lhs);
00306 sqrtdiag_lhs=NULL;
00307
00308 if (ld_mean_lhs!=ld_mean_rhs)
00309 SG_FREE(ld_mean_rhs);
00310 ld_mean_rhs=NULL;
00311
00312 SG_FREE(ld_mean_lhs);
00313 ld_mean_lhs=NULL;
00314
00315 CKernel::cleanup();
00316 }
00317
00318 float64_t CSalzbergWordStringKernel::compute(int32_t idx_a, int32_t idx_b)
00319 {
00320 int32_t alen, blen;
00321 bool free_avec, free_bvec;
00322 uint16_t* avec=((CStringFeatures<uint16_t>*) lhs)->get_feature_vector(idx_a, alen, free_avec);
00323 uint16_t* bvec=((CStringFeatures<uint16_t>*) rhs)->get_feature_vector(idx_b, blen, free_bvec);
00324
00325 ASSERT(alen==blen);
00326
00327 float64_t result = sum_m2_s2 ;
00328
00329 for (int32_t i=0; i<alen; i++)
00330 {
00331 if (avec[i]==bvec[i])
00332 {
00333 int32_t a_idx = compute_index(i, avec[i]) ;
00334
00335 float64_t theta_p = 1/estimate->log_derivative_pos_obsolete(avec[i], i) ;
00336 float64_t theta_n = 1/estimate->log_derivative_neg_obsolete(avec[i], i) ;
00337 float64_t value = (theta_p/(pos_prior*theta_p+neg_prior*theta_n)) ;
00338
00339 result += value*value/variance[a_idx] ;
00340 }
00341 }
00342 result += ld_mean_lhs[idx_a] + ld_mean_rhs[idx_b] ;
00343
00344 ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(avec, idx_a, free_avec);
00345 ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(bvec, idx_b, free_bvec);
00346
00347 if (initialized)
00348 result /= (sqrtdiag_lhs[idx_a]*sqrtdiag_rhs[idx_b]) ;
00349
00350 return result;
00351 }
00352
00353 void CSalzbergWordStringKernel::set_prior_probs_from_labels(CLabels* labels)
00354 {
00355 ASSERT(labels);
00356 ASSERT(labels->get_label_type() == LT_BINARY);
00357 labels->ensure_valid();
00358
00359 int32_t num_pos=0, num_neg=0;
00360 for (int32_t i=0; i<labels->get_num_labels(); i++)
00361 {
00362 if (((CBinaryLabels*) labels)->get_int_label(i)==1)
00363 num_pos++;
00364 if (((CBinaryLabels*) labels)->get_int_label(i)==-1)
00365 num_neg++;
00366 }
00367
00368 SG_INFO("priors: pos=%1.3f (%i) neg=%1.3f (%i)\n",
00369 (float64_t) num_pos/(num_pos+num_neg), num_pos,
00370 (float64_t) num_neg/(num_pos+num_neg), num_neg);
00371
00372 set_prior_probs(
00373 (float64_t)num_pos/(num_pos+num_neg),
00374 (float64_t)num_neg/(num_pos+num_neg));
00375 }
00376
00377 void CSalzbergWordStringKernel::init()
00378 {
00379 estimate=NULL;
00380 mean=NULL;
00381 variance=NULL;
00382
00383 sqrtdiag_lhs=NULL;
00384 sqrtdiag_rhs=NULL;
00385
00386 ld_mean_lhs=NULL;
00387 ld_mean_rhs=NULL;
00388
00389 num_params=0;
00390 num_symbols=0;
00391 sum_m2_s2=0;
00392 pos_prior=0.5;
00393
00394 neg_prior=0.5;
00395 initialized=false;
00396 }