HMM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 1999-2008 Gunnar Raetsch
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef __CHMM_H__
00013 #define __CHMM_H__
00014 
00015 #include <shogun/mathematics/Math.h>
00016 #include <shogun/lib/common.h>
00017 #include <shogun/io/SGIO.h>
00018 #include <shogun/lib/config.h>
00019 #include <shogun/features/Features.h>
00020 #include <shogun/features/StringFeatures.h>
00021 #include <shogun/distributions/Distribution.h>
00022 
00023 #include <stdio.h>
00024 
00025 #ifdef USE_HMMPARALLEL
00026 #define USE_HMMPARALLEL_STRUCTURES 1
00027 #endif
00028 
00029 namespace shogun
00030 {
00031     class CFeatures;
00032     template <class ST> class CStringFeatures;
00035 
00037 typedef float64_t T_ALPHA_BETA_TABLE;
00038 
00039 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00040 
00041 struct T_ALPHA_BETA
00042 {
00044     int32_t dimension;
00045 
00047     T_ALPHA_BETA_TABLE* table;
00048 
00050     bool updated;
00051 
00053     float64_t sum;
00054 };
00055 #endif // DOXYGEN_SHOULD_SKIP_THIS
00056 
00061 #ifdef USE_BIGSTATES
00062 typedef uint16_t T_STATES ;
00063 #else
00064 typedef uint8_t T_STATES ;
00065 #endif
00066 typedef T_STATES* P_STATES ;
00067 
00069 
00071 enum BaumWelchViterbiType
00072 {
00074     BW_NORMAL,
00076     BW_TRANS,
00078     BW_DEFINED,
00080     VIT_NORMAL,
00082     VIT_DEFINED
00083 };
00084 
00085 
00087 class Model
00088 {
00089     public:
00091         Model();
00092 
00094         virtual ~Model();
00095 
00097         inline void sort_learn_a()
00098         {
00099             CMath::sort(learn_a,2) ;
00100         }
00101 
00103         inline void sort_learn_b()
00104         {
00105             CMath::sort(learn_b,2) ;
00106         }
00107 
00112 
00113         inline int32_t get_learn_a(int32_t line, int32_t column) const
00114         {
00115             return learn_a[line*2 + column];
00116         }
00117 
00119         inline int32_t get_learn_b(int32_t line, int32_t column) const
00120         {
00121             return learn_b[line*2 + column];
00122         }
00123 
00125         inline int32_t get_learn_p(int32_t offset) const
00126         {
00127             return learn_p[offset];
00128         }
00129 
00131         inline int32_t get_learn_q(int32_t offset) const
00132         {
00133             return learn_q[offset];
00134         }
00135 
00137         inline int32_t get_const_a(int32_t line, int32_t column) const
00138         {
00139             return const_a[line*2 + column];
00140         }
00141 
00143         inline int32_t get_const_b(int32_t line, int32_t column) const
00144         {
00145             return const_b[line*2 + column];
00146         }
00147 
00149         inline int32_t get_const_p(int32_t offset) const
00150         {
00151             return const_p[offset];
00152         }
00153 
00155         inline int32_t get_const_q(int32_t offset) const
00156         {
00157             return const_q[offset];
00158         }
00159 
00161         inline float64_t get_const_a_val(int32_t line) const
00162         {
00163             return const_a_val[line];
00164         }
00165 
00167         inline float64_t get_const_b_val(int32_t line) const
00168         {
00169             return const_b_val[line];
00170         }
00171 
00173         inline float64_t get_const_p_val(int32_t offset) const
00174         {
00175             return const_p_val[offset];
00176         }
00177 
00179         inline float64_t get_const_q_val(int32_t offset) const
00180         {
00181             return const_q_val[offset];
00182         }
00183 #ifdef FIX_POS
00184 
00185         inline char get_fix_pos_state(int32_t pos, T_STATES state, T_STATES num_states)
00186         {
00187 #ifdef HMM_DEBUG
00188             if ((pos<0)||(pos*num_states+state>65336))
00189                 SG_DEBUG("index out of range in get_fix_pos_state(%i,%i,%i) \n", pos,state,num_states) ;
00190 #endif
00191             return fix_pos_state[pos*num_states+state] ;
00192         }
00193 #endif
00194 
00195 
00200 
00201         inline void set_learn_a(int32_t offset, int32_t value)
00202         {
00203             learn_a[offset]=value;
00204         }
00205 
00207         inline void set_learn_b(int32_t offset, int32_t value)
00208         {
00209             learn_b[offset]=value;
00210         }
00211 
00213         inline void set_learn_p(int32_t offset, int32_t value)
00214         {
00215             learn_p[offset]=value;
00216         }
00217 
00219         inline void set_learn_q(int32_t offset, int32_t value)
00220         {
00221             learn_q[offset]=value;
00222         }
00223 
00225         inline void set_const_a(int32_t offset, int32_t value)
00226         {
00227             const_a[offset]=value;
00228         }
00229 
00231         inline void set_const_b(int32_t offset, int32_t value)
00232         {
00233             const_b[offset]=value;
00234         }
00235 
00237         inline void set_const_p(int32_t offset, int32_t value)
00238         {
00239             const_p[offset]=value;
00240         }
00241 
00243         inline void set_const_q(int32_t offset, int32_t value)
00244         {
00245             const_q[offset]=value;
00246         }
00247 
00249         inline void set_const_a_val(int32_t offset, float64_t value)
00250         {
00251             const_a_val[offset]=value;
00252         }
00253 
00255         inline void set_const_b_val(int32_t offset, float64_t value)
00256         {
00257             const_b_val[offset]=value;
00258         }
00259 
00261         inline void set_const_p_val(int32_t offset, float64_t value)
00262         {
00263             const_p_val[offset]=value;
00264         }
00265 
00267         inline void set_const_q_val(int32_t offset, float64_t value)
00268         {
00269             const_q_val[offset]=value;
00270         }
00271 #ifdef FIX_POS
00272 
00273         inline void set_fix_pos_state(
00274             int32_t pos, T_STATES state, T_STATES num_states, char value)
00275         {
00276 #ifdef HMM_DEBUG
00277             if ((pos<0)||(pos*num_states+state>65336))
00278                 SG_DEBUG("index out of range in set_fix_pos_state(%i,%i,%i,%i) [%i]\n", pos,state,num_states,(int)value, pos*num_states+state) ;
00279 #endif
00280             fix_pos_state[pos*num_states+state]=value;
00281             if (value==FIX_ALLOWED)
00282                 for (int32_t i=0; i<num_states; i++)
00283                     if (get_fix_pos_state(pos,i,num_states)==FIX_DEFAULT)
00284                         set_fix_pos_state(pos,i,num_states,FIX_DISALLOWED) ;
00285         }
00287 
00289         const static char FIX_DISALLOWED ;
00290 
00292         const static char FIX_ALLOWED ;
00293 
00295         const static char FIX_DEFAULT ;
00296 
00298         const static float64_t DISALLOWED_PENALTY ;
00299 #endif
00300     protected:
00307 
00308         int32_t* learn_a;
00309 
00311         int32_t* learn_b;
00312 
00314         int32_t* learn_p;
00315 
00317         int32_t* learn_q;
00319 
00326 
00327         int32_t* const_a;
00328 
00330         int32_t* const_b;
00331 
00333         int32_t* const_p;
00334 
00336         int32_t* const_q;
00337 
00338 
00340         float64_t* const_a_val;
00341 
00343         float64_t* const_b_val;
00344 
00346         float64_t* const_p_val;
00347 
00349         float64_t* const_q_val;
00350 
00351 #ifdef FIX_POS
00352 
00355         char* fix_pos_state;
00356 #endif
00357 
00358 };
00359 
00360 
00371 class CHMM : public CDistribution
00372 {
00373     private:
00374 
00375         T_STATES trans_list_len ;
00376         T_STATES **trans_list_forward  ;
00377         T_STATES *trans_list_forward_cnt  ;
00378         float64_t **trans_list_forward_val ;
00379         T_STATES **trans_list_backward  ;
00380         T_STATES *trans_list_backward_cnt  ;
00381         bool mem_initialized ;
00382 
00383 #ifdef USE_HMMPARALLEL_STRUCTURES
00384 
00386         struct S_DIM_THREAD_PARAM
00387         {
00388             CHMM* hmm;
00389             int32_t dim;
00390             float64_t prob_sum;
00391         };
00392 
00394         struct S_BW_THREAD_PARAM
00395         {
00396             CHMM* hmm;
00397             int32_t dim_start;
00398             int32_t dim_stop;
00399 
00400             float64_t ret;
00401 
00402             float64_t* p_buf;
00403             float64_t* q_buf;
00404             float64_t* a_buf;
00405             float64_t* b_buf;
00406         };
00407 
00408         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t dim) {
00409             return alpha_cache[dim%parallel->get_num_threads()] ; } ;
00410         inline T_ALPHA_BETA & BETA_CACHE(int32_t dim) {
00411             return beta_cache[dim%parallel->get_num_threads()] ; } ;
00412 #ifdef USE_LOGSUMARRAY
00413         inline float64_t* ARRAYS(int32_t dim) {
00414             return arrayS[dim%parallel->get_num_threads()] ; } ;
00415 #endif
00416         inline float64_t* ARRAYN1(int32_t dim) {
00417             return arrayN1[dim%parallel->get_num_threads()] ; } ;
00418         inline float64_t* ARRAYN2(int32_t dim) {
00419             return arrayN2[dim%parallel->get_num_threads()] ; } ;
00420         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) {
00421             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00422         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) const {
00423             return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00424         inline T_STATES* PATH(int32_t dim) {
00425             return path[dim%parallel->get_num_threads()] ; } ;
00426         inline bool & PATH_PROB_UPDATED(int32_t dim) {
00427             return path_prob_updated[dim%parallel->get_num_threads()] ; } ;
00428         inline int32_t & PATH_PROB_DIMENSION(int32_t dim) {
00429             return path_prob_dimension[dim%parallel->get_num_threads()] ; } ;
00430 #else
00431         inline T_ALPHA_BETA & ALPHA_CACHE(int32_t /*dim*/) {
00432             return alpha_cache ; } ;
00433         inline T_ALPHA_BETA & BETA_CACHE(int32_t /*dim*/) {
00434             return beta_cache ; } ;
00435 #ifdef USE_LOGSUMARRAY
00436         inline float64_t* ARRAYS(int32_t dim) {
00437             return arrayS ; } ;
00438 #endif
00439         inline float64_t* ARRAYN1(int32_t /*dim*/) {
00440             return arrayN1 ; } ;
00441         inline float64_t* ARRAYN2(int32_t /*dim*/) {
00442             return arrayN2 ; } ;
00443         inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) {
00444             return states_per_observation_psi ; } ;
00445         inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t /*dim*/) const {
00446             return states_per_observation_psi ; } ;
00447         inline T_STATES* PATH(int32_t /*dim*/) {
00448             return path ; } ;
00449         inline bool & PATH_PROB_UPDATED(int32_t /*dim*/) {
00450             return path_prob_updated ; } ;
00451         inline int32_t & PATH_PROB_DIMENSION(int32_t /*dim*/) {
00452             return path_prob_dimension ; } ;
00453 #endif
00454 
00459         bool converged(float64_t x, float64_t y);
00460 
00466     public:
00468         CHMM();
00469 
00480         CHMM(
00481             int32_t N, int32_t M, Model* model, float64_t PSEUDO);
00482         CHMM(
00483             CStringFeatures<uint16_t>* obs, int32_t N, int32_t M,
00484             float64_t PSEUDO);
00485         CHMM(
00486             int32_t N, float64_t* p, float64_t* q, float64_t* a);
00487         CHMM(
00488             int32_t N, float64_t* p, float64_t* q, int32_t num_trans,
00489             float64_t* a_trans);
00490 
00495         CHMM(FILE* model_file, float64_t PSEUDO);
00496 
00498         CHMM(CHMM* h);
00499 
00501         virtual ~CHMM();
00502 
00511         virtual bool train(CFeatures* data=NULL);
00512         virtual int32_t get_num_model_parameters() { return N*(N+M+2); }
00513         virtual float64_t get_log_model_parameter(int32_t num_param);
00514         virtual float64_t get_log_derivative(int32_t num_param, int32_t num_example);
00515         virtual float64_t get_log_likelihood_example(int32_t num_example)
00516         {
00517             return model_probability(num_example);
00518         }
00519 
00525         bool initialize(Model* model, float64_t PSEUDO, FILE* model_file=NULL);
00527 
00529         bool alloc_state_dependend_arrays();
00530 
00532         void free_state_dependend_arrays();
00533 
00545         float64_t forward_comp(int32_t time, int32_t state, int32_t dimension);
00546         float64_t forward_comp_old(
00547             int32_t time, int32_t state, int32_t dimension);
00548 
00556         float64_t backward_comp(int32_t time, int32_t state, int32_t dimension);
00557         float64_t backward_comp_old(
00558             int32_t time, int32_t state, int32_t dimension);
00559 
00564         float64_t best_path(int32_t dimension);
00565         inline uint16_t get_best_path_state(int32_t dim, int32_t t)
00566         {
00567             ASSERT(PATH(dim));
00568             return PATH(dim)[t];
00569         }
00570 
00573         float64_t model_probability_comp() ;
00574 
00576         inline float64_t model_probability(int32_t dimension=-1)
00577         {
00578             //for faster calculation cache model probability
00579             if (dimension==-1)
00580             {
00581                 if (mod_prob_updated)
00582                     return mod_prob/p_observations->get_num_vectors();
00583                 else
00584                     return model_probability_comp()/p_observations->get_num_vectors();
00585             }
00586             else
00587                 return forward(p_observations->get_vector_length(dimension), 0, dimension);
00588         }
00589 
00595         inline float64_t linear_model_probability(int32_t dimension)
00596         {
00597             float64_t lik=0;
00598             int32_t len=0;
00599             bool free_vec;
00600             uint16_t* o=p_observations->get_feature_vector(dimension, len, free_vec);
00601             float64_t* obs_b=observation_matrix_b;
00602 
00603             ASSERT(N==len);
00604 
00605             for (int32_t i=0; i<N; i++)
00606             {
00607                 lik+=obs_b[*o++];
00608                 obs_b+=M;
00609             }
00610             p_observations->free_feature_vector(o, dimension, free_vec);
00611             return lik;
00612 
00613             // sorry, the above code is the speed optimized version of :
00614             /*  float64_t lik=0;
00615 
00616                 for (int32_t i=0; i<N; i++)
00617                 lik+=get_b(i, p_observations->get_feature(dimension, i));
00618                 return lik;
00619                 */
00620             // : that
00621         }
00622 
00624 
00627         inline bool set_iterations(int32_t num) { iterations=num; return true; }
00628         inline int32_t get_iterations() { return iterations; }
00629         inline bool set_epsilon (float64_t eps) { epsilon=eps; return true; }
00630         inline float64_t get_epsilon() { return epsilon; }
00631 
00635         bool baum_welch_viterbi_train(BaumWelchViterbiType type);
00636 
00643         void estimate_model_baum_welch(CHMM* train);
00644         void estimate_model_baum_welch_trans(CHMM* train);
00645 
00646 #ifdef USE_HMMPARALLEL_STRUCTURES
00647         void ab_buf_comp(
00648             float64_t* p_buf, float64_t* q_buf, float64_t* a_buf,
00649             float64_t* b_buf, int32_t dim) ;
00650 #else
00651         void estimate_model_baum_welch_old(CHMM* train);
00652 #endif
00653 
00657         void estimate_model_baum_welch_defined(CHMM* train);
00658 
00662         void estimate_model_viterbi(CHMM* train);
00663 
00667         void estimate_model_viterbi_defined(CHMM* train);
00668 
00670 
00672         bool linear_train(bool right_align=false);
00673 
00675         bool permutation_entropy(int32_t window_width, int32_t sequence_number);
00676 
00683         void output_model(bool verbose=false);
00684 
00686         void output_model_defined(bool verbose=false);
00688 
00689 
00692 
00694         void normalize(bool keep_dead_states=false);
00695 
00699         void add_states(int32_t num_states, float64_t default_val=0);
00700 
00706         bool append_model(
00707             CHMM* append_model, float64_t* cur_out, float64_t* app_out);
00708 
00712         bool append_model(CHMM* append_model);
00713 
00715         void chop(float64_t value);
00716 
00718         void convert_to_log();
00719 
00721         void init_model_random();
00722 
00728         void init_model_defined();
00729 
00731         void clear_model();
00732 
00734         void clear_model_defined();
00735 
00737         void copy_model(CHMM* l);
00738 
00743         void invalidate_model();
00744 
00748         inline bool get_status() const
00749         {
00750             return status;
00751         }
00752 
00754         inline float64_t get_pseudo() const
00755         {
00756             return PSEUDO ;
00757         }
00758 
00760         inline void set_pseudo(float64_t pseudo)
00761         {
00762             PSEUDO=pseudo ;
00763         }
00764 
00765 #ifdef USE_HMMPARALLEL_STRUCTURES
00766         static void* bw_dim_prefetch(void * params);
00767         static void* bw_single_dim_prefetch(void * params);
00768         static void* vit_dim_prefetch(void * params);
00769 #endif
00770 
00771 #ifdef FIX_POS
00772 
00775         inline bool set_fix_pos_state(int32_t pos, T_STATES state, char value)
00776         {
00777             if (!model)
00778                 return false ;
00779             model->set_fix_pos_state(pos, state, N, value) ;
00780             return true ;
00781         } ;
00782 #endif
00783 
00784 
00793         void set_observations(CStringFeatures<uint16_t>* obs, CHMM* hmm=NULL);
00794 
00798         void set_observation_nocache(CStringFeatures<uint16_t>* obs);
00799 
00801         inline CStringFeatures<uint16_t>* get_observations()
00802         {
00803             SG_REF(p_observations);
00804             return p_observations;
00805         }
00807 
00875         bool load_definitions(FILE* file, bool verbose, bool initialize=true);
00876 
00912         bool load_model(FILE* file);
00913 
00917         bool save_model(FILE* file);
00918 
00922         bool save_model_derivatives(FILE* file);
00923 
00927         bool save_model_derivatives_bin(FILE* file);
00928 
00932         bool save_model_bin(FILE* file);
00933 
00935         bool check_model_derivatives() ;
00936         bool check_model_derivatives_combined() ;
00937 
00943         T_STATES* get_path(int32_t dim, float64_t& prob);
00944 
00948         bool save_path(FILE* file);
00949 
00953         bool save_path_derivatives(FILE* file);
00954 
00958         bool save_path_derivatives_bin(FILE* file);
00959 
00960 #ifdef USE_HMMDEBUG
00961 
00962         bool check_path_derivatives() ;
00963 #endif //USE_HMMDEBUG
00964 
00968         bool save_likelihood_bin(FILE* file);
00969 
00973         bool save_likelihood(FILE* file);
00975 
00981 
00983         inline T_STATES get_N() const { return N ; }
00984 
00986         inline int32_t get_M() const { return M ; }
00987 
00992         inline void set_q(T_STATES offset, float64_t value)
00993         {
00994 #ifdef HMM_DEBUG
00995             if (offset>=N)
00996                 SG_DEBUG("index out of range in set_q(%i,%e) [%i]\n", offset,value,N) ;
00997 #endif
00998             end_state_distribution_q[offset]=value;
00999         }
01000 
01005         inline void set_p(T_STATES offset, float64_t value)
01006         {
01007 #ifdef HMM_DEBUG
01008             if (offset>=N)
01009                 SG_DEBUG("index out of range in set_p(%i,.) [%i]\n", offset,N) ;
01010 #endif
01011             initial_state_distribution_p[offset]=value;
01012         }
01013 
01019         inline void set_A(T_STATES line_, T_STATES column, float64_t value)
01020         {
01021 #ifdef HMM_DEBUG
01022             if ((line_>N)||(column>N))
01023                 SG_DEBUG("index out of range in set_A(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01024 #endif
01025             transition_matrix_A[line_+column*N]=value;
01026         }
01027 
01033         inline void set_a(T_STATES line_, T_STATES column, float64_t value)
01034         {
01035 #ifdef HMM_DEBUG
01036             if ((line_>N)||(column>N))
01037                 SG_DEBUG("index out of range in set_a(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01038 #endif
01039             transition_matrix_a[line_+column*N]=value; // look also best_path!
01040         }
01041 
01047         inline void set_B(T_STATES line_, uint16_t column, float64_t value)
01048         {
01049 #ifdef HMM_DEBUG
01050             if ((line_>=N)||(column>=M))
01051                 SG_DEBUG("index out of range in set_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01052 #endif
01053             observation_matrix_B[line_*M+column]=value;
01054         }
01055 
01061         inline void set_b(T_STATES line_, uint16_t column, float64_t value)
01062         {
01063 #ifdef HMM_DEBUG
01064             if ((line_>=N)||(column>=M))
01065                 SG_DEBUG("index out of range in set_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01066 #endif
01067             observation_matrix_b[line_*M+column]=value;
01068         }
01069 
01076         inline void set_psi(
01077             int32_t time, T_STATES state, T_STATES value, int32_t dimension)
01078         {
01079 #ifdef HMM_DEBUG
01080             if ((time>=p_observations->get_max_vector_length())||(state>N))
01081                 SG_DEBUG("index out of range in set_psi(%i,%i,.) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01082 #endif
01083             STATES_PER_OBSERVATION_PSI(dimension)[time*N+state]=value;
01084         }
01085 
01090         inline float64_t get_q(T_STATES offset) const
01091         {
01092 #ifdef HMM_DEBUG
01093             if (offset>=N)
01094                 SG_DEBUG("index out of range in %e=get_q(%i) [%i]\n", end_state_distribution_q[offset],offset,N) ;
01095 #endif
01096             return end_state_distribution_q[offset];
01097         }
01098 
01103         inline float64_t get_p(T_STATES offset) const
01104         {
01105 #ifdef HMM_DEBUG
01106             if (offset>=N)
01107                 SG_DEBUG("index out of range in get_p(%i,.) [%i]\n", offset,N) ;
01108 #endif
01109             return initial_state_distribution_p[offset];
01110         }
01111 
01117         inline float64_t get_A(T_STATES line_, T_STATES column) const
01118         {
01119 #ifdef HMM_DEBUG
01120             if ((line_>N)||(column>N))
01121                 SG_DEBUG("index out of range in get_A(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01122 #endif
01123             return transition_matrix_A[line_+column*N];
01124         }
01125 
01131         inline float64_t get_a(T_STATES line_, T_STATES column) const
01132         {
01133 #ifdef HMM_DEBUG
01134             if ((line_>N)||(column>N))
01135                 SG_DEBUG("index out of range in get_a(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01136 #endif
01137             return transition_matrix_a[line_+column*N]; // look also best_path()!
01138         }
01139 
01145         inline float64_t get_B(T_STATES line_, uint16_t column) const
01146         {
01147 #ifdef HMM_DEBUG
01148             if ((line_>=N)||(column>=M))
01149                 SG_DEBUG("index out of range in get_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01150 #endif
01151             return observation_matrix_B[line_*M+column];
01152         }
01153 
01159         inline float64_t get_b(T_STATES line_, uint16_t column) const
01160         {
01161 #ifdef HMM_DEBUG
01162             if ((line_>=N)||(column>=M))
01163                 SG_DEBUG("index out of range in get_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01164 #endif
01165             //SG_PRINT("idx %d\n", line_*M+column);
01166             return observation_matrix_b[line_*M+column];
01167         }
01168 
01175         inline T_STATES get_psi(
01176             int32_t time, T_STATES state, int32_t dimension) const
01177         {
01178 #ifdef HMM_DEBUG
01179             if ((time>=p_observations->get_max_vector_length())||(state>N))
01180                 SG_DEBUG("index out of range in get_psi(%i,%i) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01181 #endif
01182             return STATES_PER_OBSERVATION_PSI(dimension)[time*N+state];
01183         }
01184 
01186 
01188         virtual const char* get_name() const { return "HMM"; }
01189 
01190     protected:
01195 
01196         int32_t M;
01197 
01199         int32_t N;
01200 
01202         float64_t PSEUDO;
01203 
01204         // line number during processing input files
01205         int32_t line;
01206 
01208         CStringFeatures<uint16_t>* p_observations;
01209 
01210         //train definition for HMM
01211         Model* model;
01212 
01214         float64_t* transition_matrix_A;
01215 
01217         float64_t* observation_matrix_B;
01218 
01220         float64_t* transition_matrix_a;
01221 
01223         float64_t* initial_state_distribution_p;
01224 
01226         float64_t* end_state_distribution_q;
01227 
01229         float64_t* observation_matrix_b;
01230 
01232         int32_t iterations;
01233         int32_t iteration_count;
01234 
01236         float64_t epsilon;
01237         int32_t conv_it;
01238 
01240         float64_t all_pat_prob;
01241 
01243         float64_t pat_prob;
01244 
01246         float64_t mod_prob;
01247 
01249         bool mod_prob_updated;
01250 
01252         bool all_path_prob_updated;
01253 
01255         int32_t path_deriv_dimension;
01256 
01258         bool path_deriv_updated;
01259 
01260         // true if model is using log likelihood
01261         bool loglikelihood;
01262 
01263         // true->ok, false->error
01264         bool status;
01265 
01266         // true->stolen from other HMMs, false->got own
01267         bool reused_caches;
01269 
01270 #ifdef USE_HMMPARALLEL_STRUCTURES
01271 
01272         float64_t** arrayN1 /*[parallel.get_num_threads()]*/ ;
01274         float64_t** arrayN2 /*[parallel.get_num_threads()]*/ ;
01275 #else //USE_HMMPARALLEL_STRUCTURES
01276 
01277         float64_t* arrayN1;
01279         float64_t* arrayN2;
01280 #endif //USE_HMMPARALLEL_STRUCTURES
01281 
01282 #ifdef USE_LOGSUMARRAY
01283 #ifdef USE_HMMPARALLEL_STRUCTURES
01284 
01285         float64_t** arrayS /*[parallel.get_num_threads()]*/;
01286 #else
01287 
01288         float64_t* arrayS;
01289 #endif // USE_HMMPARALLEL_STRUCTURES
01290 #endif // USE_LOGSUMARRAY
01291 
01292 #ifdef USE_HMMPARALLEL_STRUCTURES
01293 
01295         T_ALPHA_BETA* alpha_cache /*[parallel.get_num_threads()]*/ ;
01297         T_ALPHA_BETA* beta_cache /*[parallel.get_num_threads()]*/ ;
01298 
01300         T_STATES** states_per_observation_psi /*[parallel.get_num_threads()]*/ ;
01301 
01303         T_STATES** path /*[parallel.get_num_threads()]*/ ;
01304 
01306         bool* path_prob_updated /*[parallel.get_num_threads()]*/;
01307 
01309         int32_t* path_prob_dimension /*[parallel.get_num_threads()]*/ ;
01310 
01311 #else //USE_HMMPARALLEL_STRUCTURES
01312 
01313         T_ALPHA_BETA alpha_cache;
01315         T_ALPHA_BETA beta_cache;
01316 
01318         T_STATES* states_per_observation_psi;
01319 
01321         T_STATES* path;
01322 
01324         bool path_prob_updated;
01325 
01327         int32_t path_prob_dimension;
01328 
01329 #endif //USE_HMMPARALLEL_STRUCTURES
01330 
01331 
01333         static const int32_t GOTN;
01335         static const int32_t GOTM;
01337         static const int32_t GOTO;
01339         static const int32_t GOTa;
01341         static const int32_t GOTb;
01343         static const int32_t GOTp;
01345         static const int32_t GOTq;
01346 
01348         static const int32_t GOTlearn_a;
01350         static const int32_t GOTlearn_b;
01352         static const int32_t GOTlearn_p;
01354         static const int32_t GOTlearn_q;
01356         static const int32_t GOTconst_a;
01358         static const int32_t GOTconst_b;
01360         static const int32_t GOTconst_p;
01362         static const int32_t GOTconst_q;
01363 
01364         public:
01369 
01371 inline float64_t state_probability(
01372     int32_t time, int32_t state, int32_t dimension)
01373 {
01374     return forward(time, state, dimension) + backward(time, state, dimension) - model_probability(dimension);
01375 }
01376 
01378 inline float64_t transition_probability(
01379     int32_t time, int32_t state_i, int32_t state_j, int32_t dimension)
01380 {
01381     return forward(time, state_i, dimension) +
01382         backward(time+1, state_j, dimension) +
01383         get_a(state_i,state_j) + get_b(state_j,p_observations->get_feature(dimension ,time+1)) - model_probability(dimension);
01384 }
01385 
01392 
01395 inline float64_t linear_model_derivative(
01396     T_STATES i, uint16_t j, int32_t dimension)
01397 {
01398     float64_t der=0;
01399 
01400     for (int32_t k=0; k<N; k++)
01401     {
01402         if (k!=i || p_observations->get_feature(dimension, k) != j)
01403             der+=get_b(k, p_observations->get_feature(dimension, k));
01404     }
01405 
01406     return der;
01407 }
01408 
01412 inline float64_t model_derivative_p(T_STATES i, int32_t dimension)
01413 {
01414     return backward(0,i,dimension)+get_b(i, p_observations->get_feature(dimension, 0));
01415 }
01416 
01420 inline float64_t model_derivative_q(T_STATES i, int32_t dimension)
01421 {
01422     return forward(p_observations->get_vector_length(dimension)-1,i,dimension) ;
01423 }
01424 
01426 inline float64_t model_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01427 {
01428     float64_t sum=-CMath::INFTY;
01429     for (int32_t t=0; t<p_observations->get_vector_length(dimension)-1; t++)
01430         sum= CMath::logarithmic_sum(sum, forward(t, i, dimension) + backward(t+1, j, dimension) + get_b(j, p_observations->get_feature(dimension,t+1)));
01431 
01432     return sum;
01433 }
01434 
01435 
01437 inline float64_t model_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01438 {
01439     float64_t sum=-CMath::INFTY;
01440     for (int32_t t=0; t<p_observations->get_vector_length(dimension); t++)
01441     {
01442         if (p_observations->get_feature(dimension,t)==j)
01443             sum= CMath::logarithmic_sum(sum, forward(t,i,dimension)+backward(t,i,dimension)-get_b(i,p_observations->get_feature(dimension,t)));
01444     }
01445     //if (sum==-CMath::INFTY)
01446     // SG_DEBUG( "log derivative is -inf: dim=%i, state=%i, obs=%i\n",dimension, i, j) ;
01447     return sum;
01448 }
01450 
01457 
01459 inline float64_t path_derivative_p(T_STATES i, int32_t dimension)
01460 {
01461     best_path(dimension);
01462     return (i==PATH(dimension)[0]) ? (exp(-get_p(PATH(dimension)[0]))) : (0) ;
01463 }
01464 
01466 inline float64_t path_derivative_q(T_STATES i, int32_t dimension)
01467 {
01468     best_path(dimension);
01469     return (i==PATH(dimension)[p_observations->get_vector_length(dimension)-1]) ? (exp(-get_q(PATH(dimension)[p_observations->get_vector_length(dimension)-1]))) : 0 ;
01470 }
01471 
01473 inline float64_t path_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01474 {
01475     prepare_path_derivative(dimension) ;
01476     return (get_A(i,j)==0) ? (0) : (get_A(i,j)*exp(-get_a(i,j))) ;
01477 }
01478 
01480 inline float64_t path_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01481 {
01482     prepare_path_derivative(dimension) ;
01483     return (get_B(i,j)==0) ? (0) : (get_B(i,j)*exp(-get_b(i,j))) ;
01484 }
01485 
01487 
01488 
01489 protected:
01494 
01495     bool get_numbuffer(FILE* file, char* buffer, int32_t length);
01496 
01498     void open_bracket(FILE* file);
01499 
01501     void close_bracket(FILE* file);
01502 
01504     bool comma_or_space(FILE* file);
01505 
01507     inline void error(int32_t p_line, const char* str)
01508     {
01509         if (p_line)
01510             SG_ERROR( "error in line %d %s\n", p_line, str);
01511         else
01512             SG_ERROR( "error %s\n", str);
01513     }
01515 
01517     inline void prepare_path_derivative(int32_t dim)
01518     {
01519         if (path_deriv_updated && (path_deriv_dimension==dim))
01520             return ;
01521         int32_t i,j,t ;
01522         best_path(dim);
01523         //initialize with zeros
01524         for (i=0; i<N; i++)
01525         {
01526             for (j=0; j<N; j++)
01527                 set_A(i,j, 0);
01528             for (j=0; j<M; j++)
01529                 set_B(i,j, 0);
01530         }
01531 
01532         //counting occurences for A and B
01533         for (t=0; t<p_observations->get_vector_length(dim)-1; t++)
01534         {
01535             set_A(PATH(dim)[t], PATH(dim)[t+1], get_A(PATH(dim)[t], PATH(dim)[t+1])+1);
01536             set_B(PATH(dim)[t], p_observations->get_feature(dim,t),  get_B(PATH(dim)[t], p_observations->get_feature(dim,t))+1);
01537         }
01538         set_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1),  get_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1)) + 1);
01539         path_deriv_dimension=dim ;
01540         path_deriv_updated=true ;
01541     } ;
01543 
01545     inline float64_t forward(int32_t time, int32_t state, int32_t dimension)
01546     {
01547         if (time<1)
01548             time=0;
01549 
01550         if (ALPHA_CACHE(dimension).table && (dimension==ALPHA_CACHE(dimension).dimension) && ALPHA_CACHE(dimension).updated)
01551         {
01552             if (time<p_observations->get_vector_length(dimension))
01553                 return ALPHA_CACHE(dimension).table[time*N+state];
01554             else
01555                 return ALPHA_CACHE(dimension).sum;
01556         }
01557         else
01558             return forward_comp(time, state, dimension) ;
01559     }
01560 
01562     inline float64_t backward(int32_t time, int32_t state, int32_t dimension)
01563     {
01564         if (BETA_CACHE(dimension).table && (dimension==BETA_CACHE(dimension).dimension) && (BETA_CACHE(dimension).updated))
01565         {
01566             if (time<0)
01567                 return BETA_CACHE(dimension).sum;
01568             if (time<p_observations->get_vector_length(dimension))
01569                 return BETA_CACHE(dimension).table[time*N+state];
01570             else
01571                 return -CMath::INFTY;
01572         }
01573         else
01574             return backward_comp(time, state, dimension) ;
01575     }
01576 
01577 };
01578 }
01579 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation