00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Jacob Walker 00008 */ 00009 00010 #ifndef CEXACTINFERENCEMETHOD_H_ 00011 #define CEXACTINFERENCEMETHOD_H_ 00012 00013 #include <shogun/lib/config.h> 00014 #include <shogun/regression/gp/InferenceMethod.h> 00015 #ifdef HAVE_EIGEN3 00016 #ifdef HAVE_LAPACK 00017 namespace shogun 00018 { 00019 00020 class CInferenceMethod; 00021 00045 class CExactInferenceMethod: public CInferenceMethod 00046 { 00047 00048 public: 00049 00051 CExactInferenceMethod(); 00052 00053 /* Constructor 00054 * @param kernel covariance function 00055 * @param features features to use in inference 00056 * @param labels labels of the features 00057 * @param model Likelihood model to use 00058 */ 00059 CExactInferenceMethod(CKernel* kernel, CFeatures* features, 00060 CMeanFunction* mean, CLabels* labels, CLikelihoodModel* model); 00061 00063 virtual ~CExactInferenceMethod(); 00064 00074 virtual float64_t get_negative_marginal_likelihood(); 00075 00084 virtual CMap<TParameter*, SGVector<float64_t> > get_marginal_likelihood_derivatives( 00085 CMap<TParameter*, CSGObject*>& para_dict); 00086 00096 virtual SGVector<float64_t> get_alpha(); 00097 00098 00109 virtual SGMatrix<float64_t> get_cholesky(); 00110 00121 virtual SGVector<float64_t> get_diagonal_vector(); 00122 00128 virtual const char* get_name() const 00129 { 00130 return "ExactInferenceMethod"; 00131 } 00132 00138 virtual CMap<TParameter*, SGVector<float64_t> > get_gradient( 00139 CMap<TParameter*, CSGObject*>& para_dict) 00140 { 00141 return get_marginal_likelihood_derivatives(para_dict); 00142 } 00143 00148 virtual SGVector<float64_t> get_quantity() 00149 { 00150 SGVector<float64_t> result(1); 00151 result[0] = get_negative_marginal_likelihood(); 00152 return result; 00153 } 00154 00155 protected: 00157 virtual void update_alpha(); 00158 00160 virtual void update_chol(); 00161 00163 virtual void update_train_kernel(); 00164 00166 virtual void update_all(); 00167 00168 private: 00169 00173 void check_members(); 00174 00176 SGMatrix<float64_t> m_kern_with_noise; 00177 }; 00178 00179 } 00180 #endif 00181 #endif // HAVE_LAPACK 00182 00183 #endif /* CEXACTINFERENCEMETHOD_H_ */