GaussianProcessRegression.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Jacob Walker
00008  */
00009 
00010 #ifndef _GAUSSIANPROCESSREGRESSION_H__
00011 #define _GAUSSIANPROCESSREGRESSION_H__
00012 
00013 #include <shogun/lib/config.h>
00014 #ifdef HAVE_EIGEN3
00015 #ifdef HAVE_LAPACK
00016 
00017 #include <shogun/regression/Regression.h>
00018 #include <shogun/machine/Machine.h>
00019 #include <shogun/features/DenseFeatures.h>
00020 #include <shogun/regression/gp/InferenceMethod.h>
00021 
00022 namespace shogun
00023 {
00024 
00025 class CInferenceMethod;
00026 class CFeatures;
00027 class CLabels;
00028 
00034 class CGaussianProcessRegression : public CMachine
00035 {
00036 
00037     public:
00039         MACHINE_PROBLEM_TYPE(PT_REGRESSION);
00040 
00042         enum EGPReturnType
00043         {
00044             GP_RETURN_MEANS,
00045             GP_RETURN_COV,
00046             GP_RETURN_BOTH
00047         };
00048 
00055         CGaussianProcessRegression(CInferenceMethod* inf,
00056                        CFeatures* data, CLabels* lab);
00057 
00059         CGaussianProcessRegression();
00060 
00061         virtual ~CGaussianProcessRegression();
00062         
00067         virtual void set_features(CFeatures* feat)
00068         {
00069             SG_UNREF(m_features);
00070             SG_REF(feat);
00071             m_features = feat;
00072             update_kernel_matrices();
00073         }
00074         
00079         virtual CFeatures* get_features()
00080         {
00081             SG_REF(m_features);
00082             return m_features;
00083         }
00084         
00089         inline void set_method(CInferenceMethod* inf)
00090         {
00091             SG_UNREF(m_method);
00092             SG_REF(inf);
00093             m_method = inf;
00094         };
00095         
00100         inline CInferenceMethod* get_method()
00101         {
00102             SG_REF(m_method);
00103             return m_method;
00104         };
00105             
00111         virtual bool load(FILE* srcfile);
00112         
00118         virtual bool save(FILE* dstfile);
00119 
00124         void set_kernel(CKernel* k);
00125         
00130         CKernel* get_kernel();
00131         
00137         virtual CRegressionLabels* apply_regression(CFeatures* data = NULL);
00138         
00143         virtual EMachineType get_classifier_type()
00144         {
00145           return CT_GAUSSIANPROCESSREGRESSION;
00146         }
00147         
00152         SGVector<float64_t> get_covariance_vector();
00153         
00158         SGVector<float64_t> get_mean_vector();
00159 
00161         virtual const char* get_name() const
00162         {
00163             return "GaussianProcessRegression";
00164         }
00165 
00170         inline void set_return_type(EGPReturnType t)
00171         {
00172             m_return = t;
00173         };
00174 
00180         inline EGPReturnType get_return_type()
00181         {
00182             return m_return;
00183         };
00184 
00185     
00186     protected:
00193         virtual bool train_machine(CFeatures* data = NULL);
00194     private:
00195 
00197         void init();
00198 
00199         /* Update kernel matrices */
00200         void update_kernel_matrices();
00201 
00202     private:
00203 
00205         CFeatures* m_features;
00206         
00208         CFeatures* m_data;
00209 
00210         /*Kernel matrix from testing and training
00211          * features
00212          */
00213         SGMatrix<float64_t> m_k_trts;
00214 
00215         /*Kernel matrix from testing
00216          * features
00217          */
00218         SGMatrix<float64_t> m_k_tsts;
00219 
00221         CInferenceMethod* m_method;
00222 
00223         /*What should apply_regression return?*/
00224         EGPReturnType m_return;
00225 };
00226 
00227 }
00228 
00229 #endif 
00230 #endif
00231 #endif /* _GAUSSIANPROCESSREGRESSION_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation