PluginEstimate.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013 
00014 #include <shogun/machine/Machine.h>
00015 #include <shogun/features/StringFeatures.h>
00016 #include <shogun/features/Labels.h>
00017 #include <shogun/distributions/LinearHMM.h>
00018 
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CMachine
00035 {
00036     public:
00041         CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00042         virtual ~CPluginEstimate();
00043 
00048         CLabels* apply();
00049 
00055         virtual CLabels* apply(CFeatures* data);
00056 
00061         virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00062         {
00063             SG_UNREF(features);
00064             SG_REF(feat);
00065             features=feat;
00066         }
00067 
00072         virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00073 
00075         float64_t apply(int32_t vec_idx);
00076 
00083         inline float64_t posterior_log_odds_obsolete(
00084             uint16_t* vector, int32_t len)
00085         {
00086             return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00087         }
00088 
00095         inline float64_t get_parameterwise_log_odds(
00096             uint16_t obs, int32_t position)
00097         {
00098             return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00099         }
00100 
00107         inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00108         {
00109             return pos_model->get_log_derivative_obsolete(obs, pos);
00110         }
00111 
00118         inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00119         {
00120             return neg_model->get_log_derivative_obsolete(obs, pos);
00121         }
00122 
00131         inline bool get_model_params(
00132             float64_t*& pos_params, float64_t*& neg_params,
00133             int32_t &seq_length, int32_t &num_symbols)
00134         {
00135             if ((!pos_model) || (!neg_model))
00136             {
00137                 SG_ERROR( "no model available\n");
00138                 return false;
00139             }
00140 
00141             SGVector<float64_t> log_pos_trans = pos_model->get_log_transition_probs();
00142             pos_params = log_pos_trans.vector;
00143             SGVector<float64_t> log_neg_trans = neg_model->get_log_transition_probs();
00144             neg_params = log_neg_trans.vector;
00145 
00146             seq_length = pos_model->get_sequence_length();
00147             num_symbols = pos_model->get_num_symbols();
00148             ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00149             ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00150             return true;
00151         }
00152 
00159         inline void set_model_params(
00160             float64_t* pos_params, float64_t* neg_params,
00161             int32_t seq_length, int32_t num_symbols)
00162         {
00163             int32_t num_params;
00164 
00165             SG_UNREF(pos_model);
00166             pos_model=new CLinearHMM(seq_length, num_symbols);
00167             SG_REF(pos_model);
00168 
00169 
00170             SG_UNREF(neg_model);
00171             neg_model=new CLinearHMM(seq_length, num_symbols);
00172             SG_REF(neg_model);
00173 
00174             num_params=pos_model->get_num_model_parameters();
00175             ASSERT(seq_length*num_symbols==num_params);
00176             ASSERT(num_params==neg_model->get_num_model_parameters());
00177 
00178             pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
00179             neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
00180         }
00181 
00186         inline int32_t get_num_params()
00187         {
00188             return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00189         }
00190         
00195         inline bool check_models()
00196         {
00197             return ( (pos_model!=NULL) && (neg_model!=NULL) );
00198         }
00199 
00201         inline virtual const char* get_name() const { return "PluginEstimate"; }
00202 
00203     protected:
00212         virtual bool train_machine(CFeatures* data=NULL);
00213 
00214     protected:
00216         float64_t m_pos_pseudo;
00218         float64_t m_neg_pseudo;
00219 
00221         CLinearHMM* pos_model;
00223         CLinearHMM* neg_model;
00224 
00226         CStringFeatures<uint16_t>* features;
00227 };
00228 }
00229 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation