SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
PluginEstimate.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #ifndef _PLUGINESTIMATE_H___
12 #define _PLUGINESTIMATE_H___
13 
14 #include <shogun/lib/config.h>
15 
16 #include <shogun/machine/Machine.h>
20 
21 namespace shogun
22 {
37 {
38  public:
39 
42 
47  CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
48  virtual ~CPluginEstimate();
49 
55  virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
56 
62  {
63  SG_REF(feat);
65  features=feat;
66  }
67 
73 
75  float64_t apply_one(int32_t vec_idx);
76 
84  uint16_t* vector, int32_t len)
85  {
86  return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
87  }
88 
96  uint16_t obs, int32_t position)
97  {
98  return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
99  }
100 
107  inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
108  {
109  return pos_model->get_log_derivative_obsolete(obs, pos);
110  }
111 
118  inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
119  {
120  return neg_model->get_log_derivative_obsolete(obs, pos);
121  }
122 
131  inline bool get_model_params(
132  float64_t*& pos_params, float64_t*& neg_params,
133  int32_t &seq_length, int32_t &num_symbols)
134  {
135  if ((!pos_model) || (!neg_model))
136  {
137  SG_ERROR("no model available\n")
138  return false;
139  }
140 
142  pos_params = log_pos_trans.vector;
144  neg_params = log_neg_trans.vector;
145 
146  seq_length = pos_model->get_sequence_length();
147  num_symbols = pos_model->get_num_symbols();
150  return true;
151  }
152 
159  inline void set_model_params(
160  float64_t* pos_params, float64_t* neg_params,
161  int32_t seq_length, int32_t num_symbols)
162  {
163  int32_t num_params;
164 
166  pos_model=new CLinearHMM(seq_length, num_symbols);
167  SG_REF(pos_model);
168 
169 
171  neg_model=new CLinearHMM(seq_length, num_symbols);
172  SG_REF(neg_model);
173 
174  num_params=pos_model->get_num_model_parameters();
175  ASSERT(seq_length*num_symbols==num_params)
177 
178  pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
179  neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
180  }
181 
186  inline int32_t get_num_params()
187  {
189  }
190 
195  inline bool check_models()
196  {
197  return ( (pos_model!=NULL) && (neg_model!=NULL) );
198  }
199 
201  virtual const char* get_name() const { return "PluginEstimate"; }
202 
203  protected:
212  virtual bool train_machine(CFeatures* data=NULL);
213 
214  protected:
219 
224 
227 };
228 }
229 #endif
virtual int32_t get_sequence_length()
Definition: LinearHMM.h:151
virtual const char * get_name() const
virtual float64_t get_positional_log_parameter(uint16_t obs, int32_t position)
Definition: LinearHMM.h:171
float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
float64_t apply_one(int32_t vec_idx)
classify the test feature vector indexed by vec_idx
virtual float64_t get_log_derivative_obsolete(uint16_t obs, int32_t pos)
Definition: LinearHMM.h:128
virtual int32_t get_num_symbols()
Definition: LinearHMM.h:157
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
#define SG_ERROR(...)
Definition: SGIO.h:129
virtual int32_t get_num_model_parameters()
Definition: LinearHMM.h:163
float64_t get_parameterwise_log_odds(uint16_t obs, int32_t position)
bool get_model_params(float64_t *&pos_params, float64_t *&neg_params, int32_t &seq_length, int32_t &num_symbols)
#define SG_REF(x)
Definition: SGObject.h:54
A generic learning machine interface.
Definition: Machine.h:143
float64_t get_log_likelihood_example(uint16_t *vector, int32_t len)
Definition: LinearHMM.cpp:186
MACHINE_PROBLEM_TYPE(PT_BINARY)
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
virtual void set_features(CStringFeatures< uint16_t > *feat)
CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10)
float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
virtual bool set_log_transition_probs(const SGVector< float64_t > probs)
Definition: LinearHMM.cpp:283
float64_t posterior_log_odds_obsolete(uint16_t *vector, int32_t len)
#define SG_UNREF(x)
Definition: SGObject.h:55
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
virtual bool train_machine(CFeatures *data=NULL)
class PluginEstimate
virtual CStringFeatures< uint16_t > * get_features()
CStringFeatures< uint16_t > * features
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
void set_model_params(float64_t *pos_params, float64_t *neg_params, int32_t seq_length, int32_t num_symbols)
virtual SGVector< float64_t > get_log_transition_probs()
Definition: LinearHMM.cpp:278
The class LinearHMM is for learning Higher Order Markov chains.
Definition: LinearHMM.h:41

SHOGUN Machine Learning Toolbox - Documentation