PluginEstimate.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _PLUGINESTIMATE_H___
00012 #define _PLUGINESTIMATE_H___
00013 
00014 #include "classifier/Classifier.h"
00015 #include "features/StringFeatures.h"
00016 #include "features/Labels.h"
00017 #include "distributions/LinearHMM.h"
00018 
00019 namespace shogun
00020 {
00034 class CPluginEstimate: public CClassifier
00035 {
00036     public:
00041         CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
00042         virtual ~CPluginEstimate();
00043 
00052         virtual bool train(CFeatures* data=NULL);
00053 
00058         CLabels* classify();
00059 
00065         virtual CLabels* classify(CFeatures* data);
00066 
00071         virtual inline void set_features(CStringFeatures<uint16_t>* feat)
00072         {
00073             SG_UNREF(features);
00074             SG_REF(feat);
00075             features=feat;
00076         }
00077 
00082         virtual CStringFeatures<uint16_t>* get_features() { SG_REF(features); return features; }
00083 
00085         float64_t classify_example(int32_t vec_idx);
00086 
00093         inline float64_t posterior_log_odds_obsolete(
00094             uint16_t* vector, int32_t len)
00095         {
00096             return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
00097         }
00098 
00105         inline float64_t get_parameterwise_log_odds(
00106             uint16_t obs, int32_t position)
00107         {
00108             return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
00109         }
00110 
00117         inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
00118         {
00119             return pos_model->get_log_derivative_obsolete(obs, pos);
00120         }
00121 
00128         inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
00129         {
00130             return neg_model->get_log_derivative_obsolete(obs, pos);
00131         }
00132 
00141         inline bool get_model_params(
00142             float64_t*& pos_params, float64_t*& neg_params,
00143             int32_t &seq_length, int32_t &num_symbols)
00144         {
00145             int32_t num;
00146 
00147             if ((!pos_model) || (!neg_model))
00148             {
00149                 SG_ERROR( "no model available\n");
00150                 return false;
00151             }
00152 
00153             pos_model->get_log_transition_probs(&pos_params, &num);
00154             neg_model->get_log_transition_probs(&neg_params, &num);
00155 
00156             seq_length = pos_model->get_sequence_length();
00157             num_symbols = pos_model->get_num_symbols();
00158             ASSERT(pos_model->get_num_model_parameters()==neg_model->get_num_model_parameters());
00159             ASSERT(pos_model->get_num_symbols()==neg_model->get_num_symbols());
00160             return true;
00161         }
00162 
00169         inline void set_model_params(
00170             const float64_t* pos_params, const float64_t* neg_params,
00171             int32_t seq_length, int32_t num_symbols)
00172         {
00173             int32_t num_params;
00174 
00175             SG_UNREF(pos_model);
00176             pos_model=new CLinearHMM(seq_length, num_symbols);
00177             SG_REF(pos_model);
00178 
00179 
00180             SG_UNREF(neg_model);
00181             neg_model=new CLinearHMM(seq_length, num_symbols);
00182             SG_REF(neg_model);
00183 
00184             num_params=pos_model->get_num_model_parameters();
00185             ASSERT(seq_length*num_symbols==num_params);
00186             ASSERT(num_params==neg_model->get_num_model_parameters());
00187 
00188             pos_model->set_log_transition_probs(pos_params, num_params);
00189             neg_model->set_log_transition_probs(neg_params, num_params);
00190         }
00191 
00196         inline int32_t get_num_params()
00197         {
00198             return pos_model->get_num_model_parameters()+neg_model->get_num_model_parameters();
00199         }
00200         
00205         inline bool check_models()
00206         {
00207             return ( (pos_model!=NULL) && (neg_model!=NULL) );
00208         }
00209 
00211         inline virtual const char* get_name() const { return "PluginEstimate"; }
00212 
00213     protected:
00215         float64_t m_pos_pseudo;
00217         float64_t m_neg_pseudo;
00218 
00220         CLinearHMM* pos_model;
00222         CLinearHMM* neg_model;
00223 
00225         CStringFeatures<uint16_t>* features;
00226 };
00227 }
00228 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation