SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
PluginEstimate.h
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #ifndef _PLUGINESTIMATE_H___
12 #define _PLUGINESTIMATE_H___
13 
14 #include <shogun/lib/config.h>
15 
16 #include <shogun/machine/Machine.h>
20 
21 namespace shogun
22 {
37 {
38  public:
39 
42 
47  CPluginEstimate(float64_t pos_pseudo=1e-10, float64_t neg_pseudo=1e-10);
48  virtual ~CPluginEstimate();
49 
55  virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
56 
62  {
63  SG_REF(feat);
65  features=feat;
66  }
67 
73 
75  float64_t apply_one(int32_t vec_idx);
76 
84  uint16_t* vector, int32_t len)
85  {
86  return pos_model->get_log_likelihood_example(vector, len) - neg_model->get_log_likelihood_example(vector, len);
87  }
88 
96  uint16_t obs, int32_t position)
97  {
98  return pos_model->get_positional_log_parameter(obs, position) - neg_model->get_positional_log_parameter(obs, position);
99  }
100 
107  inline float64_t log_derivative_pos_obsolete(uint16_t obs, int32_t pos)
108  {
109  return pos_model->get_log_derivative_obsolete(obs, pos);
110  }
111 
118  inline float64_t log_derivative_neg_obsolete(uint16_t obs, int32_t pos)
119  {
120  return neg_model->get_log_derivative_obsolete(obs, pos);
121  }
122 
131  inline bool get_model_params(
132  float64_t*& pos_params, float64_t*& neg_params,
133  int32_t &seq_length, int32_t &num_symbols)
134  {
135  if ((!pos_model) || (!neg_model))
136  {
137  SG_ERROR("no model available\n")
138  return false;
139  }
140 
142  pos_params = log_pos_trans.vector;
144  neg_params = log_neg_trans.vector;
145 
146  seq_length = pos_model->get_sequence_length();
147  num_symbols = pos_model->get_num_symbols();
150  return true;
151  }
152 
159  inline void set_model_params(
160  float64_t* pos_params, float64_t* neg_params,
161  int32_t seq_length, int32_t num_symbols)
162  {
163  int32_t num_params;
164 
166  pos_model=new CLinearHMM(seq_length, num_symbols);
167  SG_REF(pos_model);
168 
169 
171  neg_model=new CLinearHMM(seq_length, num_symbols);
172  SG_REF(neg_model);
173 
174  num_params=pos_model->get_num_model_parameters();
175  ASSERT(seq_length*num_symbols==num_params)
177 
178  pos_model->set_log_transition_probs(SGVector<float64_t>(pos_params, num_params));
179  neg_model->set_log_transition_probs(SGVector<float64_t>(neg_params, num_params));
180  }
181 
186  inline int32_t get_num_params()
187  {
189  }
190 
195  inline bool check_models()
196  {
197  return ( (pos_model!=NULL) && (neg_model!=NULL) );
198  }
199 
201  virtual const char* get_name() const { return "PluginEstimate"; }
202 
203  protected:
212  virtual bool train_machine(CFeatures* data=NULL);
213 
214  protected:
219 
224 
227 };
228 }
229 #endif

SHOGUN Machine Learning Toolbox - Documentation