SpectrumMismatchRBFKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <vector>
00013 
00014 #include "lib/common.h"
00015 #include "lib/io.h"
00016 #include "lib/Signal.h"
00017 #include "lib/Trie.h"
00018 #include "base/Parallel.h"
00019 
00020 #include "kernel/SpectrumMismatchRBFKernel.h"
00021 #include "features/Features.h"
00022 #include "features/StringFeatures.h"
00023 
00024 
00025 #include <vector>
00026 #include <string>
00027 
00028 #include <assert.h>
00029 
00030 #ifndef WIN32
00031 #include <pthread.h>
00032 #endif
00033 
00034 using namespace shogun;
00035 
00036 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(void)
00037     :CStringKernel<char>(0)
00038 {
00039     SG_UNSTABLE("CSpectrumMismatchRBFKernel::"
00040                 "CSpectrumMismatchRBFKernel(void)", "\n");
00041 
00042     alphabet = NULL;
00043     degree = 0;
00044     max_mismatch = 0;
00045     AA_matrix = NULL;
00046     width = 0.0;
00047 
00048     initialized = false;
00049     target_letter_0 = 0;
00050 }
00051 
00052 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel (int32_t size,
00053         float64_t* AA_matrix_, int32_t nr, int32_t nc,
00054         int32_t degree_, int32_t max_mismatch_, float64_t width_) : CStringKernel<char>(size),
00055     alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00056 {
00057     lhs=NULL;
00058     rhs=NULL;
00059 
00060     target_letter_0=-1 ;
00061 
00062     AA_matrix=NULL;
00063     set_AA_matrix(AA_matrix_, nr, nc);
00064 }
00065 
00066 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(
00067                                                        CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_, int32_t nr, int32_t nc, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00068 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00069 {
00070     target_letter_0=-1 ;
00071 
00072     AA_matrix=NULL;
00073     set_AA_matrix(AA_matrix_, nr, nc);
00074     init(l, r);
00075 }
00076 
00077 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel()
00078 {
00079     cleanup();
00080     delete[] AA_matrix ;
00081 }
00082 
00083 
00084 void CSpectrumMismatchRBFKernel::remove_lhs()
00085 {
00086 
00087     CKernel::remove_lhs();
00088 }
00089 
00090 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r)
00091 {
00092     int32_t lhs_changed=(lhs!=l);
00093     int32_t rhs_changed=(rhs!=r);
00094 
00095     CStringKernel<char>::init(l,r);
00096 
00097     SG_DEBUG("lhs_changed: %i\n", lhs_changed);
00098     SG_DEBUG("rhs_changed: %i\n", rhs_changed);
00099 
00100     CStringFeatures<char>* sf_l=(CStringFeatures<char>*) l;
00101     CStringFeatures<char>* sf_r=(CStringFeatures<char>*) r;
00102 
00103     SG_UNREF(alphabet);
00104     alphabet=sf_l->get_alphabet();
00105     CAlphabet* ralphabet=sf_r->get_alphabet();
00106 
00107     if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA)))
00108         properties &= ((uint64_t) (-1)) ^ (KP_LINADD | KP_BATCHEVALUATION);
00109 
00110     ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet());
00111     SG_UNREF(ralphabet);
00112 
00113     compute_all() ;
00114     
00115     return init_normalizer();
00116 }
00117 
00118 void CSpectrumMismatchRBFKernel::cleanup()
00119 {
00120 
00121     SG_UNREF(alphabet);
00122     alphabet=NULL;
00123 
00124     CKernel::cleanup();
00125 }
00126 
00127 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path, const char* joint_seq, unsigned int index)
00128 {
00129     float64_t diff=0.0 ;
00130 
00131     for (unsigned int i=0; i<path.size(); i++)
00132     {
00133         if (path[i]!=joint_seq[index+i])
00134         {
00135             diff += AA_matrix[ (path[i]-1)*128 + path[i] - 1] ;
00136             diff -= 2*AA_matrix[ (path[i]-1)*128 + joint_seq[index+i] - 1] ;
00137             diff += AA_matrix[ (joint_seq[index+i]-1)*128 + joint_seq[index+i] - 1] ;
00138         }
00139     }
00140 
00141     return exp( - diff/width) ;
00142 }
00143 
00144 /*
00145 float64_t CSpectrumMismatchRBFKernel::compute_helper(const char* joint_seq, 
00146                                                       std::vector<unsigned int> joint_index, std::vector<unsigned int> joint_mismatch, 
00147                                                       std::string path, unsigned int d, 
00148                                                       const int & alen) 
00149 {
00150     const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00151     const unsigned int num_AA = strlen(AA) ;
00152 
00153     assert(path.size()==d) ;
00154     assert(joint_mismatch.size()==joint_index.size()) ;
00155     
00156     float64_t res = 0.0 ;
00157     
00158     for (unsigned int i=0; i<num_AA; i++)
00159     {
00160         std::vector<unsigned int> joint_mismatch_ ;
00161         std::vector<unsigned int> joint_index_ ;
00162 
00163         for (unsigned int j=0; j<joint_index.size(); j++)
00164         {
00165             if (joint_seq[joint_index[j]+d] != AA[i])
00166             {
00167                 if (joint_mismatch[j]+1 <= (unsigned int) max_mismatch)
00168                 {
00169                     joint_mismatch_.push_back(joint_mismatch[j]+1) ;
00170                     joint_index_.push_back(joint_index[j]) ;
00171                 }
00172             }
00173             else
00174             {
00175                 joint_mismatch_.push_back(joint_mismatch[j]) ;
00176                 joint_index_.push_back(joint_index[j]) ;
00177             }
00178         }
00179         if (joint_mismatch_.size()>0)
00180         {
00181             std::string path_ = path + AA[i] ;
00182 
00183             if (d+1 < (unsigned int) degree)
00184             {
00185                 res += compute_helper(joint_seq,  joint_index_, joint_mismatch_, path_, d+1, alen) ;
00186             }
00187             else
00188             {
00189                 int anum=0, bnum=0;
00190                 for (unsigned int j=0; j<joint_index_.size(); j++)
00191                     if (joint_index_[j] < (unsigned int)alen)
00192                     {
00193                         if (1)
00194                         {
00195                             anum++ ;
00196                             if (joint_mismatch_[j]==0)
00197                                 anum+=3 ;
00198                         }
00199                         else
00200                         {
00201                             if (joint_mismatch_[j]!=0)
00202                                 anum += AA_helper(path_, joint_seq, joint_index_[j]) ;
00203                             else
00204                                 anum++ ;
00205                         }
00206                     }
00207                     else
00208                     {
00209                         if (1)
00210                         {
00211                             bnum++ ;
00212                             if (joint_mismatch_[j]==0)
00213                                 bnum+=3 ;
00214                         }
00215                         else
00216                         {
00217                             if (joint_mismatch_[j]!=0)
00218                                 bnum += AA_helper(path_, joint_seq, joint_index_[j]) ;
00219                             else
00220                                 bnum++ ;
00221                         }
00222                     }
00223                 
00224                 //fprintf(stdout, "%s: %i x %i\n", path_.c_str(), anum, bnum) ;
00225                 
00226                 res+= anum*bnum ;
00227             }
00228         }
00229     }
00230     return res ;
00231 }
00232 */
00233 
00234 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq, 
00235                                                      std::vector<struct joint_list_struct> &joint_list,
00236                                                      std::string path, unsigned int d) 
00237 {
00238     const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00239     const unsigned int num_AA = strlen(AA) ;
00240 
00241     assert(path.size()==d) ;
00242     
00243     for (unsigned int i=0; i<num_AA; i++)
00244     {
00245         std::vector<struct joint_list_struct> joint_list_ ;
00246         
00247         if (d==0)
00248             fprintf(stderr, "i=%i: ", i) ;
00249         if (d==0 && target_letter_0!=-1 && (int)i != target_letter_0 )
00250             continue ;
00251         
00252         if (d==1)
00253         {
00254             fprintf(stdout, "*") ;
00255             fflush(stdout) ;
00256         }
00257         if (d==2)
00258         {
00259             fprintf(stdout, "+") ;
00260             fflush(stdout) ;
00261         }
00262 
00263         for (unsigned int j=0; j<joint_list.size(); j++)
00264         {
00265             if (joint_seq[joint_list[j].index+d] != AA[i])
00266             {
00267                 if (joint_list[j].mismatch+1 <= (unsigned int) max_mismatch)
00268                 {
00269                     struct joint_list_struct list_item ;
00270                     list_item = joint_list[j] ;
00271                     list_item.mismatch = joint_list[j].mismatch+1 ;
00272                     joint_list_.push_back(list_item) ;
00273                 }
00274             }
00275             else
00276                 joint_list_.push_back(joint_list[j]) ;
00277         }
00278 
00279         if (joint_list_.size()>0)
00280         {
00281             std::string path_ = path + AA[i] ;
00282 
00283             if (d+1 < (unsigned int) degree)
00284             {
00285                 compute_helper_all(joint_seq,  joint_list_, path_, d+1) ;
00286             }
00287             else
00288             {
00289                 CArray<float64_t> feats ;
00290                 feats.resize_array(kernel_matrix.get_dim1()) ;
00291                 feats.zero() ;
00292                 
00293                 for (unsigned int j=0; j<joint_list_.size(); j++)
00294                 {
00295                     if (width==0.0)
00296                     {
00297                         feats[joint_list_[j].ex_index]++ ;
00298                         //if (joint_mismatch_[j]==0)
00299                         //  feats[joint_ex_index_[j]]+=3 ;
00300                     }
00301                     else
00302                     {
00303                         if (joint_list_[j].mismatch!=0)
00304                             feats[joint_list_[j].ex_index] += AA_helper(path_, joint_seq, joint_list_[j].index) ;
00305                         else
00306                             feats[joint_list_[j].ex_index] ++ ;
00307                     }
00308                 }
00309 
00310                 std::vector<int> idx ;
00311                 for (int r=0; r<feats.get_array_size(); r++)
00312                     if (feats[r]!=0.0)
00313                         idx.push_back(r) ;
00314 
00315                 for (unsigned int r=0; r<idx.size(); r++)
00316                     for (unsigned int s=r; s<idx.size(); s++)
00317                         if (s==r)
00318                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s])  ;
00319                         else
00320                         {
00321                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s])  ;
00322                             kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[s],idx[r]), idx[s], idx[r])  ;
00323                         }
00324             }
00325         }
00326         if (d==0)
00327             fprintf(stdout, "\n") ;
00328     }
00329 }
00330 
00331 void CSpectrumMismatchRBFKernel::compute_all()
00332 {
00333     std::string joint_seq ; 
00334     std::vector<struct joint_list_struct> joint_list ;
00335 
00336     assert(lhs->get_num_vectors()==rhs->get_num_vectors()) ;
00337     kernel_matrix.resize_array(lhs->get_num_vectors(), lhs->get_num_vectors()) ;
00338     for (int i=0; i<lhs->get_num_vectors(); i++)
00339         for (int j=0; j<lhs->get_num_vectors(); j++)
00340             kernel_matrix.set_element(0, i, j) ;
00341     
00342     for (int i=0; i<lhs->get_num_vectors(); i++)
00343     {
00344         int32_t alen ;
00345         bool free_avec ;
00346         char* avec = ((CStringFeatures<char>*) lhs)->get_feature_vector(i, alen, free_avec);
00347 
00348         for (int apos=0; apos+degree-1<alen; apos++)
00349         {
00350             struct joint_list_struct list_item ;
00351             list_item.ex_index = i ;
00352             list_item.index = apos+joint_seq.size() ;
00353             list_item.mismatch = 0 ;
00354             
00355             joint_list.push_back(list_item) ;
00356         }
00357         joint_seq += std::string(avec, alen) ;
00358         
00359         ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, i, free_avec);
00360     }
00361     
00362     compute_helper_all(joint_seq.c_str(), joint_list, "", 0) ;
00363 }
00364 
00365 
00366 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b)
00367 {
00368     return kernel_matrix.element(idx_a, idx_b) ;
00369 }
00370 /*
00371 bool CSpectrumMismatchRBFKernel::set_weights(
00372     float64_t* ws, int32_t d, int32_t len)
00373 {
00374     if (d==128 && len==128)
00375     {
00376         SG_DEBUG("Setting AA_matrix\n") ;
00377         memcpy(AA_matrix, ws, 128*128*sizeof(float64_t)) ;
00378         return true ;
00379     }
00380 
00381     if (d==1 && len==1)
00382     {
00383         sigma=ws[0] ;
00384         SG_DEBUG("Setting sigma to %e\n", sigma) ;
00385         return true ;
00386     }
00387 
00388     if (d==2 && len==2)
00389     {
00390         target_letter_0=ws[0] ;
00391         SG_DEBUG("Setting target letter to %c\n", target_letter_0) ;
00392         return true ;
00393     }
00394 
00395     if (d!=degree || len<1)
00396         SG_ERROR("Dimension mismatch (should be de(seq_length | 1) x degree)\n");
00397 
00398     length=len;
00399 
00400     if (length==0)
00401         length=1;
00402 
00403     int32_t num_weights=degree*(length+max_mismatch);
00404     delete[] weights;
00405     weights=new float64_t[num_weights];
00406 
00407     if (weights)
00408     {
00409         for (int32_t i=0; i<num_weights; i++) {
00410             if (ws[i]) // len(ws) might be != num_weights?
00411                 weights[i]=ws[i];
00412         }
00413         return true;
00414     }
00415     else
00416         return false;
00417 }
00418 */
00419 
00420 bool CSpectrumMismatchRBFKernel::set_AA_matrix(float64_t* AA_matrix_, int32_t nr, int32_t nc)
00421 {
00422     if (AA_matrix_)
00423     {
00424         if (nr!=128 || nc!=128)
00425             SG_ERROR("AA_matrix should be of shape 128x128\n");
00426         delete[] AA_matrix;
00427         AA_matrix=new float64_t[nc*nr];
00428         memcpy(AA_matrix, AA_matrix_, nc*nr*sizeof(float64_t)) ;
00429         SG_DEBUG("Setting AA_matrix\n") ;
00430         memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00431         return true ;
00432     }
00433 
00434     return false;
00435 }
00436 
00437 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max)
00438 {
00439     max_mismatch=max;
00440 
00441     if (lhs!=NULL && rhs!=NULL)
00442         return init(lhs, rhs);
00443     else
00444         return true;
00445 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation