Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013
00014 #include "classifier/Classifier.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/LinearHMM.h"
00018
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CClassifier
00035 {
00036 public:
00041 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00042 virtual ~CPluginEstimate();
00043
00052 virtual bool train(CFeatures* data=NULL);
00053
00058 CLabels* classify();
00059
00065 virtual CLabels* classify(CFeatures* data);
00066
00071 virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00072 {
00073 SG_UNREF(features);
00074 SG_REF(feat);
00075 features=feat;
00076 }
00077
00082 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00083
00085 float64_t classify_example(int32_t vec_idx);
00086
00093 inline float64_t posterior_log_odds_obsolete(
00094 uint16_t* vector, int32_t len)
00095 {
00096 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00097 }
00098
00105 inline float64_t get_parameterwise_log_odds(
00106 uint16_t obs, int32_t position)
00107 {
00108 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00109 }
00110
00117 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00118 {
00119 return pos_model->get_log_derivative_obsolete(obs, pos);
00120 }
00121
00128 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00129 {
00130 return neg_model->get_log_derivative_obsolete(obs, pos);
00131 }
00132
00141 inline bool get_model_params(
00142 float64_t*& pos_params, float64_t*& neg_params,
00143 int32_t &seq_length, int32_t &num_symbols)
00144 {
00145 int32_t num;
00146
00147 if ((!pos_model) || (!neg_model))
00148 {
00149 SG_ERROR( "no model available\n");
00150 return false;
00151 }
00152
00153 pos_model->get_log_transition_probs(&pos_params, &num);
00154 neg_model->get_log_transition_probs(&neg_params, &num);
00155
00156 seq_length = pos_model->get_sequence_length();
00157 num_symbols = pos_model->get_num_symbols();
00158 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00159 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00160 return true;
00161 }
00162
00169 inline void set_model_params(
00170 const float64_t* pos_params, const float64_t* neg_params,
00171 int32_t seq_length, int32_t num_symbols)
00172 {
00173 int32_t num_params;
00174
00175 SG_UNREF(pos_model);
00176 pos_model=new CLinearHMM(seq_length, num_symbols);
00177 SG_REF(pos_model);
00178
00179
00180 SG_UNREF(neg_model);
00181 neg_model=new CLinearHMM(seq_length, num_symbols);
00182 SG_REF(neg_model);
00183
00184 num_params=pos_model->get_num_model_parameters();
00185 ASSERT(seq_length*num_symbols==num_params);
00186 ASSERT(num_params==neg_model->get_num_model_parameters());
00187
00188 pos_model->set_log_transition_probs(pos_params, num_params);
00189 neg_model->set_log_transition_probs(neg_params, num_params);
00190 }
00191
00196 inline int32_t get_num_params()
00197 {
00198 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00199 }
00200
00205 inline bool check_models()
00206 {
00207 return ( (pos_model!=NULL) && (neg_model!=NULL) );
00208 }
00209
00211 inline virtual const char* get_name() const { return "PluginEstimate"; }
00212
00213 protected:
00215 float64_t m_pos_pseudo;
00217 float64_t m_neg_pseudo;
00218
00220 CLinearHMM* pos_model;
00222 CLinearHMM* neg_model;
00223
00225 CStringFeatures<uint16_t>* features;
00226 };
00227 }
00228 #endif