SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
GaussianProcessClassification.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) The Shogun Machine Learning Toolbox
3  * Written (w) 2014 Wu Lin
4  * Written (W) 2013 Roman Votyakov
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without
8  * modification, are permitted provided that the following conditions are met:
9  *
10  * 1. Redistributions of source code must retain the above copyright notice, this
11  * list of conditions and the following disclaimer.
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  * this list of conditions and the following disclaimer in the documentation
14  * and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  * The views and conclusions contained in the software and documentation are those
28  * of the authors and should not be interpreted as representing official policies,
29  * either expressed or implied, of the Shogun Development Team.
30  *
31  * Code adapted from
32  * Gaussian Process Machine Learning Toolbox
33  * http://www.gaussianprocess.org/gpml/code/matlab/doc/
34  * and
35  * https://gist.github.com/yorkerlin/8a36e8f9b298aa0246a4
36  */
37 
38 
39 #include <shogun/lib/config.h>
43 
44 using namespace shogun;
45 
48 {
49 }
50 
52  CInference* method) : CGaussianProcessMachine(method)
53 {
54  // set labels
55  m_labels=method->get_labels();
56 }
57 
59 {
60 }
61 
63 {
64  // check whether given combination of inference method and likelihood
65  // function supports classification
66  REQUIRE(m_method, "Inference method should not be NULL\n")
68  REQUIRE(m_method->supports_multiclass(), "%s with %s doesn't support "
69  "multi classification\n", m_method->get_name(), lik->get_name())
70  SG_UNREF(lik);
71 
72  // if regression data equals to NULL, then apply classification on training
73  // features
74  if (!data)
75  {
77  {
79  }
80  else
81  data=m_method->get_features();
82  }
83  else
84  SG_REF(data);
85 
86  const index_t n=data->get_num_vectors();
88  const index_t C=mean.vlen/n;
89  SGVector<index_t> lab(n);
90  for (index_t idx=0; idx<n; idx++)
91  {
92  int32_t cate=CMath::arg_max(mean.vector+idx*C, 1, C);
93  lab[idx]=cate;
94  }
96  result->set_int_labels(lab);
97 
98  SG_UNREF(data);
99 
100  return result;
101 }
102 
104  CFeatures* data)
105 {
106  // check whether given combination of inference method and likelihood
107  // function supports classification
108  REQUIRE(m_method, "Inference method should not be NULL\n")
110  REQUIRE(m_method->supports_binary(), "%s with %s doesn't support "
111  "binary classification\n", m_method->get_name(), lik->get_name())
112  SG_UNREF(lik);
113 
114  // if regression data equals to NULL, then apply classification on training
115  // features
116  if (!data)
117  {
119  {
122  data=fitc_method->get_inducing_features();
123  SG_UNREF(fitc_method);
124  }
125  else
126  data=m_method->get_features();
127  }
128  else
129  SG_REF(data);
130 
131  CBinaryLabels* result=new CBinaryLabels(get_mean_vector(data));
132  SG_UNREF(data);
133 
134  return result;
135 }
136 
138 {
139  // check whether given combination of inference method and likelihood
140  // function supports classification
141  REQUIRE(m_method, "Inference method should not be NULL\n")
143  REQUIRE(m_method->supports_binary() || m_method->supports_multiclass(), "%s with %s doesn't support "
144  "classification\n", m_method->get_name(), lik->get_name())
145  SG_UNREF(lik);
146 
147  if (data)
148  {
149  // set inducing features for FITC inference method
151  {
154  fitc_method->set_inducing_features(data);
155  SG_UNREF(fitc_method);
156  }
157  else
158  m_method->set_features(data);
159  }
160 
161  // perform inference
162  m_method->update();
163 
164  return true;
165 }
166 
168  CFeatures* data)
169 {
170  // check whether given combination of inference method and likelihood
171  // function supports classification
172  REQUIRE(m_method, "Inference method should not be NULL\n")
175  "%s with %s doesn't support classification\n", m_method->get_name(), lik->get_name())
176 
177  SG_REF(data);
180  SG_UNREF(data);
181 
182  // evaluate mean
183  mu=lik->get_predictive_means(mu, s2);
184  SG_UNREF(lik);
185 
186  return mu;
187 }
188 
190  CFeatures* data)
191 {
192  // check whether given combination of inference method and
193  // likelihood function supports classification
194  REQUIRE(m_method, "Inference method should not be NULL\n")
197  "%s with %s doesn't support classification\n", m_method->get_name(), lik->get_name())
198 
199  SG_REF(data);
202  SG_UNREF(data);
203 
204  // evaluate variance
205  s2=lik->get_predictive_variances(mu, s2);
206  SG_UNREF(lik);
207 
208  return s2;
209 }
210 
212  CFeatures* data)
213 {
214  // check whether given combination of inference method and likelihood
215  // function supports classification
216  REQUIRE(m_method, "Inference method should not be NULL\n")
219  "%s with %s doesn't support classification\n", m_method->get_name(), lik->get_name())
220 
221  SG_REF(data);
224  SG_UNREF(data);
225 
226  // evaluate log probabilities
228  SG_UNREF(lik);
229 
230  // evaluate probabilities
231  for (index_t idx=0; idx<p.vlen; idx++)
232  p[idx]=CMath::exp(p[idx]);
233 
234  return p;
235 }
virtual const char * get_name() const =0
virtual void update()
Definition: Inference.cpp:316
virtual void set_inducing_features(CFeatures *feat)
SGVector< float64_t > get_variance_vector(CFeatures *data)
void set_int_labels(SGVector< int32_t > labels)
static int32_t arg_max(T *vec, int32_t inc, int32_t len, T *maxv_ptr=NULL)
Definition: Math.h:262
int32_t index_t
Definition: common.h:62
A base class for Gaussian Processes.
virtual EInferenceType get_inference_type() const
Definition: Inference.h:104
SGVector< float64_t > get_posterior_variances(CFeatures *data)
virtual bool supports_binary() const
Definition: Inference.h:371
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:361
#define REQUIRE(x,...)
Definition: SGIO.h:206
#define SG_NOTIMPLEMENTED
Definition: SGIO.h:139
SGVector< float64_t > get_posterior_means(CFeatures *data)
virtual SGVector< float64_t > get_predictive_variances(SGVector< float64_t > mu, SGVector< float64_t > s2, const CLabels *lab=NULL) const =0
virtual CLabels * get_labels()
Definition: Inference.h:317
virtual CFeatures * get_features()
Definition: Inference.h:266
static CSingleFITCLaplaceInferenceMethod * obtain_from_generic(CInference *inference)
#define SG_REF(x)
Definition: SGObject.h:54
SGVector< float64_t > get_mean_vector(CFeatures *data)
Multiclass Labels for multi-class classification.
SGVector< float64_t > get_probabilities(CFeatures *data)
index_t vlen
Definition: SGVector.h:494
virtual SGVector< float64_t > get_predictive_log_probabilities(SGVector< float64_t > mu, SGVector< float64_t > s2, const CLabels *lab=NULL)
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
virtual CFeatures * get_inducing_features()
CLikelihoodModel * get_model()
Definition: Inference.h:334
virtual bool train_machine(CFeatures *data=NULL)
virtual void set_features(CFeatures *feat)
Definition: Inference.h:272
#define SG_UNREF(x)
Definition: SGObject.h:55
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The Inference Method base class.
Definition: Inference.h:81
The class Features is the base class of all feature objects.
Definition: Features.h:68
static float64_t exp(float64_t x)
Definition: Math.h:621
virtual SGVector< float64_t > get_predictive_means(SGVector< float64_t > mu, SGVector< float64_t > s2, const CLabels *lab=NULL) const =0
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
virtual CMulticlassLabels * apply_multiclass(CFeatures *data=NULL)
The FITC approximation inference method class for regression and binary Classification. Note that the number of inducing points (m) is usually far less than the number of input points (n). (the time complexity is computed based on the assumption m < n)
virtual bool supports_multiclass() const
Definition: Inference.h:378
The Likelihood model base class.

SHOGUN Machine Learning Toolbox - Documentation