00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef __CHMM_H__
00013 #define __CHMM_H__
00014
00015 #include "lib/Mathematics.h"
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "lib/config.h"
00019 #include "features/Features.h"
00020 #include "features/StringFeatures.h"
00021 #include "distributions/Distribution.h"
00022
00023 #include <stdio.h>
00024
00025 #ifdef USE_HMMPARALLEL
00026 #define USE_HMMPARALLEL_STRUCTURES 1
00027 #endif
00028
00029 namespace shogun
00030 {
00031 class CFeatures;
00032 template <class ST> class CStringFeatures;
00035
00037 typedef float64_t T_ALPHA_BETA_TABLE;
00038
00039 #ifndef DOXYGEN_SHOULD_SKIP_THIS
00040
00041 struct T_ALPHA_BETA
00042 {
00044 int32_t dimension;
00045
00047 T_ALPHA_BETA_TABLE* table;
00048
00050 bool updated;
00051
00053 float64_t sum;
00054 };
00055 #endif // DOXYGEN_SHOULD_SKIP_THIS
00056
00061 #ifdef USE_BIGSTATES
00062 typedef uint16_t T_STATES ;
00063 #else
00064 typedef uint8_t T_STATES ;
00065 #endif
00066 typedef T_STATES* P_STATES ;
00067
00069
00070 enum BaumWelchViterbiType
00071 {
00072 BW_NORMAL,
00073 BW_TRANS,
00074 BW_DEFINED,
00075 VIT_NORMAL,
00076 VIT_DEFINED
00077 };
00078
00079
00081 class Model
00082 {
00083 public:
00085 Model();
00086
00088 virtual ~Model();
00089
00091 inline void sort_learn_a()
00092 {
00093 CMath::sort(learn_a,2) ;
00094 }
00095
00097 inline void sort_learn_b()
00098 {
00099 CMath::sort(learn_b,2) ;
00100 }
00101
00106
00107 inline int32_t get_learn_a(int32_t line, int32_t column) const
00108 {
00109 return learn_a[line*2 + column];
00110 }
00111
00113 inline int32_t get_learn_b(int32_t line, int32_t column) const
00114 {
00115 return learn_b[line*2 + column];
00116 }
00117
00119 inline int32_t get_learn_p(int32_t offset) const
00120 {
00121 return learn_p[offset];
00122 }
00123
00125 inline int32_t get_learn_q(int32_t offset) const
00126 {
00127 return learn_q[offset];
00128 }
00129
00131 inline int32_t get_const_a(int32_t line, int32_t column) const
00132 {
00133 return const_a[line*2 + column];
00134 }
00135
00137 inline int32_t get_const_b(int32_t line, int32_t column) const
00138 {
00139 return const_b[line*2 + column];
00140 }
00141
00143 inline int32_t get_const_p(int32_t offset) const
00144 {
00145 return const_p[offset];
00146 }
00147
00149 inline int32_t get_const_q(int32_t offset) const
00150 {
00151 return const_q[offset];
00152 }
00153
00155 inline float64_t get_const_a_val(int32_t line) const
00156 {
00157 return const_a_val[line];
00158 }
00159
00161 inline float64_t get_const_b_val(int32_t line) const
00162 {
00163 return const_b_val[line];
00164 }
00165
00167 inline float64_t get_const_p_val(int32_t offset) const
00168 {
00169 return const_p_val[offset];
00170 }
00171
00173 inline float64_t get_const_q_val(int32_t offset) const
00174 {
00175 return const_q_val[offset];
00176 }
00177 #ifdef FIX_POS
00178
00179 inline char get_fix_pos_state(int32_t pos, T_STATES state, T_STATES num_states)
00180 {
00181 #ifdef HMM_DEBUG
00182 if ((pos<0)||(pos*num_states+state>65336))
00183 SG_DEBUG("index out of range in get_fix_pos_state(%i,%i,%i) \n", pos,state,num_states) ;
00184 #endif
00185 return fix_pos_state[pos*num_states+state] ;
00186 }
00187 #endif
00188
00189
00194
00195 inline void set_learn_a(int32_t offset, int32_t value)
00196 {
00197 learn_a[offset]=value;
00198 }
00199
00201 inline void set_learn_b(int32_t offset, int32_t value)
00202 {
00203 learn_b[offset]=value;
00204 }
00205
00207 inline void set_learn_p(int32_t offset, int32_t value)
00208 {
00209 learn_p[offset]=value;
00210 }
00211
00213 inline void set_learn_q(int32_t offset, int32_t value)
00214 {
00215 learn_q[offset]=value;
00216 }
00217
00219 inline void set_const_a(int32_t offset, int32_t value)
00220 {
00221 const_a[offset]=value;
00222 }
00223
00225 inline void set_const_b(int32_t offset, int32_t value)
00226 {
00227 const_b[offset]=value;
00228 }
00229
00231 inline void set_const_p(int32_t offset, int32_t value)
00232 {
00233 const_p[offset]=value;
00234 }
00235
00237 inline void set_const_q(int32_t offset, int32_t value)
00238 {
00239 const_q[offset]=value;
00240 }
00241
00243 inline void set_const_a_val(int32_t offset, float64_t value)
00244 {
00245 const_a_val[offset]=value;
00246 }
00247
00249 inline void set_const_b_val(int32_t offset, float64_t value)
00250 {
00251 const_b_val[offset]=value;
00252 }
00253
00255 inline void set_const_p_val(int32_t offset, float64_t value)
00256 {
00257 const_p_val[offset]=value;
00258 }
00259
00261 inline void set_const_q_val(int32_t offset, float64_t value)
00262 {
00263 const_q_val[offset]=value;
00264 }
00265 #ifdef FIX_POS
00266
00267 inline void set_fix_pos_state(
00268 int32_t pos, T_STATES state, T_STATES num_states, char value)
00269 {
00270 #ifdef HMM_DEBUG
00271 if ((pos<0)||(pos*num_states+state>65336))
00272 SG_DEBUG("index out of range in set_fix_pos_state(%i,%i,%i,%i) [%i]\n", pos,state,num_states,(int)value, pos*num_states+state) ;
00273 #endif
00274 fix_pos_state[pos*num_states+state]=value;
00275 if (value==FIX_ALLOWED)
00276 for (int32_t i=0; i<num_states; i++)
00277 if (get_fix_pos_state(pos,i,num_states)==FIX_DEFAULT)
00278 set_fix_pos_state(pos,i,num_states,FIX_DISALLOWED) ;
00279 }
00281
00283 const static char FIX_DISALLOWED ;
00284
00286 const static char FIX_ALLOWED ;
00287
00289 const static char FIX_DEFAULT ;
00290
00292 const static float64_t DISALLOWED_PENALTY ;
00293 #endif
00294 protected:
00301
00302 int32_t* learn_a;
00303
00305 int32_t* learn_b;
00306
00308 int32_t* learn_p;
00309
00311 int32_t* learn_q;
00313
00320
00321 int32_t* const_a;
00322
00324 int32_t* const_b;
00325
00327 int32_t* const_p;
00328
00330 int32_t* const_q;
00331
00332
00334 float64_t* const_a_val;
00335
00337 float64_t* const_b_val;
00338
00340 float64_t* const_p_val;
00341
00343 float64_t* const_q_val;
00344
00345 #ifdef FIX_POS
00346
00349 char* fix_pos_state;
00350 #endif
00351
00352 };
00353
00354
00365 class CHMM : public CDistribution
00366 {
00367 private:
00368
00369 T_STATES trans_list_len ;
00370 T_STATES **trans_list_forward ;
00371 T_STATES *trans_list_forward_cnt ;
00372 float64_t **trans_list_forward_val ;
00373 T_STATES **trans_list_backward ;
00374 T_STATES *trans_list_backward_cnt ;
00375 bool mem_initialized ;
00376
00377 #ifdef USE_HMMPARALLEL_STRUCTURES
00378
00380 struct S_DIM_THREAD_PARAM
00381 {
00382 CHMM* hmm;
00383 int32_t dim;
00384 float64_t prob_sum;
00385 };
00386
00388 struct S_BW_THREAD_PARAM
00389 {
00390 CHMM* hmm;
00391 int32_t dim_start;
00392 int32_t dim_stop;
00393
00394 float64_t ret;
00395
00396 float64_t* p_buf;
00397 float64_t* q_buf;
00398 float64_t* a_buf;
00399 float64_t* b_buf;
00400 };
00401
00402 inline T_ALPHA_BETA & ALPHA_CACHE(int32_t dim) {
00403 return alpha_cache[dim%parallel->get_num_threads()] ; } ;
00404 inline T_ALPHA_BETA & BETA_CACHE(int32_t dim) {
00405 return beta_cache[dim%parallel->get_num_threads()] ; } ;
00406 #ifdef USE_LOGSUMARRAY
00407 inline float64_t* ARRAYS(int32_t dim) {
00408 return arrayS[dim%parallel->get_num_threads()] ; } ;
00409 #endif
00410 inline float64_t* ARRAYN1(int32_t dim) {
00411 return arrayN1[dim%parallel->get_num_threads()] ; } ;
00412 inline float64_t* ARRAYN2(int32_t dim) {
00413 return arrayN2[dim%parallel->get_num_threads()] ; } ;
00414 inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) {
00415 return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00416 inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t dim) const {
00417 return states_per_observation_psi[dim%parallel->get_num_threads()] ; } ;
00418 inline T_STATES* PATH(int32_t dim) {
00419 return path[dim%parallel->get_num_threads()] ; } ;
00420 inline bool & PATH_PROB_UPDATED(int32_t dim) {
00421 return path_prob_updated[dim%parallel->get_num_threads()] ; } ;
00422 inline int32_t & PATH_PROB_DIMENSION(int32_t dim) {
00423 return path_prob_dimension[dim%parallel->get_num_threads()] ; } ;
00424 #else
00425 inline T_ALPHA_BETA & ALPHA_CACHE(int32_t ) {
00426 return alpha_cache ; } ;
00427 inline T_ALPHA_BETA & BETA_CACHE(int32_t ) {
00428 return beta_cache ; } ;
00429 #ifdef USE_LOGSUMARRAY
00430 inline float64_t* ARRAYS(int32_t dim) {
00431 return arrayS ; } ;
00432 #endif
00433 inline float64_t* ARRAYN1(int32_t ) {
00434 return arrayN1 ; } ;
00435 inline float64_t* ARRAYN2(int32_t ) {
00436 return arrayN2 ; } ;
00437 inline T_STATES* STATES_PER_OBSERVATION_PSI(int32_t ) {
00438 return states_per_observation_psi ; } ;
00439 inline const T_STATES* STATES_PER_OBSERVATION_PSI(int32_t ) const {
00440 return states_per_observation_psi ; } ;
00441 inline T_STATES* PATH(int32_t ) {
00442 return path ; } ;
00443 inline bool & PATH_PROB_UPDATED(int32_t ) {
00444 return path_prob_updated ; } ;
00445 inline int32_t & PATH_PROB_DIMENSION(int32_t ) {
00446 return path_prob_dimension ; } ;
00447 #endif
00448
00453 bool converged(float64_t x, float64_t y);
00454
00460 public:
00462 CHMM(void);
00463
00474 CHMM(
00475 int32_t N, int32_t M, Model* model, float64_t PSEUDO);
00476 CHMM(
00477 CStringFeatures<uint16_t>* obs, int32_t N, int32_t M,
00478 float64_t PSEUDO);
00479 CHMM(
00480 int32_t N, float64_t* p, float64_t* q, float64_t* a);
00481 CHMM(
00482 int32_t N, float64_t* p, float64_t* q, int32_t num_trans,
00483 float64_t* a_trans);
00484
00489 CHMM(FILE* model_file, float64_t PSEUDO);
00490
00492 CHMM(CHMM* h);
00493
00495 virtual ~CHMM();
00496
00505 virtual bool train(CFeatures* data=NULL);
00506 virtual inline int32_t get_num_model_parameters() { return N*(N+M+2); }
00507 virtual float64_t get_log_model_parameter(int32_t num_param);
00508 virtual float64_t get_log_derivative(int32_t num_param, int32_t num_example);
00509 virtual float64_t get_log_likelihood_example(int32_t num_example)
00510 {
00511 return model_probability(num_example);
00512 }
00513
00519 bool initialize(Model* model, float64_t PSEUDO, FILE* model_file=NULL);
00521
00523 bool alloc_state_dependend_arrays();
00524
00526 void free_state_dependend_arrays();
00527
00539 float64_t forward_comp(int32_t time, int32_t state, int32_t dimension);
00540 float64_t forward_comp_old(
00541 int32_t time, int32_t state, int32_t dimension);
00542
00550 float64_t backward_comp(int32_t time, int32_t state, int32_t dimension);
00551 float64_t backward_comp_old(
00552 int32_t time, int32_t state, int32_t dimension);
00553
00558 float64_t best_path(int32_t dimension);
00559 inline uint16_t get_best_path_state(int32_t dim, int32_t t)
00560 {
00561 ASSERT(PATH(dim));
00562 return PATH(dim)[t];
00563 }
00564
00567 float64_t model_probability_comp() ;
00568
00570 inline float64_t model_probability(int32_t dimension=-1)
00571 {
00572
00573 if (dimension==-1)
00574 {
00575 if (mod_prob_updated)
00576 return mod_prob/p_observations->get_num_vectors();
00577 else
00578 return model_probability_comp()/p_observations->get_num_vectors();
00579 }
00580 else
00581 return forward(p_observations->get_vector_length(dimension), 0, dimension);
00582 }
00583
00589 inline float64_t linear_model_probability(int32_t dimension)
00590 {
00591 float64_t lik=0;
00592 int32_t len=0;
00593 bool free_vec;
00594 uint16_t* o=p_observations->get_feature_vector(dimension, len, free_vec);
00595 float64_t* obs_b=observation_matrix_b;
00596
00597 ASSERT(N==len);
00598
00599 for (int32_t i=0; i<N; i++)
00600 {
00601 lik+=obs_b[*o++];
00602 obs_b+=M;
00603 }
00604 p_observations->free_feature_vector(o, dimension, free_vec);
00605 return lik;
00606
00607
00608
00609
00610
00611
00612
00613
00614
00615 }
00616
00618
00621 inline bool set_iterations(int32_t num) { iterations=num; return true; }
00622 inline int32_t get_iterations() { return iterations; }
00623 inline bool set_epsilon (float64_t eps) { epsilon=eps; return true; }
00624 inline float64_t get_epsilon() { return epsilon; }
00625
00629 bool baum_welch_viterbi_train(BaumWelchViterbiType type);
00630
00637 void estimate_model_baum_welch(CHMM* train);
00638 void estimate_model_baum_welch_trans(CHMM* train);
00639
00640 #ifdef USE_HMMPARALLEL_STRUCTURES
00641 void ab_buf_comp(
00642 float64_t* p_buf, float64_t* q_buf, float64_t* a_buf,
00643 float64_t* b_buf, int32_t dim) ;
00644 #else
00645 void estimate_model_baum_welch_old(CHMM* train);
00646 #endif
00647
00651 void estimate_model_baum_welch_defined(CHMM* train);
00652
00656 void estimate_model_viterbi(CHMM* train);
00657
00661 void estimate_model_viterbi_defined(CHMM* train);
00662
00664
00666 bool linear_train(bool right_align=false);
00667
00669 bool permutation_entropy(int32_t window_width, int32_t sequence_number);
00670
00677 void output_model(bool verbose=false);
00678
00680 void output_model_defined(bool verbose=false);
00682
00683
00686
00688 void normalize(bool keep_dead_states=false);
00689
00693 void add_states(int32_t num_states, float64_t default_val=0);
00694
00700 bool append_model(
00701 CHMM* append_model, float64_t* cur_out, float64_t* app_out);
00702
00706 bool append_model(CHMM* append_model);
00707
00709 void chop(float64_t value);
00710
00712 void convert_to_log();
00713
00715 void init_model_random();
00716
00722 void init_model_defined();
00723
00725 void clear_model();
00726
00728 void clear_model_defined();
00729
00731 void copy_model(CHMM* l);
00732
00737 void invalidate_model();
00738
00742 inline bool get_status() const
00743 {
00744 return status;
00745 }
00746
00748 inline float64_t get_pseudo() const
00749 {
00750 return PSEUDO ;
00751 }
00752
00754 inline void set_pseudo(float64_t pseudo)
00755 {
00756 PSEUDO=pseudo ;
00757 }
00758
00759 #ifdef USE_HMMPARALLEL_STRUCTURES
00760 static void* bw_dim_prefetch(void * params);
00761 static void* bw_single_dim_prefetch(void * params);
00762 static void* vit_dim_prefetch(void * params);
00763 #endif
00764
00765 #ifdef FIX_POS
00766
00769 inline bool set_fix_pos_state(int32_t pos, T_STATES state, char value)
00770 {
00771 if (!model)
00772 return false ;
00773 model->set_fix_pos_state(pos, state, N, value) ;
00774 return true ;
00775 } ;
00776 #endif
00777
00778
00787 void set_observations(CStringFeatures<uint16_t>* obs, CHMM* hmm=NULL);
00788
00792 void set_observation_nocache(CStringFeatures<uint16_t>* obs);
00793
00795 inline CStringFeatures<uint16_t>* get_observations()
00796 {
00797 SG_REF(p_observations);
00798 return p_observations;
00799 }
00801
00869 bool load_definitions(FILE* file, bool verbose, bool initialize=true);
00870
00906 bool load_model(FILE* file);
00907
00911 bool save_model(FILE* file);
00912
00916 bool save_model_derivatives(FILE* file);
00917
00921 bool save_model_derivatives_bin(FILE* file);
00922
00926 bool save_model_bin(FILE* file);
00927
00929 bool check_model_derivatives() ;
00930 bool check_model_derivatives_combined() ;
00931
00937 T_STATES* get_path(int32_t dim, float64_t& prob);
00938
00942 bool save_path(FILE* file);
00943
00947 bool save_path_derivatives(FILE* file);
00948
00952 bool save_path_derivatives_bin(FILE* file);
00953
00954 #ifdef USE_HMMDEBUG
00955
00956 bool check_path_derivatives() ;
00957 #endif //USE_HMMDEBUG
00958
00962 bool save_likelihood_bin(FILE* file);
00963
00967 bool save_likelihood(FILE* file);
00969
00975
00977 inline T_STATES get_N() const { return N ; }
00978
00980 inline int32_t get_M() const { return M ; }
00981
00986 inline void set_q(T_STATES offset, float64_t value)
00987 {
00988 #ifdef HMM_DEBUG
00989 if (offset>=N)
00990 SG_DEBUG("index out of range in set_q(%i,%e) [%i]\n", offset,value,N) ;
00991 #endif
00992 end_state_distribution_q[offset]=value;
00993 }
00994
00999 inline void set_p(T_STATES offset, float64_t value)
01000 {
01001 #ifdef HMM_DEBUG
01002 if (offset>=N)
01003 SG_DEBUG("index out of range in set_p(%i,.) [%i]\n", offset,N) ;
01004 #endif
01005 initial_state_distribution_p[offset]=value;
01006 }
01007
01013 inline void set_A(T_STATES line_, T_STATES column, float64_t value)
01014 {
01015 #ifdef HMM_DEBUG
01016 if ((line_>N)||(column>N))
01017 SG_DEBUG("index out of range in set_A(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01018 #endif
01019 transition_matrix_A[line_+column*N]=value;
01020 }
01021
01027 inline void set_a(T_STATES line_, T_STATES column, float64_t value)
01028 {
01029 #ifdef HMM_DEBUG
01030 if ((line_>N)||(column>N))
01031 SG_DEBUG("index out of range in set_a(%i,%i,.) [%i,%i]\n",line_,column,N,N) ;
01032 #endif
01033 transition_matrix_a[line_+column*N]=value;
01034 }
01035
01041 inline void set_B(T_STATES line_, uint16_t column, float64_t value)
01042 {
01043 #ifdef HMM_DEBUG
01044 if ((line_>=N)||(column>=M))
01045 SG_DEBUG("index out of range in set_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01046 #endif
01047 observation_matrix_B[line_*M+column]=value;
01048 }
01049
01055 inline void set_b(T_STATES line_, uint16_t column, float64_t value)
01056 {
01057 #ifdef HMM_DEBUG
01058 if ((line_>=N)||(column>=M))
01059 SG_DEBUG("index out of range in set_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01060 #endif
01061 observation_matrix_b[line_*M+column]=value;
01062 }
01063
01070 inline void set_psi(
01071 int32_t time, T_STATES state, T_STATES value, int32_t dimension)
01072 {
01073 #ifdef HMM_DEBUG
01074 if ((time>=p_observations->get_max_vector_length())||(state>N))
01075 SG_DEBUG("index out of range in set_psi(%i,%i,.) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01076 #endif
01077 STATES_PER_OBSERVATION_PSI(dimension)[time*N+state]=value;
01078 }
01079
01084 inline float64_t get_q(T_STATES offset) const
01085 {
01086 #ifdef HMM_DEBUG
01087 if (offset>=N)
01088 SG_DEBUG("index out of range in %e=get_q(%i) [%i]\n", end_state_distribution_q[offset],offset,N) ;
01089 #endif
01090 return end_state_distribution_q[offset];
01091 }
01092
01097 inline float64_t get_p(T_STATES offset) const
01098 {
01099 #ifdef HMM_DEBUG
01100 if (offset>=N)
01101 SG_DEBUG("index out of range in get_p(%i,.) [%i]\n", offset,N) ;
01102 #endif
01103 return initial_state_distribution_p[offset];
01104 }
01105
01111 inline float64_t get_A(T_STATES line_, T_STATES column) const
01112 {
01113 #ifdef HMM_DEBUG
01114 if ((line_>N)||(column>N))
01115 SG_DEBUG("index out of range in get_A(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01116 #endif
01117 return transition_matrix_A[line_+column*N];
01118 }
01119
01125 inline float64_t get_a(T_STATES line_, T_STATES column) const
01126 {
01127 #ifdef HMM_DEBUG
01128 if ((line_>N)||(column>N))
01129 SG_DEBUG("index out of range in get_a(%i,%i) [%i,%i]\n",line_,column,N,N) ;
01130 #endif
01131 return transition_matrix_a[line_+column*N];
01132 }
01133
01139 inline float64_t get_B(T_STATES line_, uint16_t column) const
01140 {
01141 #ifdef HMM_DEBUG
01142 if ((line_>=N)||(column>=M))
01143 SG_DEBUG("index out of range in get_B(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01144 #endif
01145 return observation_matrix_B[line_*M+column];
01146 }
01147
01153 inline float64_t get_b(T_STATES line_, uint16_t column) const
01154 {
01155 #ifdef HMM_DEBUG
01156 if ((line_>=N)||(column>=M))
01157 SG_DEBUG("index out of range in get_b(%i,%i) [%i,%i]\n", line_, column,N,M) ;
01158 #endif
01159
01160 return observation_matrix_b[line_*M+column];
01161 }
01162
01169 inline T_STATES get_psi(
01170 int32_t time, T_STATES state, int32_t dimension) const
01171 {
01172 #ifdef HMM_DEBUG
01173 if ((time>=p_observations->get_max_vector_length())||(state>N))
01174 SG_DEBUG("index out of range in get_psi(%i,%i) [%i,%i]\n",time,state,p_observations->get_max_vector_length(),N) ;
01175 #endif
01176 return STATES_PER_OBSERVATION_PSI(dimension)[time*N+state];
01177 }
01178
01180
01182 inline virtual const char* get_name() const { return "HMM"; }
01183
01184 protected:
01189
01190 int32_t M;
01191
01193 int32_t N;
01194
01196 float64_t PSEUDO;
01197
01198
01199 int32_t line;
01200
01202 CStringFeatures<uint16_t>* p_observations;
01203
01204
01205 Model* model;
01206
01208 float64_t* transition_matrix_A;
01209
01211 float64_t* observation_matrix_B;
01212
01214 float64_t* transition_matrix_a;
01215
01217 float64_t* initial_state_distribution_p;
01218
01220 float64_t* end_state_distribution_q;
01221
01223 float64_t* observation_matrix_b;
01224
01226 int32_t iterations;
01227 int32_t iteration_count;
01228
01230 float64_t epsilon;
01231 int32_t conv_it;
01232
01234 float64_t all_pat_prob;
01235
01237 float64_t pat_prob;
01238
01240 float64_t mod_prob;
01241
01243 bool mod_prob_updated;
01244
01246 bool all_path_prob_updated;
01247
01249 int32_t path_deriv_dimension;
01250
01252 bool path_deriv_updated;
01253
01254
01255 bool loglikelihood;
01256
01257
01258 bool status;
01259
01260
01261 bool reused_caches;
01263
01264 #ifdef USE_HMMPARALLEL_STRUCTURES
01265
01266 float64_t** arrayN1 ;
01268 float64_t** arrayN2 ;
01269 #else //USE_HMMPARALLEL_STRUCTURES
01270
01271 float64_t* arrayN1;
01273 float64_t* arrayN2;
01274 #endif //USE_HMMPARALLEL_STRUCTURES
01275
01276 #ifdef USE_LOGSUMARRAY
01277 #ifdef USE_HMMPARALLEL_STRUCTURES
01278
01279 float64_t** arrayS ;
01280 #else
01281
01282 float64_t* arrayS;
01283 #endif // USE_HMMPARALLEL_STRUCTURES
01284 #endif // USE_LOGSUMARRAY
01285
01286 #ifdef USE_HMMPARALLEL_STRUCTURES
01287
01289 T_ALPHA_BETA* alpha_cache ;
01291 T_ALPHA_BETA* beta_cache ;
01292
01294 T_STATES** states_per_observation_psi ;
01295
01297 T_STATES** path ;
01298
01300 bool* path_prob_updated ;
01301
01303 int32_t* path_prob_dimension ;
01304
01305 #else //USE_HMMPARALLEL_STRUCTURES
01306
01307 T_ALPHA_BETA alpha_cache;
01309 T_ALPHA_BETA beta_cache;
01310
01312 T_STATES* states_per_observation_psi;
01313
01315 T_STATES* path;
01316
01318 bool path_prob_updated;
01319
01321 int32_t path_prob_dimension;
01322
01323 #endif //USE_HMMPARALLEL_STRUCTURES
01324
01325
01327 static const int32_t GOTN;
01329 static const int32_t GOTM;
01331 static const int32_t GOTO;
01333 static const int32_t GOTa;
01335 static const int32_t GOTb;
01337 static const int32_t GOTp;
01339 static const int32_t GOTq;
01340
01342 static const int32_t GOTlearn_a;
01344 static const int32_t GOTlearn_b;
01346 static const int32_t GOTlearn_p;
01348 static const int32_t GOTlearn_q;
01350 static const int32_t GOTconst_a;
01352 static const int32_t GOTconst_b;
01354 static const int32_t GOTconst_p;
01356 static const int32_t GOTconst_q;
01357
01358 public:
01363
01365 inline float64_t state_probability(
01366 int32_t time, int32_t state, int32_t dimension)
01367 {
01368 return forward(time, state, dimension) + backward(time, state, dimension) - model_probability(dimension);
01369 }
01370
01372 inline float64_t transition_probability(
01373 int32_t time, int32_t state_i, int32_t state_j, int32_t dimension)
01374 {
01375 return forward(time, state_i, dimension) +
01376 backward(time+1, state_j, dimension) +
01377 get_a(state_i,state_j) + get_b(state_j,p_observations->get_feature(dimension ,time+1)) - model_probability(dimension);
01378 }
01379
01386
01389 inline float64_t linear_model_derivative(
01390 T_STATES i, uint16_t j, int32_t dimension)
01391 {
01392 float64_t der=0;
01393
01394 for (int32_t k=0; k<N; k++)
01395 {
01396 if (k!=i || p_observations->get_feature(dimension, k) != j)
01397 der+=get_b(k, p_observations->get_feature(dimension, k));
01398 }
01399
01400 return der;
01401 }
01402
01406 inline float64_t model_derivative_p(T_STATES i, int32_t dimension)
01407 {
01408 return backward(0,i,dimension)+get_b(i, p_observations->get_feature(dimension, 0));
01409 }
01410
01414 inline float64_t model_derivative_q(T_STATES i, int32_t dimension)
01415 {
01416 return forward(p_observations->get_vector_length(dimension)-1,i,dimension) ;
01417 }
01418
01420 inline float64_t model_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01421 {
01422 float64_t sum=-CMath::INFTY;
01423 for (int32_t t=0; t<p_observations->get_vector_length(dimension)-1; t++)
01424 sum= CMath::logarithmic_sum(sum, forward(t, i, dimension) + backward(t+1, j, dimension) + get_b(j, p_observations->get_feature(dimension,t+1)));
01425
01426 return sum;
01427 }
01428
01429
01431 inline float64_t model_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01432 {
01433 float64_t sum=-CMath::INFTY;
01434 for (int32_t t=0; t<p_observations->get_vector_length(dimension); t++)
01435 {
01436 if (p_observations->get_feature(dimension,t)==j)
01437 sum= CMath::logarithmic_sum(sum, forward(t,i,dimension)+backward(t,i,dimension)-get_b(i,p_observations->get_feature(dimension,t)));
01438 }
01439
01440
01441 return sum;
01442 }
01444
01451
01453 inline float64_t path_derivative_p(T_STATES i, int32_t dimension)
01454 {
01455 best_path(dimension);
01456 return (i==PATH(dimension)[0]) ? (exp(-get_p(PATH(dimension)[0]))) : (0) ;
01457 }
01458
01460 inline float64_t path_derivative_q(T_STATES i, int32_t dimension)
01461 {
01462 best_path(dimension);
01463 return (i==PATH(dimension)[p_observations->get_vector_length(dimension)-1]) ? (exp(-get_q(PATH(dimension)[p_observations->get_vector_length(dimension)-1]))) : 0 ;
01464 }
01465
01467 inline float64_t path_derivative_a(T_STATES i, T_STATES j, int32_t dimension)
01468 {
01469 prepare_path_derivative(dimension) ;
01470 return (get_A(i,j)==0) ? (0) : (get_A(i,j)*exp(-get_a(i,j))) ;
01471 }
01472
01474 inline float64_t path_derivative_b(T_STATES i, uint16_t j, int32_t dimension)
01475 {
01476 prepare_path_derivative(dimension) ;
01477 return (get_B(i,j)==0) ? (0) : (get_B(i,j)*exp(-get_b(i,j))) ;
01478 }
01479
01481
01482
01483 protected:
01488
01489 bool get_numbuffer(FILE* file, char* buffer, int32_t length);
01490
01492 void open_bracket(FILE* file);
01493
01495 void close_bracket(FILE* file);
01496
01498 bool comma_or_space(FILE* file);
01499
01501 inline void error(int32_t p_line, const char* str)
01502 {
01503 if (p_line)
01504 SG_ERROR( "error in line %d %s\n", p_line, str);
01505 else
01506 SG_ERROR( "error %s\n", str);
01507 }
01509
01511 inline void prepare_path_derivative(int32_t dim)
01512 {
01513 if (path_deriv_updated && (path_deriv_dimension==dim))
01514 return ;
01515 int32_t i,j,t ;
01516 best_path(dim);
01517
01518 for (i=0; i<N; i++)
01519 {
01520 for (j=0; j<N; j++)
01521 set_A(i,j, 0);
01522 for (j=0; j<M; j++)
01523 set_B(i,j, 0);
01524 }
01525
01526
01527 for (t=0; t<p_observations->get_vector_length(dim)-1; t++)
01528 {
01529 set_A(PATH(dim)[t], PATH(dim)[t+1], get_A(PATH(dim)[t], PATH(dim)[t+1])+1);
01530 set_B(PATH(dim)[t], p_observations->get_feature(dim,t), get_B(PATH(dim)[t], p_observations->get_feature(dim,t))+1);
01531 }
01532 set_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1), get_B(PATH(dim)[p_observations->get_vector_length(dim)-1], p_observations->get_feature(dim,p_observations->get_vector_length(dim)-1)) + 1);
01533 path_deriv_dimension=dim ;
01534 path_deriv_updated=true ;
01535 } ;
01537
01539 inline float64_t forward(int32_t time, int32_t state, int32_t dimension)
01540 {
01541 if (time<1)
01542 time=0;
01543
01544 if (ALPHA_CACHE(dimension).table && (dimension==ALPHA_CACHE(dimension).dimension) && ALPHA_CACHE(dimension).updated)
01545 {
01546 if (time<p_observations->get_vector_length(dimension))
01547 return ALPHA_CACHE(dimension).table[time*N+state];
01548 else
01549 return ALPHA_CACHE(dimension).sum;
01550 }
01551 else
01552 return forward_comp(time, state, dimension) ;
01553 }
01554
01556 inline float64_t backward(int32_t time, int32_t state, int32_t dimension)
01557 {
01558 if (BETA_CACHE(dimension).table && (dimension==BETA_CACHE(dimension).dimension) && (BETA_CACHE(dimension).updated))
01559 {
01560 if (time<0)
01561 return BETA_CACHE(dimension).sum;
01562 if (time<p_observations->get_vector_length(dimension))
01563 return BETA_CACHE(dimension).table[time*N+state];
01564 else
01565 return -CMath::INFTY;
01566 }
01567 else
01568 return backward_comp(time, state, dimension) ;
01569 }
01570
01571 };
01572 }
01573 #endif