LatentModel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Viktor Gal
00008  * Copyright (C) 2012 Viktor Gal
00009  */
00010 
00011 #include <shogun/latent/LatentModel.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013 
00014 using namespace shogun;
00015 
00016 CLatentModel::CLatentModel()
00017     : m_features(NULL),
00018     m_labels(NULL),
00019     m_do_caching(false),
00020     m_cached_psi(NULL)
00021 {
00022     register_parameters();
00023 }
00024 
00025 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching)
00026     : m_features(feats),
00027     m_labels(labels),
00028     m_do_caching(do_caching),
00029     m_cached_psi(NULL)
00030 {
00031     register_parameters();
00032     SG_REF(m_features);
00033     SG_REF(m_labels);
00034 }
00035 
00036 CLatentModel::~CLatentModel()
00037 {
00038     SG_UNREF(m_labels);
00039     SG_UNREF(m_features);
00040     SG_UNREF(m_cached_psi);
00041 }
00042 
00043 int32_t CLatentModel::get_num_vectors() const
00044 {
00045     return m_features->get_num_vectors();
00046 }
00047 
00048 void CLatentModel::set_labels(CLatentLabels* labs)
00049 {
00050     SG_UNREF(m_labels);
00051     SG_REF(labs);
00052     m_labels = labs;
00053 }
00054 
00055 CLatentLabels* CLatentModel::get_labels() const
00056 {
00057     SG_REF(m_labels);
00058     return m_labels;
00059 }
00060 
00061 void CLatentModel::set_features(CLatentFeatures* feats)
00062 {
00063     SG_UNREF(m_features);
00064     SG_REF(feats);
00065     m_features = feats;
00066 }
00067 
00068 void CLatentModel::argmax_h(const SGVector<float64_t>& w)
00069 {
00070     int32_t num = get_num_vectors();
00071     CBinaryLabels* y = CBinaryLabels::obtain_from_generic(m_labels->get_labels());
00072     ASSERT(num > 0);
00073     ASSERT(num == m_labels->get_num_labels());
00074     
00075 
00076     // argmax_h only for positive examples
00077     for (int32_t i = 0; i < num; ++i)
00078     {
00079         if (y->get_label(i) == 1)
00080         {
00081             // infer h and set it for the argmax_h <w,psi(x,h)>
00082             CData* latent_data = infer_latent_variable(w, i);
00083             m_labels->set_latent_label(i, latent_data);
00084         }
00085     }
00086 }
00087 
00088 void CLatentModel::register_parameters()
00089 {
00090     m_parameters->add((CSGObject**) &m_features, "features", "Latent features");
00091     m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels");
00092     m_parameters->add((CSGObject**) &m_cached_psi, "cached psi", "Cached PSI features after argmax_h");
00093     m_parameters->add(&m_do_caching, "do caching", "Indicate whether or not do PSI vector caching after argmax_h");
00094 }
00095 
00096 
00097 CLatentFeatures* CLatentModel::get_features() const
00098 {
00099     SG_REF(m_features);
00100     return m_features;
00101 }
00102 
00103 void CLatentModel::cache_psi_features()
00104 {
00105     if (m_do_caching)
00106     {
00107         if (m_cached_psi)
00108             SG_UNREF(m_cached_psi);
00109         m_cached_psi = this->get_psi_feature_vectors();
00110         SG_REF(m_cached_psi);
00111     }
00112 }
00113 
00114 CDotFeatures* CLatentModel::get_cached_psi_features() const
00115 {
00116     if (m_do_caching)
00117     {
00118         SG_REF(m_cached_psi);
00119         return m_cached_psi;
00120     }
00121     return NULL;
00122 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation