InferenceMethod.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Jacob Walker
00008  */
00009 
00010 #include <shogun/regression/gp/InferenceMethod.h>
00011 #ifdef HAVE_EIGEN3
00012 #include <shogun/mathematics/lapack.h>
00013 #include <shogun/mathematics/Math.h>
00014 #include <shogun/labels/RegressionLabels.h>
00015 #include <shogun/kernel/GaussianKernel.h>
00016 #include <shogun/features/CombinedFeatures.h>
00017 
00018 using namespace shogun;
00019 
00020 CInferenceMethod::CInferenceMethod()
00021 {
00022     init();
00023 
00024     m_kernel = NULL;
00025     m_model = NULL;
00026     m_labels = NULL;
00027     m_features = NULL;
00028     m_latent_features = NULL;
00029     m_mean = NULL;
00030 }
00031 
00032 CInferenceMethod::CInferenceMethod(CKernel* kern, CFeatures* feat,
00033         CMeanFunction* m, CLabels* lab, CLikelihoodModel* mod)
00034 {
00035     init();
00036 
00037     set_kernel(kern);
00038     set_features(feat);
00039     set_labels(lab);
00040     set_model(mod);
00041     set_mean(m);
00042 }
00043 
00044 CInferenceMethod::~CInferenceMethod() 
00045 {
00046     SG_UNREF(m_kernel);
00047     SG_UNREF(m_features);
00048     SG_UNREF(m_latent_features);
00049     SG_UNREF(m_labels);
00050     SG_UNREF(m_model);
00051     SG_UNREF(m_mean);
00052 }
00053 
00054 void CInferenceMethod::init()
00055 {
00056     SG_ADD((CSGObject**)&m_kernel, "kernel", "Kernel", MS_AVAILABLE);
00057     SG_ADD(&m_scale, "scale", "Kernel Scale", MS_AVAILABLE);
00058     SG_ADD((CSGObject**)&m_model, "likelihood_model", "Likelihood model",
00059             MS_AVAILABLE);
00060     SG_ADD((CSGObject**)&m_labels, "labels", "Labels", MS_NOT_AVAILABLE);
00061     SG_ADD((CSGObject**)&m_features, "features", "Features", MS_NOT_AVAILABLE);
00062     SG_ADD((CSGObject**)&m_latent_features, "latent_features", "latent Features", MS_NOT_AVAILABLE);
00063     SG_ADD((CSGObject**)&m_mean, "mean_function", "Mean Function", MS_NOT_AVAILABLE);
00064 
00065     m_kernel = NULL;
00066     m_model = NULL;
00067     m_labels = NULL;
00068     m_features = NULL;
00069     m_latent_features = NULL;
00070     m_mean = NULL;
00071     m_scale = 1.0;
00072 }
00073 
00074 void CInferenceMethod::set_features(CFeatures* feat)
00075 {
00076     SG_REF(feat);
00077     SG_UNREF(m_features);
00078     m_features=feat;
00079 
00080     if (m_features && m_features->has_property(FP_DOT) && m_features->get_num_vectors())
00081         m_feature_matrix =
00082                 ((CDotFeatures*)m_features)->get_computed_dot_feature_matrix();
00083 
00084     else if (m_features && m_features->get_feature_class() == C_COMBINED)
00085     {
00086         CDotFeatures* subfeat =
00087                 (CDotFeatures*)((CCombinedFeatures*)m_features)->
00088                 get_first_feature_obj();
00089         
00090         if (m_features->get_num_vectors())
00091             m_feature_matrix = subfeat->get_computed_dot_feature_matrix();
00092 
00093         SG_UNREF(subfeat);
00094     }
00095 
00096     update_data_means();
00097     update_train_kernel();
00098     update_chol();
00099     update_alpha();
00100 }
00101 
00102 void CInferenceMethod::set_latent_features(CFeatures* feat)
00103 {
00104     SG_REF(feat);
00105     SG_UNREF(m_latent_features);
00106     m_latent_features=feat;
00107 
00108     if (m_latent_features && m_latent_features->has_property(FP_DOT) && m_latent_features->get_num_vectors())
00109         m_latent_matrix =
00110                 ((CDotFeatures*)m_latent_features)->get_computed_dot_feature_matrix();
00111 
00112     else if (m_latent_features && m_latent_features->get_feature_class() == C_COMBINED)
00113     {
00114         CDotFeatures* subfeat =
00115                 (CDotFeatures*)((CCombinedFeatures*)m_latent_features)->
00116                 get_first_feature_obj();
00117 
00118         if (m_latent_features->get_num_vectors())
00119             m_latent_matrix = subfeat->get_computed_dot_feature_matrix();
00120 
00121         SG_UNREF(subfeat);
00122     }
00123 
00124     update_data_means();
00125     update_train_kernel();
00126     update_chol();
00127     update_alpha();
00128 }
00129 
00130 void CInferenceMethod::set_kernel(CKernel* kern)
00131 {
00132     SG_REF(kern);
00133     SG_UNREF(m_kernel);
00134     m_kernel = kern;
00135     update_train_kernel();
00136     update_chol();
00137     update_alpha();
00138 }
00139 
00140 void CInferenceMethod::set_mean(CMeanFunction* m)
00141 {
00142     SG_REF(m);
00143     SG_UNREF(m_mean);
00144     m_mean = m;
00145 
00146     update_data_means();
00147     update_chol();
00148     update_alpha();
00149 }
00150 
00151 void CInferenceMethod::set_labels(CLabels* lab)
00152 {
00153     SG_REF(lab);
00154     SG_UNREF(m_labels);
00155     m_labels = lab;
00156 
00157     if (m_labels)
00158     {
00159         m_label_vector =
00160             ((CRegressionLabels*) m_labels)->get_labels().clone();
00161     }
00162 
00163     update_data_means();
00164     update_alpha();
00165 }
00166 
00167 void CInferenceMethod::set_model(CLikelihoodModel* mod)
00168 {
00169     SG_REF(mod);
00170     SG_UNREF(m_model);
00171     m_model = mod;
00172     update_train_kernel();
00173     update_chol();
00174     update_alpha();
00175 }
00176 
00177 void CInferenceMethod::set_scale(float64_t s)
00178 {
00179     update_train_kernel();
00180     m_scale = s;
00181     update_chol();
00182     update_alpha();
00183 }
00184 
00185 void CInferenceMethod::update_data_means()
00186 {
00187     if (m_mean)
00188     {
00189         m_data_means =
00190             m_mean->get_mean_vector(m_feature_matrix);
00191 
00192 
00193         if (m_label_vector.vlen == m_data_means.vlen)
00194         {
00195             for (index_t i = 0; i < m_label_vector.vlen; i++)
00196                 m_label_vector[i] -= m_data_means[i];
00197         }
00198     }
00199 }
00200 #endif /* HAVE_EIGEN3 */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation