Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _LINEARHMM_H__
00013 #define _LINEARHMM_H__
00014
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/features/Labels.h>
00017 #include <shogun/distributions/Distribution.h>
00018
00019 namespace shogun
00020 {
00039 class CLinearHMM : public CDistribution
00040 {
00041 public:
00043 CLinearHMM();
00044
00049 CLinearHMM(CStringFeatures<uint16_t>* f);
00050
00056 CLinearHMM(int32_t p_num_features, int32_t p_num_symbols);
00057
00058 virtual ~CLinearHMM();
00059
00068 virtual bool train(CFeatures* data=NULL);
00069
00077 bool train(
00078 const int32_t* indizes, int32_t num_indizes,
00079 float64_t pseudo_count);
00080
00087 float64_t get_log_likelihood_example(uint16_t* vector, int32_t len);
00088
00095 float64_t get_likelihood_example(uint16_t* vector, int32_t len);
00096
00102 virtual float64_t get_log_likelihood_example(int32_t num_example);
00103
00110 virtual float64_t get_log_derivative(
00111 int32_t num_param, int32_t num_example);
00112
00119 virtual inline float64_t get_log_derivative_obsolete(
00120 uint16_t obs, int32_t pos)
00121 {
00122 return 1.0/transition_probs[pos*num_symbols+obs];
00123 }
00124
00131 virtual inline float64_t get_derivative_obsolete(
00132 uint16_t* vector, int32_t len, int32_t pos)
00133 {
00134 ASSERT(pos<len);
00135 return get_likelihood_example(vector, len)/transition_probs[pos*num_symbols+vector[pos]];
00136 }
00137
00142 virtual inline int32_t get_sequence_length() { return sequence_length; }
00143
00148 virtual inline int32_t get_num_symbols() { return num_symbols; }
00149
00154 virtual inline int32_t get_num_model_parameters() { return num_params; }
00155
00162 virtual inline float64_t get_positional_log_parameter(
00163 uint16_t obs, int32_t position)
00164 {
00165 return log_transition_probs[position*num_symbols+obs];
00166 }
00167
00173 virtual inline float64_t get_log_model_parameter(int32_t num_param)
00174 {
00175 ASSERT(log_transition_probs);
00176 ASSERT(num_param<num_params);
00177
00178 return log_transition_probs[num_param];
00179 }
00180
00185 virtual SGVector<float64_t> get_log_transition_probs();
00186
00192 virtual bool set_log_transition_probs(SGVector<float64_t> probs);
00193
00198 virtual SGVector<float64_t> get_transition_probs();
00199
00205 virtual bool set_transition_probs(SGVector<float64_t> probs);
00206
00208 inline virtual const char* get_name() const { return "LinearHMM"; }
00209
00210 protected:
00211 virtual void load_serializable_post() throw (ShogunException);
00212
00213 private:
00214 void init();
00215
00216 protected:
00218 int32_t sequence_length;
00220 int32_t num_symbols;
00222 int32_t num_params;
00224 float64_t* transition_probs;
00226 float64_t* log_transition_probs;
00227 };
00228 }
00229 #endif