Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/labels/BinaryLabels.h>
00017 #include <shogun/distributions/LinearHMM.h>
00018
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CMachine
00035 {
00036 public:
00037
00039 MACHINE_PROBLEM_TYPE(PT_BINARY);
00040
00045 CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00046 virtual ~CPluginEstimate();
00047
00053 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00054
00059 virtual void set_features(CStringFeatures<uint16_t>* feat)
00060 {
00061 SG_UNREF(features);
00062 SG_REF(feat);
00063 features=feat;
00064 }
00065
00070 virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00071
00073 float64_t apply_one(int32_t vec_idx);
00074
00081 inline float64_t posterior_log_odds_obsolete(
00082 uint16_t* vector, int32_t len)
00083 {
00084 return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00085 }
00086
00093 inline float64_t get_parameterwise_log_odds(
00094 uint16_t obs, int32_t position)
00095 {
00096 return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00097 }
00098
00105 inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00106 {
00107 return pos_model->get_log_derivative_obsolete(obs, pos);
00108 }
00109
00116 inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00117 {
00118 return neg_model->get_log_derivative_obsolete(obs, pos);
00119 }
00120
00129 inline bool get_model_params(
00130 float64_t*& pos_params, float64_t*& neg_params,
00131 int32_t &seq_length, int32_t &num_symbols)
00132 {
00133 if ((!pos_model) || (!neg_model))
00134 {
00135 SG_ERROR( "no model available\n");
00136 return false;
00137 }
00138
00139 SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs();
00140 pos_params = log_pos_trans.vector;
00141 SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs();
00142 neg_params = log_neg_trans.vector;
00143
00144 seq_length = pos_model->get_sequence_length();
00145 num_symbols = pos_model->get_num_symbols();
00146 ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00147 ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00148 return true;
00149 }
00150
00157 inline void set_model_params(
00158 float64_t* pos_params, float64_t* neg_params,
00159 int32_t seq_length, int32_t num_symbols)
00160 {
00161 int32_t num_params;
00162
00163 SG_UNREF(pos_model);
00164 pos_model=new CLinearHMM(seq_length, num_symbols);
00165 SG_REF(pos_model);
00166
00167
00168 SG_UNREF(neg_model);
00169 neg_model=new CLinearHMM(seq_length, num_symbols);
00170 SG_REF(neg_model);
00171
00172 num_params=pos_model->get_num_model_parameters();
00173 ASSERT(seq_length*num_symbols==num_params);
00174 ASSERT(num_params==neg_model->get_num_model_parameters());
00175
00176 pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
00177 neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
00178 }
00179
00184 inline int32_t get_num_params()
00185 {
00186 return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00187 }
00188
00193 inline bool check_models()
00194 {
00195 return ( (pos_model!=NULL) && (neg_model!=NULL) );
00196 }
00197
00199 virtual const char* get_name() const { return "PluginEstimate"; }
00200
00201 protected:
00210 virtual bool train_machine(CFeatures* data=NULL);
00211
00212 protected:
00214 float64_t m_pos_pseudo;
00216 float64_t m_neg_pseudo;
00217
00219 CLinearHMM* pos_model;
00221 CLinearHMM* neg_model;
00222
00224 CStringFeatures<uint16_t>* features;
00225 };
00226 }
00227 #endif