Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/features/Labels.h>
00017 #include <shogun/distributions/LinearHMM.h>
00018
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CMachine
00035 {
00036 public:
00041 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00042 virtual ~CPluginEstimate();
00043
00048 CLabels* apply();
00049
00055 virtual CLabels* apply(CFeatures* data);
00056
00061 virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00062 {
00063 SG_UNREF(features);
00064 SG_REF(feat);
00065 features=feat;
00066 }
00067
00072 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00073
00075 float64_t apply(int32_t vec_idx);
00076
00083 inline float64_t posterior_log_odds_obsolete(
00084 uint16_t* vector, int32_t len)
00085 {
00086 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00087 }
00088
00095 inline float64_t get_parameterwise_log_odds(
00096 uint16_t obs, int32_t position)
00097 {
00098 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00099 }
00100
00107 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00108 {
00109 return pos_model->get_log_derivative_obsolete(obs, pos);
00110 }
00111
00118 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00119 {
00120 return neg_model->get_log_derivative_obsolete(obs, pos);
00121 }
00122
00131 inline bool get_model_params(
00132 float64_t*& pos_params, float64_t*& neg_params,
00133 int32_t &seq_length, int32_t &num_symbols)
00134 {
00135 if ((!pos_model) || (!neg_model))
00136 {
00137 SG_ERROR( "no model available\n");
00138 return false;
00139 }
00140
00141 SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs();
00142 pos_params = log_pos_trans.vector;
00143 SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs();
00144 neg_params = log_neg_trans.vector;
00145
00146 seq_length = pos_model->get_sequence_length();
00147 num_symbols = pos_model->get_num_symbols();
00148 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00149 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00150 return true;
00151 }
00152
00159 inline void set_model_params(
00160 float64_t* pos_params, float64_t* neg_params,
00161 int32_t seq_length, int32_t num_symbols)
00162 {
00163 int32_t num_params;
00164
00165 SG_UNREF(pos_model);
00166 pos_model=new CLinearHMM(seq_length, num_symbols);
00167 SG_REF(pos_model);
00168
00169
00170 SG_UNREF(neg_model);
00171 neg_model=new CLinearHMM(seq_length, num_symbols);
00172 SG_REF(neg_model);
00173
00174 num_params=pos_model->get_num_model_parameters();
00175 ASSERT(seq_length*num_symbols==num_params);
00176 ASSERT(num_params==neg_model->get_num_model_parameters());
00177
00178 pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
00179 neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
00180 }
00181
00186 inline int32_t get_num_params()
00187 {
00188 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00189 }
00190
00195 inline bool check_models()
00196 {
00197 return ( (pos_model!=NULL) && (neg_model!=NULL) );
00198 }
00199
00201 inline virtual const char* get_name() const { return "PluginEstimate"; }
00202
00203 protected:
00212 virtual bool train_machine(CFeatures* data=NULL);
00213
00214 protected:
00216 float64_t m_pos_pseudo;
00218 float64_t m_neg_pseudo;
00219
00221 CLinearHMM* pos_model;
00223 CLinearHMM* neg_model;
00224
00226 CStringFeatures<uint16_t>* features;
00227 };
00228 }
00229 #endif