00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <vector>
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/io/SGIO.h>
00016 #include <shogun/lib/Signal.h>
00017 #include <shogun/lib/Trie.h>
00018 #include <shogun/base/Parallel.h>
00019
00020 #include <shogun/kernel/SpectrumMismatchRBFKernel.h>
00021 #include <shogun/features/Features.h>
00022 #include <shogun/features/StringFeatures.h>
00023
00024
00025 #include <vector>
00026 #include <string>
00027
00028 #include <assert.h>
00029
00030 #ifndef WIN32
00031 #include <pthread.h>
00032 #endif
00033
00034 using namespace shogun;
00035
00036 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(void)
00037 :CStringKernel<char>(0)
00038 {
00039 init();
00040 register_params();
00041 }
00042
00043 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel (int32_t size,
00044 float64_t* AA_matrix_, int32_t nr, int32_t nc,
00045 int32_t degree_, int32_t max_mismatch_, float64_t width_) : CStringKernel<char>(size),
00046 alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00047 {
00048 lhs=NULL;
00049 rhs=NULL;
00050
00051 target_letter_0=-1 ;
00052
00053 AA_matrix=NULL;
00054 set_AA_matrix(AA_matrix_, nr, nc);
00055 register_params();
00056 }
00057
00058 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(
00059 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_, int32_t nr, int32_t nc, int32_t degree_, int32_t max_mismatch_, float64_t width_)
00060 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_)
00061 {
00062 target_letter_0=-1 ;
00063
00064 AA_matrix=NULL;
00065 set_AA_matrix(AA_matrix_, nr, nc);
00066 init(l, r);
00067 register_params();
00068 }
00069
00070 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel()
00071 {
00072 cleanup();
00073 SG_FREE(AA_matrix);
00074 }
00075
00076
00077 void CSpectrumMismatchRBFKernel::remove_lhs()
00078 {
00079
00080 CKernel::remove_lhs();
00081 }
00082
00083 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r)
00084 {
00085 int32_t lhs_changed=(lhs!=l);
00086 int32_t rhs_changed=(rhs!=r);
00087
00088 CStringKernel<char>::init(l,r);
00089
00090 SG_DEBUG("lhs_changed: %i\n", lhs_changed);
00091 SG_DEBUG("rhs_changed: %i\n", rhs_changed);
00092
00093 CStringFeatures<char>* sf_l=(CStringFeatures<char>*) l;
00094 CStringFeatures<char>* sf_r=(CStringFeatures<char>*) r;
00095
00096 SG_UNREF(alphabet);
00097 alphabet=sf_l->get_alphabet();
00098 CAlphabet* ralphabet=sf_r->get_alphabet();
00099
00100 if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA)))
00101 properties &= ((uint64_t) (-1)) ^ (KP_LINADD | KP_BATCHEVALUATION);
00102
00103 ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet());
00104 SG_UNREF(ralphabet);
00105
00106 compute_all() ;
00107
00108 return init_normalizer();
00109 }
00110
00111 void CSpectrumMismatchRBFKernel::cleanup()
00112 {
00113
00114 SG_UNREF(alphabet);
00115 alphabet=NULL;
00116
00117 CKernel::cleanup();
00118 }
00119
00120 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path, const char* joint_seq, unsigned int index)
00121 {
00122 float64_t diff=0.0 ;
00123
00124 for (unsigned int i=0; i<path.size(); i++)
00125 {
00126 if (path[i]!=joint_seq[index+i])
00127 {
00128 diff += AA_matrix[ (path[i]-1)*128 + path[i] - 1] ;
00129 diff -= 2*AA_matrix[ (path[i]-1)*128 + joint_seq[index+i] - 1] ;
00130 diff += AA_matrix[ (joint_seq[index+i]-1)*128 + joint_seq[index+i] - 1] ;
00131 }
00132 }
00133
00134 return exp( - diff/width) ;
00135 }
00136
00137
00138
00139
00140
00141
00142
00143
00144
00145
00146
00147
00148
00149
00150
00151
00152
00153
00154
00155
00156
00157
00158
00159
00160
00161
00162
00163
00164
00165
00166
00167
00168
00169
00170
00171
00172
00173
00174
00175
00176
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208
00209
00210
00211
00212
00213
00214
00215
00216
00217
00218
00219
00220
00221
00222
00223
00224
00225
00226
00227 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq,
00228 std::vector<struct joint_list_struct> &joint_list,
00229 std::string path, unsigned int d)
00230 {
00231 const char* AA = "ACDEFGHIKLMNPQRSTVWY" ;
00232 const unsigned int num_AA = strlen(AA) ;
00233
00234 assert(path.size()==d) ;
00235
00236 for (unsigned int i=0; i<num_AA; i++)
00237 {
00238 std::vector<struct joint_list_struct> joint_list_ ;
00239
00240 if (d==0)
00241 fprintf(stderr, "i=%i: ", i) ;
00242 if (d==0 && target_letter_0!=-1 && (int)i != target_letter_0 )
00243 continue ;
00244
00245 if (d==1)
00246 {
00247 fprintf(stdout, "*") ;
00248 fflush(stdout) ;
00249 }
00250 if (d==2)
00251 {
00252 fprintf(stdout, "+") ;
00253 fflush(stdout) ;
00254 }
00255
00256 for (unsigned int j=0; j<joint_list.size(); j++)
00257 {
00258 if (joint_seq[joint_list[j].index+d] != AA[i])
00259 {
00260 if (joint_list[j].mismatch+1 <= (unsigned int) max_mismatch)
00261 {
00262 struct joint_list_struct list_item ;
00263 list_item = joint_list[j] ;
00264 list_item.mismatch = joint_list[j].mismatch+1 ;
00265 joint_list_.push_back(list_item) ;
00266 }
00267 }
00268 else
00269 joint_list_.push_back(joint_list[j]) ;
00270 }
00271
00272 if (joint_list_.size()>0)
00273 {
00274 std::string path_ = path + AA[i] ;
00275
00276 if (d+1 < (unsigned int) degree)
00277 {
00278 compute_helper_all(joint_seq, joint_list_, path_, d+1) ;
00279 }
00280 else
00281 {
00282 CArray<float64_t> feats ;
00283 feats.resize_array(kernel_matrix.get_dim1()) ;
00284 feats.zero() ;
00285
00286 for (unsigned int j=0; j<joint_list_.size(); j++)
00287 {
00288 if (width==0.0)
00289 {
00290 feats[joint_list_[j].ex_index]++ ;
00291
00292
00293 }
00294 else
00295 {
00296 if (joint_list_[j].mismatch!=0)
00297 feats[joint_list_[j].ex_index] += AA_helper(path_, joint_seq, joint_list_[j].index) ;
00298 else
00299 feats[joint_list_[j].ex_index] ++ ;
00300 }
00301 }
00302
00303 std::vector<int> idx ;
00304 for (int r=0; r<feats.get_array_size(); r++)
00305 if (feats[r]!=0.0)
00306 idx.push_back(r) ;
00307
00308 for (unsigned int r=0; r<idx.size(); r++)
00309 for (unsigned int s=r; s<idx.size(); s++)
00310 if (s==r)
00311 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ;
00312 else
00313 {
00314 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ;
00315 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[s],idx[r]), idx[s], idx[r]) ;
00316 }
00317 }
00318 }
00319 if (d==0)
00320 fprintf(stdout, "\n") ;
00321 }
00322 }
00323
00324 void CSpectrumMismatchRBFKernel::compute_all()
00325 {
00326 std::string joint_seq ;
00327 std::vector<struct joint_list_struct> joint_list ;
00328
00329 assert(lhs->get_num_vectors()==rhs->get_num_vectors()) ;
00330 kernel_matrix.resize_array(lhs->get_num_vectors(), lhs->get_num_vectors()) ;
00331 kernel_matrix_length = lhs->get_num_vectors()*rhs->get_num_vectors();
00332 for (int i=0; i<lhs->get_num_vectors(); i++)
00333 for (int j=0; j<lhs->get_num_vectors(); j++)
00334 kernel_matrix.set_element(0, i, j) ;
00335
00336 for (int i=0; i<lhs->get_num_vectors(); i++)
00337 {
00338 int32_t alen ;
00339 bool free_avec ;
00340 char* avec = ((CStringFeatures<char>*) lhs)->get_feature_vector(i, alen, free_avec);
00341
00342 for (int apos=0; apos+degree-1<alen; apos++)
00343 {
00344 struct joint_list_struct list_item ;
00345 list_item.ex_index = i ;
00346 list_item.index = apos+joint_seq.size() ;
00347 list_item.mismatch = 0 ;
00348
00349 joint_list.push_back(list_item) ;
00350 }
00351 joint_seq += std::string(avec, alen) ;
00352
00353 ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, i, free_avec);
00354 }
00355
00356 compute_helper_all(joint_seq.c_str(), joint_list, "", 0) ;
00357 }
00358
00359
00360 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b)
00361 {
00362 return kernel_matrix.element(idx_a, idx_b) ;
00363 }
00364
00365
00366
00367
00368
00369
00370
00371
00372
00373
00374
00375
00376
00377
00378
00379
00380
00381
00382
00383
00384
00385
00386
00387
00388
00389
00390
00391
00392
00393
00394
00395
00396
00397
00398
00399
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411
00412
00413
00414 bool CSpectrumMismatchRBFKernel::set_AA_matrix(float64_t* AA_matrix_, int32_t nr, int32_t nc)
00415 {
00416 if (AA_matrix_)
00417 {
00418 if (nr!=128 || nc!=128)
00419 SG_ERROR("AA_matrix should be of shape 128x128\n");
00420 SG_FREE(AA_matrix);
00421 AA_matrix=SG_MALLOC(float64_t, nc*nr);
00422 memcpy(AA_matrix, AA_matrix_, nc*nr*sizeof(float64_t)) ;
00423 SG_DEBUG("Setting AA_matrix\n") ;
00424 memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ;
00425 return true ;
00426 }
00427
00428 return false;
00429 }
00430
00431 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max)
00432 {
00433 max_mismatch=max;
00434
00435 if (lhs!=NULL && rhs!=NULL)
00436 return init(lhs, rhs);
00437 else
00438 return true;
00439 }
00440
00441 void CSpectrumMismatchRBFKernel::register_params()
00442 {
00443 m_parameters->add(°ree, "degree", "degree of the kernel");
00444 m_parameters->add(&AA_matrix_length, "AA_matrix_length", "the length of AA matrix");
00445 m_parameters->add_vector(&AA_matrix, &AA_matrix_length, "AA_matrix", "128*128 scalar product matrix");
00446 m_parameters->add(&width,"width","width of Gaussian");
00447 m_parameters->add(&target_letter_0, "target_letter_0","target letter 0");
00448 m_parameters->add(&initialized, "initialized", "the mark of initialization status");
00449 m_parameters->add_vector((SGString<float64_t>**)&kernel_matrix, &kernel_matrix_length, "kernel_matrix", "the kernel matrix with its length defined by the number of vectors of the string features");
00450 }
00451
00452 void CSpectrumMismatchRBFKernel::register_alphabet()
00453 {
00454 m_parameters->add((CSGObject**)&alphabet, "alphabet", "the alphabet used by kernel");
00455 }
00456
00457 void CSpectrumMismatchRBFKernel::init()
00458 {
00459 alphabet = NULL;
00460 degree = 0;
00461 max_mismatch = 0;
00462 AA_matrix = NULL;
00463 width = 0.0;
00464
00465 initialized = false;
00466 target_letter_0 = 0;
00467 }
00468