PluginEstimate.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013 
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/labels/BinaryLabels.h>
00017 #include <shogun/distributions/LinearHMM.h>
00018 
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CMachine
00035 {
00036     public:
00037 
00039         MACHINE_PROBLEM_TYPE(PT_BINARY);
00040 
00045         CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00046         virtual ~CPluginEstimate();
00047 
00053         virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00054 
00059         virtual void set_features(CStringFeatures<uint16_t>* feat)
00060         {
00061             SG_UNREF(features);
00062             SG_REF(feat);
00063             features=feat;
00064         }
00065 
00070         virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00071 
00073         float64_t apply_one(int32_t vec_idx);
00074 
00081         inline float64_t posterior_log_odds_obsolete(
00082             uint16_t* vector, int32_t len)
00083         {
00084             return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00085         }
00086 
00093         inline float64_t get_parameterwise_log_odds(
00094             uint16_t obs, int32_t position)
00095         {
00096             return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00097         }
00098 
00105         inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00106         {
00107             return pos_model->get_log_derivative_obsolete(obs, pos);
00108         }
00109 
00116         inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00117         {
00118             return neg_model->get_log_derivative_obsolete(obs, pos);
00119         }
00120 
00129         inline bool get_model_params(
00130             float64_t*& pos_params, float64_t*& neg_params,
00131             int32_t &seq_length, int32_t &num_symbols)
00132         {
00133             if ((!pos_model) || (!neg_model))
00134             {
00135                 SG_ERROR( "no model available\n");
00136                 return false;
00137             }
00138 
00139             SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs();
00140             pos_params = log_pos_trans.vector;
00141             SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs();
00142             neg_params = log_neg_trans.vector;
00143 
00144             seq_length = pos_model->get_sequence_length();
00145             num_symbols = pos_model->get_num_symbols();
00146             ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00147             ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00148             return true;
00149         }
00150 
00157         inline void set_model_params(
00158             float64_t* pos_params, float64_t* neg_params,
00159             int32_t seq_length, int32_t num_symbols)
00160         {
00161             int32_t num_params;
00162 
00163             SG_UNREF(pos_model);
00164             pos_model=new CLinearHMM(seq_length, num_symbols);
00165             SG_REF(pos_model);
00166 
00167 
00168             SG_UNREF(neg_model);
00169             neg_model=new CLinearHMM(seq_length, num_symbols);
00170             SG_REF(neg_model);
00171 
00172             num_params=pos_model->get_num_model_parameters();
00173             ASSERT(seq_length*num_symbols==num_params);
00174             ASSERT(num_params==neg_model->get_num_model_parameters());
00175 
00176             pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
00177             neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
00178         }
00179 
00184         inline int32_t get_num_params()
00185         {
00186             return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00187         }
00188 
00193         inline bool check_models()
00194         {
00195             return ( (pos_model!=NULL) && (neg_model!=NULL) );
00196         }
00197 
00199         virtual const char* get_name() const { return "PluginEstimate"; }
00200 
00201     protected:
00210         virtual bool train_machine(CFeatures* data=NULL);
00211 
00212     protected:
00214         float64_t m_pos_pseudo;
00216         float64_t m_neg_pseudo;
00217 
00219         CLinearHMM* pos_model;
00221         CLinearHMM* neg_model;
00222 
00224         CStringFeatures<uint16_t>* features;
00225 };
00226 }
00227 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation