PositionalPWM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Soeren Sonnenburg
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 #include <shogun/distributions/PositionalPWM.h>
00011 #include <shogun/mathematics/Math.h>
00012 #include <shogun/base/Parameter.h>
00013 #include <shogun/features/Alphabet.h>
00014 #include <shogun/features/StringFeatures.h>
00015 
00016 using namespace shogun;
00017 
00018 CPositionalPWM::CPositionalPWM() : CDistribution(),
00019     m_sigma(0), m_mean(0)
00020 {
00021     m_pwm = SGMatrix<float64_t>();
00022     m_w = SGMatrix<float64_t>();
00023     m_poim = SGVector<float64_t>();
00024 
00025     register_params();
00026 }
00027 
00028 CPositionalPWM::~CPositionalPWM()
00029 {
00030 }
00031 
00032 bool CPositionalPWM::train(CFeatures* data)
00033 {
00034     SG_NOTIMPLEMENTED;
00035     return true;
00036 }
00037 
00038 int32_t CPositionalPWM::get_num_model_parameters()
00039 {
00040     return m_pwm.num_rows*m_pwm.num_cols+2;
00041 }
00042 
00043 float64_t CPositionalPWM::get_log_model_parameter(int32_t num_param)
00044 {
00045     ASSERT(num_param>0 && num_param<=m_pwm.num_rows*m_pwm.num_cols+2);
00046 
00047     if (num_param<m_pwm.num_rows*m_pwm.num_cols)
00048     {
00049         return m_pwm[num_param];
00050     }
00051     else if (num_param<m_pwm.num_rows*m_pwm.num_cols+1)
00052         return CMath::log(m_sigma);
00053     else
00054         return CMath::log(m_mean);
00055 }
00056 
00057 float64_t CPositionalPWM::get_log_derivative(int32_t num_param, int32_t num_example)
00058 {
00059     SG_NOTIMPLEMENTED;
00060     return 0;
00061 }
00062 
00063 float64_t CPositionalPWM::get_log_likelihood_example(int32_t num_example)
00064 {
00065     ASSERT(features);
00066     ASSERT(features->get_feature_class() == C_STRING);
00067     ASSERT(features->get_feature_type()==F_BYTE);
00068 
00069     CStringFeatures<uint8_t>* strs=(CStringFeatures<uint8_t>*) features;
00070 
00071     float64_t lik=0;
00072     int32_t len=0;
00073     bool do_free=false;
00074 
00075     uint8_t* str = strs->get_feature_vector(num_example, len, do_free);
00076 
00077     if (!(m_w.num_cols==len))
00078         return 0; //TODO
00079 
00080     for (int32_t i=0; i<len; i++)
00081         lik+=m_w[4*i+str[i]];
00082 
00083     strs->free_feature_vector(str, num_example, do_free);
00084     return lik;
00085 }
00086 
00087 float64_t CPositionalPWM::get_log_likelihood_window(uint8_t* window, int32_t len, float64_t pos)
00088 {
00089     ASSERT(m_pwm.num_cols == len);
00090     float64_t score = CMath::log(1/(m_sigma*CMath::sqrt(2*M_PI))) -
00091             CMath::sq(pos-m_mean)/(2*CMath::sq(m_sigma));
00092 
00093     for (int32_t i=0; i<m_pwm.num_cols; i++)
00094         score+=m_pwm[m_pwm.num_rows*i+window[i]];
00095 
00096     return score;
00097 }
00098 
00099 void CPositionalPWM::compute_w(int32_t num_pos)
00100 {
00101     ASSERT(m_pwm.num_rows>0 && m_pwm.num_cols>0);
00102 
00103     int32_t m_w_rows = CMath::pow(m_pwm.num_rows, m_pwm.num_cols);
00104     int32_t m_w_cols = num_pos;
00105 
00106     m_w = SGMatrix<float64_t>(m_w_cols,m_w_rows);
00107 
00108     uint8_t* window=SG_MALLOC(uint8_t, m_pwm.num_cols);
00109     SGVector<uint8_t>::fill_vector(window, m_pwm.num_cols, (uint8_t) 0);
00110 
00111     const int32_t last_idx=m_pwm.num_cols-1;
00112     for (int32_t i=0; i<m_w_rows; i++)
00113     {
00114         for (int32_t j=0; j<m_w_cols; j++)
00115             m_w[j*m_w_rows+i]=get_log_likelihood_window(window, m_pwm.num_cols, j);
00116 
00117         window[last_idx]++;
00118         int32_t window_ptr=last_idx;
00119         while (window[window_ptr]==m_pwm.num_rows && window_ptr>0)
00120         {
00121             window[window_ptr]=0;
00122             window_ptr--;
00123             window[window_ptr]++;
00124         }
00125 
00126     }
00127 }
00128 
00129 void CPositionalPWM::register_params()
00130 {
00131     m_parameters->add(&m_poim, "poim", "POIM Scoring Matrix");
00132     m_parameters->add(&m_w, "w", "Scoring Matrix");
00133     m_parameters->add(&m_pwm, "pwm", "Positional Weight Matrix.");
00134     m_parameters->add(&m_sigma, "sigma", "Standard Deviation.");
00135     m_parameters->add(&m_mean, "mean", "Mean.");
00136 }
00137 
00138 void CPositionalPWM::compute_scoring(int32_t max_degree)
00139 {
00140     int32_t num_feat=m_w.num_cols;
00141     int32_t num_sym=0;
00142     int32_t order=m_pwm.num_rows;
00143     int32_t num_words=m_pwm.num_cols;
00144 
00145     CAlphabet* alpha=new CAlphabet(DNA);
00146     CStringFeatures<uint16_t>* str= new CStringFeatures<uint16_t>(alpha);
00147     int32_t num_bits=alpha->get_num_bits();
00148     str->compute_symbol_mask_table(num_bits);
00149 
00150     for (int32_t i=0; i<order; i++)
00151         num_sym+=CMath::pow((int32_t) num_words,i+1);
00152 
00153     m_poim = SGVector<float64_t>(num_feat*num_sym);
00154     memset(m_poim.vector,0, size_t(num_feat)*size_t(num_sym));
00155 
00156     uint32_t kmer_mask=0;
00157     uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00158     int32_t offset=0;
00159 
00160     for (int32_t o=0; o<max_degree; o++)
00161     {
00162         float64_t* contrib=&m_poim[offset];
00163         offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00164 
00165         kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00166 
00167         for (int32_t p=-o; p<order; p++)
00168         {
00169             int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00170             uint32_t imer_mask=kmer_mask;
00171             uint32_t jmer_mask=kmer_mask;
00172 
00173             if (p<0)
00174             {
00175                 il=-p;
00176                 m_sym=order-o-p-1;
00177                 o_sym=-p;
00178             }
00179             else if (p<order-o)
00180             {
00181                 ir=p;
00182                 m_sym=order-o-1;
00183             }
00184             else
00185             {
00186                 ir=p;
00187                 m_sym=p;
00188                 o_sym=p-order+o+1;
00189                 jl=order-ir;
00190                 imer_mask=(kmer_mask>>(num_bits*o_sym));
00191                 jmer_mask=(kmer_mask>>(num_bits*jl));
00192             }
00193 
00194             float64_t marginalizer=
00195                 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00196 
00197             for (uint32_t i=0; i<words; i++)
00198             {
00199                 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00200 
00201                 if (p>=0 && p<order-o)
00202                 {
00203                     contrib[x]+=m_w[m_w.num_cols*ir+i]*marginalizer;
00204                 }
00205                 else
00206                 {
00207                     for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00208                     {
00209                         uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00210                         contrib[c]+=m_w[m_w.num_cols*il+i]*marginalizer;
00211                     }
00212                 }
00213             }
00214         }
00215     }
00216 }
00217 
00218 SGMatrix<float64_t> CPositionalPWM::get_scoring(int32_t d)
00219 {
00220     int32_t offs=0;
00221     for (int32_t i=0; i<d-1; i++)
00222         offs+=CMath::pow((int32_t) m_w.num_rows,i+1);
00223     int32_t rows=CMath::pow((int32_t) m_w.num_rows,d);
00224     int32_t cols=m_w.num_cols;
00225     float64_t* scoring_matrix = SG_MALLOC(float64_t, rows*cols);
00226     memcpy(scoring_matrix,m_poim.vector+offs,rows*cols*sizeof(float64_t));
00227     return SGMatrix<float64_t>(scoring_matrix,rows,cols);
00228 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation