HMM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef __CHMM_H__
00013 #define __CHMM_H__
00014 
00015 #include "lib/Mathematics.h"
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "lib/config.h"
00019 #include "features/Features.h"
00020 #include "features/StringFeatures.h"
00021 #include "distributions/Distribution.h"
00022 
00023 #include <stdio.h>
00024 
00025 #ifdef USE_HMMPARALLEL
00026 #define USE_HMMPARALLEL_STRUCTURES 1
00027 #endif
00028 
00029 namespace shogun
00030 {
00031     class CFeatures;
00032     template <class ST> class CStringFeatures;
00035 
00037 typedef float64_t T_ALPHA_BETA_TABLE;
00038 
00039 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00040 
00041 struct T_ALPHA_BETA
00042 {
00044     int32_t dimension;
00045 
00047     T_ALPHA_BETA_TABLE* table;
00048 
00050     bool updated;
00051 
00053     float64_t sum;
00054 };
00055 #endif // DOXYGEN_SHOULD_SKIP_THIS
00056 
00061 #ifdef USE_BIGSTATES
00062 typedef uint16_t T_STATES ;
00063 #else
00064 typedef uint8_t T_STATES ;
00065 #endif
00066 typedef T_STATES* P_STATES ;
00067 
00069 
00070 enum BaumWelchViterbiType
00071 {
00072     BW_NORMAL,
00073     BW_TRANS,
00074     BW_DEFINED,
00075     VIT_NORMAL,
00076     VIT_DEFINED
00077 };
00078 
00079 
00081 class Model
00082 {
00083     public:
00085         Model();
00086 
00088         virtual ~Model();
00089 
00091         inline void sort_learn_a()
00092         {
00093             CMath::sort(learn_a,2) ;
00094         }
00095 
00097         inline void sort_learn_b()
00098         {
00099             CMath::sort(learn_b,2) ;
00100         }
00101 
00106 
00107         inline int32_t get_learn_a(int32_t line, int32_t column) const
00108         {
00109             return learn_a[line*2 + column];
00110         }
00111 
00113         inline int32_t get_learn_b(int32_t line, int32_t column) const 
00114         {
00115             return learn_b[line*2 + column];
00116         }
00117 
00119         inline int32_t get_learn_p(int32_t offset) const 
00120         {
00121             return learn_p[offset];
00122         }
00123 
00125         inline int32_t get_learn_q(int32_t offset) const 
00126         {
00127             return learn_q[offset];
00128         }
00129 
00131         inline int32_t get_const_a(int32_t line, int32_t column) const
00132         {
00133             return const_a[line*2 + column];
00134         }
00135 
00137         inline int32_t get_const_b(int32_t line, int32_t column) const 
00138         {
00139             return const_b[line*2 + column];
00140         }
00141 
00143         inline int32_t get_const_p(int32_t offset) const 
00144         {
00145             return const_p[offset];
00146         }
00147 
00149         inline int32_t get_const_q(int32_t offset) const
00150         {
00151             return const_q[offset];
00152         }
00153 
00155         inline float64_t get_const_a_val(int32_t line) const
00156         {
00157             return const_a_val[line];
00158         }
00159 
00161         inline float64_t get_const_b_val(int32_t line) const 
00162         {
00163             return const_b_val[line];
00164         }
00165 
00167         inline float64_t get_const_p_val(int32_t offset) const 
00168         {
00169             return const_p_val[offset];
00170         }
00171 
00173         inline float64_t get_const_q_val(int32_t offset) const
00174         {
00175             return const_q_val[offset];
00176         }
00177 #ifdef FIX_POS
00178 
00179         inline char get_fix_pos_state(int32_t pos, T_STATES state, T_STATES num_states)
00180         {
00181 #ifdef HMM_DEBUG
00182             if ((pos<0)||(pos*num_states+state>65336))
00183                 SG_DEBUG("index out of range in get_fix_pos_state(%i,%i,%i) \n", pos,state,num_states) ;
00184 #endif
00185             return fix_pos_state[pos*num_states+state] ;
00186         }
00187 #endif
00188 
00189 
00194 
00195         inline void set_learn_a(int32_t offset, int32_t value)
00196         {
00197             learn_a[offset]=value;
00198         }
00199 
00201         inline void set_learn_b(int32_t offset, int32_t value)
00202         {
00203             learn_b[offset]=value;
00204         }
00205 
00207         inline void set_learn_p(int32_t offset, int32_t value)
00208         {
00209             learn_p[offset]=value;
00210         }
00211 
00213         inline void set_learn_q(int32_t offset, int32_t value)
00214         {
00215             learn_q[offset]=value;
00216         }
00217 
00219         inline void set_const_a(int32_t offset, int32_t value)
00220         {
00221             const_a[offset]=value;
00222         }
00223 
00225         inline void set_const_b(int32_t offset, int32_t value)
00226         {
00227             const_b[offset]=value;
00228         }
00229 
00231         inline void set_const_p(int32_t offset, int32_t value)
00232         {
00233             const_p[offset]=value;
00234         }
00235 
00237         inline void set_const_q(int32_t offset, int32_t value)
00238         {
00239             const_q[offset]=value;
00240         }
00241 
00243         inline void set_const_a_val(int32_t offset, float64_t value)
00244         {
00245             const_a_val[offset]=value;
00246         }
00247 
00249         inline void set_const_b_val(int32_t offset, float64_t value)
00250         {
00251             const_b_val[offset]=value;
00252         }
00253 
00255         inline void set_const_p_val(int32_t offset, float64_t value)
00256         {
00257             const_p_val[offset]=value;
00258         }
00259 
00261         inline void set_const_q_val(int32_t offset, float64_t value)
00262         {
00263             const_q_val[offset]=value;
00264         }
00265 #ifdef FIX_POS
00266 
00267         inline void set_fix_pos_state(
00268             int32_t pos, T_STATES state, T_STATES num_states, char value)
00269         {
00270 #ifdef HMM_DEBUG
00271             if ((pos<0)||(pos*num_states+state>65336))
00272                 SG_DEBUG("index out of range in set_fix_pos_state(%i,%i,%i,%i) [%i]\n", pos,state,num_states,(int)value, pos*num_states+state) ;
00273 #endif
00274             fix_pos_state[pos*num_states+state]=value;
00275             if (value==FIX_ALLOWED)
00276                 for (int32_t i=0; i<num_states; i++)
00277                     if (get_fix_pos_state(pos,i,num_states)==FIX_DEFAULT)
00278                         set_fix_pos_state(pos,i,num_states,FIX_DISALLOWED) ;
00279         }
00281 
00283         const static char FIX_DISALLOWED ;
00284 
00286         const static char FIX_ALLOWED ;
00287 
00289         const static char FIX_DEFAULT ;
00290 
00292         const static float64_t DISALLOWED_PENALTY ;
00293 #endif
00294     protected:
00301 
00302         int32_t* learn_a;
00303 
00305         int32_t* learn_b;
00306 
00308         int32_t* learn_p;
00309 
00311         int32_t* learn_q;
00313 
00320 
00321         int32_t* const_a;
00322 
00324         int32_t* const_b;
00325 
00327         int32_t* const_p;
00328 
00330         int32_t* const_q;       
00331 
00332 
00334         float64_t* const_a_val;
00335 
00337         float64_t* const_b_val;
00338 
00340         float64_t* const_p_val;
00341 
00343         float64_t* const_q_val;     
00344 
00345 #ifdef FIX_POS
00346 
00349         char* fix_pos_state;
00350 #endif
00351 
00352 };
00353 
00354 
00365 class CHMM : public CDistribution
00366 {
00367     private:
00368 
00369         T_STATES trans_list_len ;
00370         T_STATES **trans_list_forward  ;
00371         T_STATES *trans_list_forward_cnt  ;
00372         float64_t **trans_list_forward_val ;
00373         T_STATES **trans_list_backward  ;
00374         T_STATES *trans_list_backward_cnt  ;
00375         bool mem_initialized ;
00376 
00377 #ifdef USE_HMMPARALLEL_STRUCTURES
00378 
00380         struct S_DIM_THREAD_PARAM
00381         {
00382             CHMM* hmm;
00383             int32_t dim;
00384             float64_t prob_sum;
00385         };
00386 
00388         struct S_BW_THREAD_PARAM
00389         {
00390             CHMM* hmm;
00391             int32_t dim_start;
00392             int32_t dim_stop;
00393 
00394             float64_t ret;
00395 
00396             float64_t* p_buf;
00397             float64_t* q_buf;
00398             float64_t* a_buf;
00399             float64_t* b_buf;
00400         };
00401 
00402         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t dim) {
00403             return alpha_cache[dim%parallel->get_num_threads()] ; } ;
00404         inline T_ALPHA_BETA & BETA_CACHE(int32_t dim) {
00405             return beta_cache[dim%parallel->get_num_threads()] ; } ;
00406 #ifdef USE_LOGSUMARRAY 
00407         inline float64_t* ARRAYS(int32_t dim) {
00408             return arrayS[dim%parallel->get_num_threads()] ; } ;
00409 #endif
00410         inline float64_t* ARRAYN1(int32_t dim) {
00411             return arrayN1[dim%parallel->get_num_threads()] ; } ;
00412         inline float64_t* ARRAYN2(int32_t dim) {
00413             return arrayN2[dim%parallel->get_num_threads()] ; } ;
00414         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) {
00415             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00416         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) const {
00417             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00418         inline T_STATES* PATH(int32_t dim) {
00419             return path[dim%parallel->get_num_threads()] ; } ;
00420         inline bool & PATH_PROB_UPDATED(int32_t dim) {
00421             return path_prob_updated[dim%parallel->get_num_threads()] ; } ;
00422         inline int32_t & PATH_PROB_DIMENSION(int32_t dim) {
00423             return path_prob_dimension[dim%parallel->get_num_threads()] ; } ;
00424 #else
00425         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t /*dim*/) {
00426             return alpha_cache ; } ;
00427         inline T_ALPHA_BETA & BETA_CACHE(int32_t /*dim*/) {
00428             return beta_cache ; } ;
00429 #ifdef USE_LOGSUMARRAY
00430         inline float64_t* ARRAYS(int32_t dim) {
00431             return arrayS ; } ;
00432 #endif
00433         inline float64_t* ARRAYN1(int32_t /*dim*/) {
00434             return arrayN1 ; } ;
00435         inline float64_t* ARRAYN2(int32_t /*dim*/) {
00436             return arrayN2 ; } ;
00437         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) {
00438             return states_per_observation_psi ; } ;
00439         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) const {
00440             return states_per_observation_psi ; } ;
00441         inline T_STATES* PATH(int32_t /*dim*/) {
00442             return path ; } ;
00443         inline bool & PATH_PROB_UPDATED(int32_t /*dim*/) {
00444             return path_prob_updated ; } ;
00445         inline int32_t & PATH_PROB_DIMENSION(int32_t /*dim*/) {
00446             return path_prob_dimension ; } ;
00447 #endif
00448 
00453         bool converged(float64_t x, float64_t y);
00454 
00460     public:
00462         CHMM(void);
00463 
00474         CHMM(
00475             int32_t N, int32_t M, Model* model, float64_t PSEUDO);
00476         CHMM(
00477             CStringFeatures<uint16_t>* obs, int32_t N, int32_t M,
00478             float64_t PSEUDO);
00479         CHMM(
00480             int32_t N, float64_t* p, float64_t* q, float64_t* a);
00481         CHMM(
00482             int32_t N, float64_t* p, float64_t* q, int32_t num_trans,
00483             float64_t* a_trans);
00484 
00489         CHMM(FILE* model_file, float64_t PSEUDO);
00490 
00492         CHMM(CHMM* h);
00493 
00495         virtual ~CHMM();
00496 
00505         virtual bool train(CFeatures* data=NULL);
00506         virtual inline int32_t get_num_model_parameters() { return N*(N+M+2); }
00507         virtual float64_t get_log_model_parameter(int32_t num_param);
00508         virtual float64_t get_log_derivative(int32_t num_param, int32_t num_example);
00509         virtual float64_t get_log_likelihood_example(int32_t num_example)
00510         {
00511             return model_probability(num_example);
00512         }
00513 
00519         bool initialize(Model* model, float64_t PSEUDO, FILE* model_file=NULL);
00521 
00523         bool alloc_state_dependend_arrays();
00524 
00526         void free_state_dependend_arrays();
00527 
00539         float64_t forward_comp(int32_t time, int32_t state, int32_t dimension);
00540         float64_t forward_comp_old(
00541             int32_t time, int32_t state, int32_t dimension);
00542 
00550         float64_t backward_comp(int32_t time, int32_t state, int32_t dimension);
00551         float64_t backward_comp_old(
00552             int32_t time, int32_t state, int32_t dimension);
00553 
00558         float64_t best_path(int32_t dimension);
00559         inline uint16_t get_best_path_state(int32_t dim, int32_t t)
00560         {
00561             ASSERT(PATH(dim));
00562             return PATH(dim)[t];
00563         }
00564 
00567         float64_t model_probability_comp() ;
00568 
00570         inline float64_t model_probability(int32_t dimension=-1)
00571         {
00572             //for faster calculation cache model probability
00573             if (dimension==-1)
00574             {
00575                 if (mod_prob_updated)
00576                     return mod_prob/p_observations->get_num_vectors();
00577                 else
00578                     return model_probability_comp()/p_observations->get_num_vectors();
00579             }
00580             else
00581                 return forward(p_observations->get_vector_length(dimension), 0, dimension);
00582         }
00583 
00589         inline float64_t linear_model_probability(int32_t dimension)
00590         {
00591             float64_t lik=0;
00592             int32_t len=0;
00593             bool free_vec;
00594             uint16_t* o=p_observations->get_feature_vector(dimension, len, free_vec);
00595             float64_t* obs_b=observation_matrix_b;
00596 
00597             ASSERT(N==len);
00598 
00599             for (int32_t i=0; i<N; i++)
00600             {
00601                 lik+=obs_b[*o++];
00602                 obs_b+=M;
00603             }
00604             p_observations->free_feature_vector(o, dimension, free_vec);
00605             return lik;
00606 
00607             // sorry, the above code is the speed optimized version of :
00608             /*  float64_t lik=0;
00609 
00610                 for (int32_t i=0; i<N; i++)
00611                 lik+=get_b(i, p_observations->get_feature(dimension, i));
00612                 return lik;
00613                 */
00614             // : that
00615         }
00616 
00618 
00621         inline bool set_iterations(int32_t num) { iterations=num; return true; }
00622         inline int32_t get_iterations() { return iterations; }
00623         inline bool set_epsilon (float64_t eps) { epsilon=eps; return true; }
00624         inline float64_t get_epsilon() { return epsilon; }
00625 
00629         bool baum_welch_viterbi_train(BaumWelchViterbiType type);
00630 
00637         void estimate_model_baum_welch(CHMM* train);
00638         void estimate_model_baum_welch_trans(CHMM* train);
00639 
00640 #ifdef USE_HMMPARALLEL_STRUCTURES
00641         void ab_buf_comp(
00642             float64_t* p_buf, float64_t* q_buf, float64_t* a_buf,
00643             float64_t* b_buf, int32_t dim) ;
00644 #else
00645         void estimate_model_baum_welch_old(CHMM* train);
00646 #endif
00647 
00651         void estimate_model_baum_welch_defined(CHMM* train);
00652 
00656         void estimate_model_viterbi(CHMM* train);
00657 
00661         void estimate_model_viterbi_defined(CHMM* train);
00662 
00664 
00666         bool linear_train(bool right_align=false);
00667 
00669         bool permutation_entropy(int32_t window_width, int32_t sequence_number);
00670 
00677         void output_model(bool verbose=false);
00678 
00680         void output_model_defined(bool verbose=false);
00682 
00683 
00686 
00688         void normalize(bool keep_dead_states=false);
00689 
00693         void add_states(int32_t num_states, float64_t default_val=0);
00694 
00700         bool append_model(
00701             CHMM* append_model, float64_t* cur_out, float64_t* app_out);
00702 
00706         bool append_model(CHMM* append_model);
00707 
00709         void chop(float64_t value);
00710 
00712         void convert_to_log();
00713 
00715         void init_model_random();
00716 
00722         void init_model_defined();
00723 
00725         void clear_model();
00726 
00728         void clear_model_defined();
00729 
00731         void copy_model(CHMM* l);
00732 
00737         void invalidate_model();
00738 
00742         inline bool get_status() const 
00743         {   
00744             return status; 
00745         } 
00746 
00748         inline float64_t get_pseudo() const
00749         {
00750             return PSEUDO ;
00751         }
00752 
00754         inline void set_pseudo(float64_t pseudo) 
00755         {
00756             PSEUDO=pseudo ;
00757         }
00758 
00759 #ifdef USE_HMMPARALLEL_STRUCTURES
00760         static void* bw_dim_prefetch(void * params);
00761         static void* bw_single_dim_prefetch(void * params);
00762         static void* vit_dim_prefetch(void * params);
00763 #endif
00764 
00765 #ifdef FIX_POS
00766 
00769         inline bool set_fix_pos_state(int32_t pos, T_STATES state, char value)
00770         {
00771             if (!model)
00772                 return false ;
00773             model->set_fix_pos_state(pos, state, N, value) ;
00774             return true ;
00775         } ;
00776 #endif  
00777 
00778 
00787         void set_observations(CStringFeatures<uint16_t>* obs, CHMM* hmm=NULL);
00788 
00792         void set_observation_nocache(CStringFeatures<uint16_t>* obs);
00793 
00795         inline CStringFeatures<uint16_t>* get_observations()
00796         {
00797             SG_REF(p_observations);
00798             return p_observations;
00799         }
00801 
00869         bool load_definitions(FILE* file, bool verbose, bool initialize=true);
00870 
00906         bool load_model(FILE* file);
00907 
00911         bool save_model(FILE* file);
00912 
00916         bool save_model_derivatives(FILE* file);
00917 
00921         bool save_model_derivatives_bin(FILE* file);
00922 
00926         bool save_model_bin(FILE* file);
00927 
00929         bool check_model_derivatives() ;
00930         bool check_model_derivatives_combined() ;
00931 
00937         T_STATES* get_path(int32_t dim, float64_t& prob);
00938 
00942         bool save_path(FILE* file);
00943 
00947         bool save_path_derivatives(FILE* file);
00948 
00952         bool save_path_derivatives_bin(FILE* file);
00953 
00954 #ifdef USE_HMMDEBUG
00955 
00956         bool check_path_derivatives() ;
00957 #endif //USE_HMMDEBUG
00958 
00962         bool save_likelihood_bin(FILE* file);
00963 
00967         bool save_likelihood(FILE* file);
00969 
00975 
00977         inline T_STATES get_N() const { return N ; }
00978 
00980         inline int32_t get_M() const { return M ; }
00981 
00986         inline void set_q(T_STATES offset, float64_t value)
00987         {
00988 #ifdef HMM_DEBUG
00989             if (offset>=N)
00990                 SG_DEBUG("index out of range in set_q(%i,%e) [%i]\n", offset,value,N) ;
00991 #endif
00992             end_state_distribution_q[offset]=value;
00993         }
00994 
00999         inline void set_p(T_STATES offset, float64_t value)
01000         {
01001 #ifdef HMM_DEBUG
01002             if (offset>=N)
01003                 SG_DEBUG("index out of range in set_p(%i,.) [%i]\n", offset,N) ;
01004 #endif
01005             initial_state_distribution_p[offset]=value;
01006         }
01007 
01013         inline void set_A(T_STATES line_, T_STATES column, float64_t value)
01014         {
01015 #ifdef HMM_DEBUG
01016             if ((line_>N)||(column>N))
01017                 SG_DEBUG("index out of range in set_A(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01018 #endif
01019             transition_matrix_A[line_+column*N]=value;
01020         }
01021 
01027         inline void set_a(T_STATES line_, T_STATES column, float64_t value)
01028         {
01029 #ifdef HMM_DEBUG
01030             if ((line_>N)||(column>N))
01031                 SG_DEBUG("index out of range in set_a(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01032 #endif
01033             transition_matrix_a[line_+column*N]=value; // look also best_path!
01034         }
01035 
01041         inline void set_B(T_STATES line_, uint16_t column, float64_t value)
01042         {
01043 #ifdef HMM_DEBUG
01044             if ((line_>=N)||(column>=M))
01045                 SG_DEBUG("index out of range in set_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01046 #endif
01047             observation_matrix_B[line_*M+column]=value;
01048         }
01049 
01055         inline void set_b(T_STATES line_, uint16_t column, float64_t value)
01056         {
01057 #ifdef HMM_DEBUG
01058             if ((line_>=N)||(column>=M))
01059                 SG_DEBUG("index out of range in set_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01060 #endif
01061             observation_matrix_b[line_*M+column]=value;
01062         }
01063 
01070         inline void set_psi(
01071             int32_t time, T_STATES state, T_STATES value, int32_t dimension)
01072         {
01073 #ifdef HMM_DEBUG
01074             if ((time>=p_observations->get_max_vector_length())||(state>N))
01075                 SG_DEBUG("index out of range in set_psi(%i,%i,.) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01076 #endif
01077             STATES_PER_OBSERVATION_PSI(dimension)[time*N+state]=value;
01078         }
01079 
01084         inline float64_t get_q(T_STATES offset) const 
01085         {
01086 #ifdef HMM_DEBUG
01087             if (offset>=N)
01088                 SG_DEBUG("index out of range in %e=get_q(%i) [%i]\n", end_state_distribution_q[offset],offset,N) ;
01089 #endif
01090             return end_state_distribution_q[offset];
01091         }
01092 
01097         inline float64_t get_p(T_STATES offset) const 
01098         {
01099 #ifdef HMM_DEBUG
01100             if (offset>=N)
01101                 SG_DEBUG("index out of range in get_p(%i,.) [%i]\n", offset,N) ;
01102 #endif
01103             return initial_state_distribution_p[offset];
01104         }
01105 
01111         inline float64_t get_A(T_STATES line_, T_STATES column) const
01112         {
01113 #ifdef HMM_DEBUG
01114             if ((line_>N)||(column>N))
01115                 SG_DEBUG("index out of range in get_A(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01116 #endif
01117             return transition_matrix_A[line_+column*N];
01118         }
01119 
01125         inline float64_t get_a(T_STATES line_, T_STATES column) const
01126         {
01127 #ifdef HMM_DEBUG
01128             if ((line_>N)||(column>N))
01129                 SG_DEBUG("index out of range in get_a(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01130 #endif
01131             return transition_matrix_a[line_+column*N]; // look also best_path()!
01132         }
01133 
01139         inline float64_t get_B(T_STATES line_, uint16_t column) const
01140         {
01141 #ifdef HMM_DEBUG
01142             if ((line_>=N)||(column>=M))
01143                 SG_DEBUG("index out of range in get_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01144 #endif
01145             return observation_matrix_B[line_*M+column];
01146         }
01147 
01153         inline float64_t get_b(T_STATES line_, uint16_t column) const 
01154         {
01155 #ifdef HMM_DEBUG
01156             if ((line_>=N)||(column>=M))
01157                 SG_DEBUG("index out of range in get_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01158 #endif
01159             //SG_PRINT("idx %d\n", line_*M+column);
01160             return observation_matrix_b[line_*M+column];
01161         }
01162 
01169         inline T_STATES get_psi(
01170             int32_t time, T_STATES state, int32_t dimension) const
01171         {
01172 #ifdef HMM_DEBUG
01173             if ((time>=p_observations->get_max_vector_length())||(state>N))
01174                 SG_DEBUG("index out of range in get_psi(%i,%i) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01175 #endif
01176             return STATES_PER_OBSERVATION_PSI(dimension)[time*N+state];
01177         }
01178 
01180 
01182         inline virtual const char* get_name() const { return "HMM"; }
01183 
01184     protected:
01189 
01190         int32_t M;
01191 
01193         int32_t N;
01194 
01196         float64_t PSEUDO;
01197 
01198         // line number during processing input files
01199         int32_t line;
01200 
01202         CStringFeatures<uint16_t>* p_observations;
01203 
01204         //train definition for HMM
01205         Model* model;
01206 
01208         float64_t* transition_matrix_A;
01209 
01211         float64_t* observation_matrix_B;
01212 
01214         float64_t* transition_matrix_a;
01215 
01217         float64_t* initial_state_distribution_p;
01218 
01220         float64_t* end_state_distribution_q;        
01221 
01223         float64_t* observation_matrix_b;    
01224 
01226         int32_t iterations;
01227         int32_t iteration_count;
01228 
01230         float64_t epsilon;
01231         int32_t conv_it;
01232 
01234         float64_t all_pat_prob; 
01235 
01237         float64_t pat_prob; 
01238 
01240         float64_t mod_prob; 
01241 
01243         bool mod_prob_updated;  
01244 
01246         bool all_path_prob_updated; 
01247 
01249         int32_t path_deriv_dimension;
01250 
01252         bool path_deriv_updated;
01253 
01254         // true if model is using log likelihood
01255         bool loglikelihood;     
01256 
01257         // true->ok, false->error
01258         bool status;            
01259 
01260         // true->stolen from other HMMs, false->got own
01261         bool reused_caches;
01263 
01264 #ifdef USE_HMMPARALLEL_STRUCTURES
01265 
01266         float64_t** arrayN1 /*[parallel.get_num_threads()]*/ ;
01268         float64_t** arrayN2 /*[parallel.get_num_threads()]*/ ;
01269 #else //USE_HMMPARALLEL_STRUCTURES
01270 
01271         float64_t* arrayN1;
01273         float64_t* arrayN2;
01274 #endif //USE_HMMPARALLEL_STRUCTURES
01275 
01276 #ifdef USE_LOGSUMARRAY
01277 #ifdef USE_HMMPARALLEL_STRUCTURES
01278 
01279         float64_t** arrayS /*[parallel.get_num_threads()]*/;
01280 #else
01281 
01282         float64_t* arrayS;
01283 #endif // USE_HMMPARALLEL_STRUCTURES
01284 #endif // USE_LOGSUMARRAY
01285 
01286 #ifdef USE_HMMPARALLEL_STRUCTURES
01287 
01289         T_ALPHA_BETA* alpha_cache /*[parallel.get_num_threads()]*/ ;
01291         T_ALPHA_BETA* beta_cache /*[parallel.get_num_threads()]*/ ;
01292 
01294         T_STATES** states_per_observation_psi /*[parallel.get_num_threads()]*/ ;
01295 
01297         T_STATES** path /*[parallel.get_num_threads()]*/ ;
01298 
01300         bool* path_prob_updated /*[parallel.get_num_threads()]*/;
01301 
01303         int32_t* path_prob_dimension /*[parallel.get_num_threads()]*/ ; 
01304 
01305 #else //USE_HMMPARALLEL_STRUCTURES
01306 
01307         T_ALPHA_BETA alpha_cache;
01309         T_ALPHA_BETA beta_cache;
01310 
01312         T_STATES* states_per_observation_psi;
01313 
01315         T_STATES* path;
01316 
01318         bool path_prob_updated;
01319 
01321         int32_t path_prob_dimension;
01322 
01323 #endif //USE_HMMPARALLEL_STRUCTURES
01324 
01325 
01327         static const int32_t GOTN;
01329         static const int32_t GOTM;
01331         static const int32_t GOTO;
01333         static const int32_t GOTa;
01335         static const int32_t GOTb;
01337         static const int32_t GOTp;
01339         static const int32_t GOTq;
01340 
01342         static const int32_t GOTlearn_a;
01344         static const int32_t GOTlearn_b;
01346         static const int32_t GOTlearn_p;
01348         static const int32_t GOTlearn_q;
01350         static const int32_t GOTconst_a;
01352         static const int32_t GOTconst_b;
01354         static const int32_t GOTconst_p;
01356         static const int32_t GOTconst_q;
01357 
01358         public:
01363 
01365 inline float64_t state_probability(
01366     int32_t time, int32_t state, int32_t dimension)
01367 {
01368     return forward(time, state, dimension) + backward(time, state, dimension) - model_probability(dimension);
01369 }
01370 
01372 inline float64_t transition_probability(
01373     int32_t time, int32_t state_i, int32_t state_j, int32_t dimension)
01374 {
01375     return forward(time, state_i, dimension) + 
01376         backward(time+1, state_j, dimension) + 
01377         get_a(state_i,state_j) + get_b(state_j,p_observations->get_feature(dimension ,time+1)) - model_probability(dimension);
01378 }
01379 
01386 
01389 inline float64_t linear_model_derivative(
01390     T_STATES i, uint16_t j, int32_t dimension)
01391 {
01392     float64_t der=0;
01393 
01394     for (int32_t k=0; k<N; k++)
01395     {
01396         if (k!=i || p_observations->get_feature(dimension, k) != j)
01397             der+=get_b(k, p_observations->get_feature(dimension, k));
01398     }
01399 
01400     return der;
01401 }
01402 
01406 inline float64_t model_derivative_p(T_STATES i, int32_t dimension)
01407 {
01408     return backward(0,i,dimension)+get_b(i, p_observations->get_feature(dimension, 0));     
01409 }
01410 
01414 inline float64_t model_derivative_q(T_STATES i, int32_t dimension)
01415 {
01416     return forward(p_observations->get_vector_length(dimension)-1,i,dimension) ;
01417 }
01418 
01420 inline float64_t model_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01421 {
01422     float64_t sum=-CMath::INFTY;
01423     for (int32_t t=0; t<p_observations->get_vector_length(dimension)-1; t++)
01424         sum= CMath::logarithmic_sum(sum, forward(t, i, dimension) + backward(t+1, j, dimension) + get_b(j, p_observations->get_feature(dimension,t+1)));
01425 
01426     return sum;
01427 }
01428 
01429 
01431 inline float64_t model_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01432 {
01433     float64_t sum=-CMath::INFTY;
01434     for (int32_t t=0; t<p_observations->get_vector_length(dimension); t++)
01435     {
01436         if (p_observations->get_feature(dimension,t)==j)
01437             sum= CMath::logarithmic_sum(sum, forward(t,i,dimension)+backward(t,i,dimension)-get_b(i,p_observations->get_feature(dimension,t)));
01438     }
01439     //if (sum==-CMath::INFTY)
01440     // SG_DEBUG( "log derivative is -inf: dim=%i, state=%i, obs=%i\n",dimension, i, j) ;
01441     return sum;
01442 } 
01444 
01451 
01453 inline float64_t path_derivative_p(T_STATES i, int32_t dimension)
01454 {
01455     best_path(dimension);
01456     return (i==PATH(dimension)[0]) ? (exp(-get_p(PATH(dimension)[0]))) : (0) ;
01457 }
01458 
01460 inline float64_t path_derivative_q(T_STATES i, int32_t dimension)
01461 {
01462     best_path(dimension);
01463     return (i==PATH(dimension)[p_observations->get_vector_length(dimension)-1]) ? (exp(-get_q(PATH(dimension)[p_observations->get_vector_length(dimension)-1]))) : 0 ;
01464 }
01465 
01467 inline float64_t path_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01468 {
01469     prepare_path_derivative(dimension) ;
01470     return (get_A(i,j)==0) ? (0) : (get_A(i,j)*exp(-get_a(i,j))) ;
01471 }
01472 
01474 inline float64_t path_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01475 {
01476     prepare_path_derivative(dimension) ;
01477     return (get_B(i,j)==0) ? (0) : (get_B(i,j)*exp(-get_b(i,j))) ;
01478 } 
01479 
01481 
01482 
01483 protected:
01488 
01489     bool get_numbuffer(FILE* file, char* buffer, int32_t length);
01490 
01492     void open_bracket(FILE* file);
01493 
01495     void close_bracket(FILE* file);
01496 
01498     bool comma_or_space(FILE* file);
01499 
01501     inline void error(int32_t p_line, const char* str)
01502     {
01503         if (p_line)
01504             SG_ERROR( "error in line %d %s\n", p_line, str);
01505         else
01506             SG_ERROR( "error %s\n", str);
01507     }
01509 
01511     inline void prepare_path_derivative(int32_t dim)
01512     {
01513         if (path_deriv_updated && (path_deriv_dimension==dim))
01514             return ;
01515         int32_t i,j,t ;
01516         best_path(dim);
01517         //initialize with zeros
01518         for (i=0; i<N; i++)
01519         {
01520             for (j=0; j<N; j++)
01521                 set_A(i,j, 0);
01522             for (j=0; j<M; j++)
01523                 set_B(i,j, 0);
01524         }
01525 
01526         //counting occurences for A and B
01527         for (t=0; t<p_observations->get_vector_length(dim)-1; t++)
01528         {
01529             set_A(PATH(dim)[t], PATH(dim)[t+1], get_A(PATH(dim)[t], PATH(dim)[t+1])+1);
01530             set_B(PATH(dim)[t], p_observations->get_feature(dim,t),  get_B(PATH(dim)[t], p_observations->get_feature(dim,t))+1);
01531         }
01532         set_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1),  get_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1)) + 1);
01533         path_deriv_dimension=dim ;
01534         path_deriv_updated=true ;
01535     } ;
01537 
01539     inline float64_t forward(int32_t time, int32_t state, int32_t dimension)
01540     {
01541         if (time<1)
01542             time=0;
01543 
01544         if (ALPHA_CACHE(dimension).table && (dimension==ALPHA_CACHE(dimension).dimension) && ALPHA_CACHE(dimension).updated)
01545         {
01546             if (time<p_observations->get_vector_length(dimension))
01547                 return ALPHA_CACHE(dimension).table[time*N+state];
01548             else
01549                 return ALPHA_CACHE(dimension).sum;
01550         }
01551         else
01552             return forward_comp(time, state, dimension) ;
01553     }
01554 
01556     inline float64_t backward(int32_t time, int32_t state, int32_t dimension)
01557     {
01558         if (BETA_CACHE(dimension).table && (dimension==BETA_CACHE(dimension).dimension) && (BETA_CACHE(dimension).updated))
01559         {
01560             if (time<0)
01561                 return BETA_CACHE(dimension).sum;
01562             if (time<p_observations->get_vector_length(dimension))
01563                 return BETA_CACHE(dimension).table[time*N+state];
01564             else
01565                 return -CMath::INFTY;
01566         }
01567         else
01568             return backward_comp(time, state, dimension) ;
01569     }
01570 
01571 };
01572 }
01573 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation