OligoStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2008 Christian Igel, Tobias Glasmachers
00008  * Copyright (C) 2008 Christian Igel, Tobias Glasmachers
00009  *
00010  * Shogun adjustments (W) 2008-2009 Soeren Sonnenburg
00011  * Copyright (C) 2008-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00012  *
00013  */
00014 #include "kernel/OligoStringKernel.h"
00015 #include "kernel/SqrtDiagKernelNormalizer.h"
00016 #include "features/StringFeatures.h"
00017 
00018 #include <map>
00019 #include <vector>
00020 #include <algorithm>
00021 
00022 using namespace shogun;
00023 
00024 COligoStringKernel::COligoStringKernel(void)
00025   : CStringKernel<char>()
00026 {
00027     init();
00028 }
00029 
00030 COligoStringKernel::COligoStringKernel(int32_t cache_sz, int32_t kmer_len, float64_t w)
00031 : CStringKernel<char>(cache_sz)
00032 {
00033     init();
00034 
00035     k=kmer_len;
00036     width=w;
00037 }
00038 
00039 COligoStringKernel::~COligoStringKernel()
00040 {
00041     cleanup();
00042 }
00043 
00044 void COligoStringKernel::cleanup()
00045 {
00046     delete[] gauss_table;
00047     gauss_table=NULL;
00048     gauss_table_len=0;
00049 
00050     CKernel::cleanup();
00051 }
00052 
00053 bool COligoStringKernel::init(CFeatures* l, CFeatures* r)
00054 {
00055     cleanup();
00056 
00057     CStringKernel<char>::init(l,r);
00058     int32_t max_len=CMath::max(
00059             ((CStringFeatures<char>*) l)->get_max_vector_length(),
00060             ((CStringFeatures<char>*) r)->get_max_vector_length()
00061             );
00062 
00063     getExpFunctionCache(max_len);
00064     return init_normalizer();
00065 }
00066 
00067 void COligoStringKernel::encodeOligo(
00068     const std::string& sequence, uint32_t k_mer_length,
00069     const std::string& allowed_characters,
00070     std::vector< std::pair<int32_t, float64_t> >& values)
00071 {
00072     float64_t oligo_value = 0.;
00073     float64_t factor      = 1.;
00074     std::map<std::string::value_type, uint32_t> residue_values;
00075     uint32_t counter = 0;
00076     uint32_t number_of_residues = allowed_characters.size();
00077     uint32_t sequence_length = sequence.size();
00078     bool sequence_ok = true;
00079 
00080     // checking if sequence contains illegal characters
00081     for (uint32_t i = 0; i < sequence.size(); ++i)
00082     {
00083         if (allowed_characters.find(sequence.at(i)) == std::string::npos)
00084             sequence_ok = false;
00085     }
00086 
00087     if (sequence_ok && k_mer_length <= sequence_length)
00088     {
00089         values.resize(sequence_length - k_mer_length + 1,
00090             std::pair<int32_t, float64_t>());
00091         for (uint32_t i = 0; i < number_of_residues; ++i)
00092         {   
00093             residue_values.insert(std::make_pair(allowed_characters[i], counter));
00094             ++counter;
00095         }
00096         for (int32_t k = k_mer_length - 1; k >= 0; k--)
00097         {
00098             oligo_value += factor * residue_values[sequence[k]];
00099             factor *= number_of_residues;
00100         }
00101         factor /= number_of_residues;
00102         counter = 0;
00103         values[counter].first = 1;
00104         values[counter].second = oligo_value;
00105         ++counter;
00106 
00107         for (uint32_t j = 1; j < sequence_length - k_mer_length + 1; j++)
00108         {
00109             oligo_value -= factor * residue_values[sequence[j - 1]];
00110             oligo_value = oligo_value * number_of_residues +
00111                 residue_values[sequence[j + k_mer_length - 1]];
00112 
00113             values[counter].first = j + 1;
00114             values[counter].second = oligo_value ;
00115             ++counter;
00116         }
00117         stable_sort(values.begin(), values.end(), cmpOligos_);
00118     }
00119     else
00120     {
00121         values.clear();
00122     }   
00123 }
00124 
00125 void COligoStringKernel::getSequences(
00126     const std::vector<std::string>& sequences, uint32_t k_mer_length,
00127     const std::string& allowed_characters,
00128     std::vector< std::vector< std::pair<int32_t, float64_t> > >& encoded_sequences)
00129 {
00130     std::vector< std::pair<int32_t, float64_t> > temp_vector;
00131     encoded_sequences.resize(sequences.size(),
00132         std::vector< std::pair<int32_t, float64_t> >());
00133 
00134     for (uint32_t i = 0; i < sequences.size(); ++i)
00135     {
00136         encodeOligo(sequences[i], k_mer_length, allowed_characters, temp_vector);
00137         encoded_sequences[i] = temp_vector;
00138     }
00139 }
00140 
00141 void COligoStringKernel::getExpFunctionCache(uint32_t sequence_length)
00142 {
00143     delete[] gauss_table;
00144     gauss_table=new float64_t[sequence_length];
00145 
00146     gauss_table[0] = 1;
00147     for (uint32_t i = 1; i < sequence_length - 1; i++)
00148         gauss_table[i] = exp((-1 / (CMath::sq(width))) * CMath::sq(i));
00149 
00150     gauss_table_len=sequence_length;
00151 }
00152 
00153 float64_t COligoStringKernel::kernelOligoFast(
00154     const std::vector< std::pair<int32_t, float64_t> >& x,
00155     const std::vector< std::pair<int32_t, float64_t> >& y,
00156     int32_t max_distance)
00157 {
00158     float64_t result = 0;
00159     int32_t  i1     = 0;
00160     int32_t  i2     = 0;
00161     int32_t  c1     = 0;
00162     uint32_t x_size = x.size();
00163     uint32_t y_size = y.size();
00164 
00165     while ((uint32_t) i1 < x_size && (uint32_t) i2 < y_size)
00166     {
00167         if (x[i1].second == y[i2].second)
00168         {
00169             if (max_distance < 0
00170                     || (abs(x[i1].first - y[i2].first)) <= max_distance)
00171             {
00172                 result += gauss_table[abs((x[i1].first - y[i2].first))];
00173                 if (x[i1].second == x[i1 + 1].second)
00174                 {
00175                     i1++;
00176                     c1++;
00177                 }
00178                 else if (y[i2].second == y[i2 + 1].second)
00179                 {
00180                     i2++;
00181                     i1 -= c1;
00182                     c1 = 0;
00183                 }
00184                 else
00185                 {
00186                     i1++;
00187                     i2++;
00188                 }
00189             }
00190             else
00191             {
00192                 if (x[i1].first < y[i2].first)
00193                 {
00194                     if (x[i1].second == x[i1 + 1].second)
00195                     {
00196                         i1++;
00197                     }
00198                     else if (y[i2].second == y[i2 + 1].second)
00199                     {
00200                         while(y[i2++].second == y[i2].second)
00201                         {
00202                             ;
00203                         }
00204                         ++i1;
00205                         c1 = 0;
00206                     }
00207                     else
00208                     {
00209                         i1++;
00210                         i2++;
00211                         c1 = 0;
00212                     }
00213                 }
00214                 else
00215                 {
00216                     i2++;
00217                     i1 -= c1;
00218                     c1 = 0;
00219                 }
00220             }
00221         }
00222         else
00223         {
00224             if (x[i1].second < y[i2].second)
00225                 i1++;
00226             else
00227                 i2++;
00228             c1 = 0;
00229         }
00230     }
00231     return result;
00232 }       
00233 
00234 
00235 float64_t COligoStringKernel::compute(int32_t idx_a, int32_t idx_b)
00236 {
00237     int32_t alen, blen;
00238     bool free_a, free_b;
00239     char* avec=((CStringFeatures<char>*) lhs)->get_feature_vector(idx_a, alen, free_a);
00240     char* bvec=((CStringFeatures<char>*) rhs)->get_feature_vector(idx_b, blen, free_b);
00241     std::vector< std::pair<int32_t, float64_t> > aenc;
00242     std::vector< std::pair<int32_t, float64_t> > benc;
00243     encodeOligo(std::string(avec, alen), k, "ACGT", aenc);
00244     encodeOligo(std::string(bvec, alen), k, "ACGT", benc);
00245     float64_t result=kernelOligoFast(aenc, benc);
00246     ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, idx_a, free_a);
00247     ((CStringFeatures<char>*) rhs)->free_feature_vector(bvec, idx_b, free_b);
00248     return result;
00249 }
00250 
00251 void COligoStringKernel::init()
00252 {
00253     k=0;
00254     width=0.0;
00255     gauss_table=NULL;
00256     gauss_table_len=0;
00257 
00258     set_normalizer(new CSqrtDiagKernelNormalizer());
00259 
00260     m_parameters->add(&k, "k", "K-mer length.");
00261     m_parameters->add(&width, "width", "Width of Gaussian.");
00262     m_parameters->add_vector(&gauss_table, &gauss_table_len, "gauss_table", "Gauss Cache Table.");
00263 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation