CommWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/common.h>
00012 #include <shogun/io/SGIO.h>
00013 
00014 #include <shogun/base/Parameter.h>
00015 
00016 #include <shogun/kernel/CommWordStringKernel.h>
00017 #include <shogun/kernel/SqrtDiagKernelNormalizer.h>
00018 #include <shogun/features/StringFeatures.h>
00019 
00020 using namespace shogun;
00021 
00022 CCommWordStringKernel::CCommWordStringKernel()
00023 : CStringKernel<uint16_t>()
00024 {
00025     init();
00026 }
00027 
00028 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00029 : CStringKernel<uint16_t>(size)
00030 {
00031     init();
00032     use_sign=s;
00033 }
00034 
00035 CCommWordStringKernel::CCommWordStringKernel(
00036     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r,
00037     bool s, int32_t size) : CStringKernel<uint16_t>(size)
00038 {
00039     init();
00040     use_sign=s;
00041 
00042     init(l,r);
00043 }
00044 
00045 
00046 bool CCommWordStringKernel::init_dictionary(int32_t size)
00047 {
00048     dictionary_size= size;
00049     SG_FREE(dictionary_weights);
00050     dictionary_weights=SG_MALLOC(float64_t, size);
00051     SG_DEBUG( "using dictionary of %d words\n", size);
00052     clear_normal();
00053 
00054     return dictionary_weights!=NULL;
00055 }
00056 
00057 CCommWordStringKernel::~CCommWordStringKernel() 
00058 {
00059     cleanup();
00060 
00061     SG_FREE(dictionary_weights);
00062     SG_FREE(dict_diagonal_optimization);
00063 }
00064   
00065 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00066 {
00067     CStringKernel<uint16_t>::init(l,r);
00068 
00069     if (use_dict_diagonal_optimization)
00070     {
00071         SG_FREE(dict_diagonal_optimization);
00072         dict_diagonal_optimization=SG_MALLOC(int32_t, int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols()));
00073         ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00074     }
00075 
00076     return init_normalizer();
00077 }
00078 
00079 void CCommWordStringKernel::cleanup()
00080 {
00081     delete_optimization();
00082     CKernel::cleanup();
00083 }
00084 
00085 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00086 {
00087     int32_t alen;
00088     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00089     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00090 
00091     bool free_av;
00092     uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00093 
00094     float64_t result=0.0 ;
00095     ASSERT(l==r);
00096     ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00097     ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00098 
00099     int32_t num_symbols=(int32_t) l->get_num_symbols();
00100     ASSERT(num_symbols<=dictionary_size);
00101 
00102     int32_t* dic = dict_diagonal_optimization;
00103     memset(dic, 0, num_symbols*sizeof(int32_t));
00104 
00105     for (int32_t i=0; i<alen; i++)
00106         dic[av[i]]++;
00107 
00108     if (use_sign)
00109     {
00110         for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00111         {
00112             if (dic[i]!=0)
00113                 result++;
00114         }
00115     }
00116     else
00117     {
00118         for (int32_t i=0; i<num_symbols; i++)
00119         {
00120             if (dic[i]!=0)
00121                 result+=dic[i]*dic[i];
00122         }
00123     }
00124     l->free_feature_vector(av, idx_a, free_av);
00125 
00126     return result;
00127 }
00128 
00129 float64_t CCommWordStringKernel::compute_helper(
00130     int32_t idx_a, int32_t idx_b, bool do_sort)
00131 {
00132     int32_t alen, blen;
00133     bool free_av, free_bv;
00134 
00135     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00136     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00137 
00138     uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00139     uint16_t* bv=r->get_feature_vector(idx_b, blen, free_bv);
00140 
00141     uint16_t* avec=av;
00142     uint16_t* bvec=bv;
00143 
00144     if (do_sort)
00145     {
00146         if (alen>0)
00147         {
00148             avec=SG_MALLOC(uint16_t, alen);
00149             memcpy(avec, av, sizeof(uint16_t)*alen);
00150             CMath::radix_sort(avec, alen);
00151         }
00152         else
00153             avec=NULL;
00154 
00155         if (blen>0)
00156         {
00157             bvec=SG_MALLOC(uint16_t, blen);
00158             memcpy(bvec, bv, sizeof(uint16_t)*blen);
00159             CMath::radix_sort(bvec, blen);
00160         }
00161         else
00162             bvec=NULL;
00163     }
00164     else
00165     {
00166         if ( (l->get_num_preprocessors() != l->get_num_preprocessed()) ||
00167                 (r->get_num_preprocessors() != r->get_num_preprocessed()))
00168         {
00169             SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00170                     " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preprocessors(),
00171                     r->get_num_preprocessed(), r->get_num_preprocessors());
00172         }
00173     }
00174 
00175     float64_t result=0;
00176 
00177     int32_t left_idx=0;
00178     int32_t right_idx=0;
00179 
00180     if (use_sign)
00181     {
00182         while (left_idx < alen && right_idx < blen)
00183         {
00184             if (avec[left_idx]==bvec[right_idx])
00185             {
00186                 uint16_t sym=avec[left_idx];
00187 
00188                 while (left_idx< alen && avec[left_idx]==sym)
00189                     left_idx++;
00190 
00191                 while (right_idx< blen && bvec[right_idx]==sym)
00192                     right_idx++;
00193 
00194                 result++;
00195             }
00196             else if (avec[left_idx]<bvec[right_idx])
00197                 left_idx++;
00198             else
00199                 right_idx++;
00200         }
00201     }
00202     else
00203     {
00204         while (left_idx < alen && right_idx < blen)
00205         {
00206             if (avec[left_idx]==bvec[right_idx])
00207             {
00208                 int32_t old_left_idx=left_idx;
00209                 int32_t old_right_idx=right_idx;
00210 
00211                 uint16_t sym=avec[left_idx];
00212 
00213                 while (left_idx< alen && avec[left_idx]==sym)
00214                     left_idx++;
00215 
00216                 while (right_idx< blen && bvec[right_idx]==sym)
00217                     right_idx++;
00218 
00219                 result+=((float64_t) (left_idx-old_left_idx))*
00220                     ((float64_t) (right_idx-old_right_idx));
00221             }
00222             else if (avec[left_idx]<bvec[right_idx])
00223                 left_idx++;
00224             else
00225                 right_idx++;
00226         }
00227     }
00228 
00229     if (do_sort)
00230     {
00231         SG_FREE(avec);
00232         SG_FREE(bvec);
00233     }
00234 
00235     l->free_feature_vector(av, idx_a, free_av);
00236     r->free_feature_vector(bv, idx_b, free_bv);
00237 
00238     return result;
00239 }
00240 
00241 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00242 {
00243     int32_t len=-1;
00244     bool free_vec;
00245     uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00246         get_feature_vector(vec_idx, len, free_vec);
00247 
00248     if (len>0)
00249     {
00250         int32_t j, last_j=0;
00251         if (use_sign)
00252         {
00253             for (j=1; j<len; j++)
00254             {
00255                 if (vec[j]==vec[j-1])
00256                     continue;
00257 
00258                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00259                     normalize_lhs(weight, vec_idx);
00260             }
00261 
00262             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00263                 normalize_lhs(weight, vec_idx);
00264         }
00265         else
00266         {
00267             for (j=1; j<len; j++)
00268             {
00269                 if (vec[j]==vec[j-1])
00270                     continue;
00271 
00272                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00273                     normalize_lhs(weight*(j-last_j), vec_idx);
00274                 last_j = j;
00275             }
00276 
00277             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00278                 normalize_lhs(weight*(len-last_j), vec_idx);
00279         }
00280         set_is_initialized(true);
00281     }
00282 
00283     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(vec, vec_idx, free_vec);
00284 }
00285 
00286 void CCommWordStringKernel::clear_normal()
00287 {
00288     memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00289     set_is_initialized(false);
00290 }
00291 
00292 bool CCommWordStringKernel::init_optimization(
00293     int32_t count, int32_t* IDX, float64_t* weights)
00294 {
00295     delete_optimization();
00296 
00297     if (count<=0)
00298     {
00299         set_is_initialized(true);
00300         SG_DEBUG("empty set of SVs\n");
00301         return true;
00302     }
00303 
00304     SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00305 
00306     for (int32_t i=0; i<count; i++)
00307     {
00308         if ( (i % (count/10+1)) == 0)
00309             SG_PROGRESS(i, 0, count);
00310 
00311         add_to_normal(IDX[i], weights[i]);
00312     }
00313 
00314     set_is_initialized(true);
00315     return true;
00316 }
00317 
00318 bool CCommWordStringKernel::delete_optimization() 
00319 {
00320     SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00321 
00322     clear_normal();
00323     return true;
00324 }
00325 
00326 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00327 { 
00328     if (!get_is_initialized())
00329     {
00330       SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00331         return 0 ; 
00332     }
00333 
00334     float64_t result = 0;
00335     int32_t len = -1;
00336     bool free_vec;
00337     uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00338         get_feature_vector(i, len, free_vec);
00339 
00340     int32_t j, last_j=0;
00341     if (vec && len>0)
00342     {
00343         if (use_sign)
00344         {
00345             for (j=1; j<len; j++)
00346             {
00347                 if (vec[j]==vec[j-1])
00348                     continue;
00349 
00350                 result += dictionary_weights[(int32_t) vec[j-1]];
00351             }
00352 
00353             result += dictionary_weights[(int32_t) vec[len-1]];
00354         }
00355         else
00356         {
00357             for (j=1; j<len; j++)
00358             {
00359                 if (vec[j]==vec[j-1])
00360                     continue;
00361 
00362                 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00363                 last_j = j;
00364             }
00365 
00366             result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00367         }
00368 
00369         result=normalizer->normalize_rhs(result, i);
00370     }
00371     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(vec, i, free_vec);
00372     return result;
00373 }
00374 
00375 float64_t* CCommWordStringKernel::compute_scoring(
00376     int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00377     int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00378 {
00379     ASSERT(lhs);
00380     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00381     num_feat=1;//str->get_max_vector_length();
00382     CAlphabet* alpha=str->get_alphabet();
00383     ASSERT(alpha);
00384     int32_t num_bits=alpha->get_num_bits();
00385     int32_t order=str->get_order();
00386     ASSERT(max_degree<=order);
00387     //int32_t num_words=(int32_t) str->get_num_symbols();
00388     int32_t num_words=(int32_t) str->get_original_num_symbols();
00389     int32_t offset=0;
00390 
00391     num_sym=0;
00392     
00393     for (int32_t i=0; i<order; i++)
00394         num_sym+=CMath::pow((int32_t) num_words,i+1);
00395 
00396     SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00397             num_feat, num_sym, num_feat*num_sym);
00398 
00399     if (!target)
00400         target=SG_MALLOC(float64_t, num_feat*num_sym);
00401     memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00402 
00403     if (do_init)
00404         init_optimization(num_suppvec, IDX, alphas);
00405 
00406     uint32_t kmer_mask=0;
00407     uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00408 
00409     for (int32_t o=0; o<max_degree; o++)
00410     {
00411         float64_t* contrib=&target[offset];
00412         offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00413 
00414         kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00415 
00416         for (int32_t p=-o; p<order; p++)
00417         {
00418             int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00419             uint32_t imer_mask=kmer_mask;
00420             uint32_t jmer_mask=kmer_mask;
00421 
00422             if (p<0)
00423             {
00424                 il=-p;
00425                 m_sym=order-o-p-1;
00426                 o_sym=-p;
00427             }
00428             else if (p<order-o)
00429             {
00430                 ir=p;
00431                 m_sym=order-o-1;
00432             }
00433             else
00434             {
00435                 ir=p;
00436                 m_sym=p;
00437                 o_sym=p-order+o+1;
00438                 jl=order-ir;
00439                 imer_mask=(kmer_mask>>(num_bits*o_sym));
00440                 jmer_mask=(kmer_mask>>(num_bits*jl));
00441             }
00442 
00443             float64_t marginalizer=
00444                 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00445             
00446             for (uint32_t i=0; i<words; i++)
00447             {
00448                 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00449 
00450                 if (p>=0 && p<order-o)
00451                 {
00452 //#define DEBUG_COMMSCORING
00453 #ifdef DEBUG_COMMSCORING
00454                     SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00455                             o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00456 
00457                     SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00458                             alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00459                             alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00460                             alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00461                             alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00462                             dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00463 #endif
00464                     contrib[x]+=dictionary_weights[i]*marginalizer;
00465                 }
00466                 else
00467                 {
00468                     for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00469                     {
00470                         uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00471 #ifdef DEBUG_COMMSCORING
00472 
00473                         SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00474                                 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00475                         SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00476                                 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00477                                 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00478                                 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00479                                 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00480                                 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00481 #endif
00482                         contrib[c]+=dictionary_weights[i]*marginalizer;
00483                     }
00484                 }
00485             }
00486         }
00487     }
00488 
00489     for (int32_t i=1; i<num_feat; i++)
00490         memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00491 
00492     SG_UNREF(alpha);
00493 
00494     return target;
00495 }
00496 
00497 
00498 char* CCommWordStringKernel::compute_consensus(
00499     int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00500 {
00501     ASSERT(lhs);
00502     ASSERT(IDX);
00503     ASSERT(alphas);
00504 
00505     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00506     int32_t num_words=(int32_t) str->get_num_symbols();
00507     int32_t num_feat=str->get_max_vector_length();
00508     int64_t total_len=((int64_t) num_feat) * num_words;
00509     CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00510     ASSERT(alpha);
00511     int32_t num_bits=alpha->get_num_bits();
00512     int32_t order=str->get_order();
00513     int32_t max_idx=-1;
00514     float64_t max_score=0; 
00515     result_len=num_feat+order-1;
00516 
00517     //init
00518     init_optimization(num_suppvec, IDX, alphas);
00519 
00520     char* result=SG_MALLOC(char, result_len);
00521     int32_t* bt=SG_MALLOC(int32_t, total_len);
00522     float64_t* score=SG_MALLOC(float64_t, total_len);
00523 
00524     for (int64_t i=0; i<total_len; i++)
00525     {
00526         bt[i]=-1;
00527         score[i]=0;
00528     }
00529 
00530     for (int32_t t=0; t<num_words; t++)
00531         score[t]=dictionary_weights[t];
00532 
00533     //dynamic program
00534     for (int32_t i=1; i<num_feat; i++)
00535     {
00536         for (int32_t t1=0; t1<num_words; t1++)
00537         {
00538             max_idx=-1;
00539             max_score=0; 
00540 
00541             /* ignore weights the svm does not care about 
00542              * (has not seen in training). note that this assumes that zero 
00543              * weights are very unlikely to appear elsewise */
00544 
00545             //if (dictionary_weights[t1]==0.0)
00546                 //continue;
00547 
00548             /* iterate over words t ending on t1 and find the highest scoring
00549              * pair */
00550             uint16_t suffix=(uint16_t) t1 >> num_bits;
00551 
00552             for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00553             {
00554                 uint16_t t=suffix | sym << (num_bits*(order-1));
00555 
00556                 //if (dictionary_weights[t]==0.0)
00557                 //  continue;
00558 
00559                 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00560                 if (sc > max_score || max_idx==-1)
00561                 {
00562                     max_idx=t;
00563                     max_score=sc;
00564                 }
00565             }
00566             ASSERT(max_idx!=-1);
00567 
00568             score[num_words*i + t1]=max_score;
00569             bt[num_words*i + t1]=max_idx;
00570         }
00571     }
00572 
00573     //backtracking
00574     max_idx=0;
00575     max_score=score[num_words*(num_feat-1) + 0];
00576     for (int32_t t=1; t<num_words; t++)
00577     {
00578         float64_t sc=score[num_words*(num_feat-1) + t];
00579         if (sc>max_score)
00580         {
00581             max_idx=t;
00582             max_score=sc;
00583         }
00584     }
00585 
00586     SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00587     
00588     for (int32_t i=result_len-1; i>=num_feat; i--)
00589         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00590 
00591     for (int32_t i=num_feat-1; i>=0; i--)
00592     {
00593         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00594         max_idx=bt[num_words*i + max_idx];
00595     }
00596 
00597     SG_FREE(bt);
00598     SG_FREE(score);
00599     SG_UNREF(alpha);
00600     return result;
00601 }
00602 
00603 void CCommWordStringKernel::init()
00604 {
00605     dictionary_size=0;
00606     dictionary_weights=NULL;
00607 
00608     use_sign=false;
00609     use_dict_diagonal_optimization=false;
00610     dict_diagonal_optimization=NULL;
00611 
00612     properties |= KP_LINADD;
00613     init_dictionary(1<<(sizeof(uint16_t)*8));
00614     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00615 
00616     m_parameters->add_vector(&dictionary_weights, &dictionary_size, "dictionary_weights",
00617             "Dictionary for applying kernel.");
00618     m_parameters->add(&use_sign, "use_sign",
00619             "If signum(counts) is used instead of counts.");
00620     m_parameters->add(&use_dict_diagonal_optimization, "use_dict_diagonal_optimization",
00621             "If K(x,x) is computed potentially more efficiently.");
00622 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation