Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/latent/LatentModel.h>
00012 #include <shogun/labels/BinaryLabels.h>
00013
00014 using namespace shogun;
00015
00016 CLatentModel::CLatentModel()
00017 : m_features(NULL),
00018 m_labels(NULL),
00019 m_do_caching(false),
00020 m_cached_psi(NULL)
00021 {
00022 register_parameters();
00023 }
00024
00025 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching)
00026 : m_features(feats),
00027 m_labels(labels),
00028 m_do_caching(do_caching),
00029 m_cached_psi(NULL)
00030 {
00031 register_parameters();
00032 SG_REF(m_features);
00033 SG_REF(m_labels);
00034 }
00035
00036 CLatentModel::~CLatentModel()
00037 {
00038 SG_UNREF(m_labels);
00039 SG_UNREF(m_features);
00040 SG_UNREF(m_cached_psi);
00041 }
00042
00043 int32_t CLatentModel::get_num_vectors() const
00044 {
00045 return m_features->get_num_vectors();
00046 }
00047
00048 void CLatentModel::set_labels(CLatentLabels* labs)
00049 {
00050 SG_UNREF(m_labels);
00051 SG_REF(labs);
00052 m_labels = labs;
00053 }
00054
00055 CLatentLabels* CLatentModel::get_labels() const
00056 {
00057 SG_REF(m_labels);
00058 return m_labels;
00059 }
00060
00061 void CLatentModel::set_features(CLatentFeatures* feats)
00062 {
00063 SG_UNREF(m_features);
00064 SG_REF(feats);
00065 m_features = feats;
00066 }
00067
00068 void CLatentModel::argmax_h(const SGVector<float64_t>& w)
00069 {
00070 int32_t num = get_num_vectors();
00071 CBinaryLabels* y = CBinaryLabels::obtain_from_generic(m_labels->get_labels());
00072 ASSERT(num > 0);
00073 ASSERT(num == m_labels->get_num_labels());
00074
00075
00076
00077 for (int32_t i = 0; i < num; ++i)
00078 {
00079 if (y->get_label(i) == 1)
00080 {
00081
00082 CData* latent_data = infer_latent_variable(w, i);
00083 m_labels->set_latent_label(i, latent_data);
00084 }
00085 }
00086 }
00087
00088 void CLatentModel::register_parameters()
00089 {
00090 m_parameters->add((CSGObject**) &m_features, "features", "Latent features");
00091 m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels");
00092 m_parameters->add((CSGObject**) &m_cached_psi, "cached psi", "Cached PSI features after argmax_h");
00093 m_parameters->add(&m_do_caching, "do caching", "Indicate whether or not do PSI vector caching after argmax_h");
00094 }
00095
00096
00097 CLatentFeatures* CLatentModel::get_features() const
00098 {
00099 SG_REF(m_features);
00100 return m_features;
00101 }
00102
00103 void CLatentModel::cache_psi_features()
00104 {
00105 if (m_do_caching)
00106 {
00107 if (m_cached_psi)
00108 SG_UNREF(m_cached_psi);
00109 m_cached_psi = this->get_psi_feature_vectors();
00110 SG_REF(m_cached_psi);
00111 }
00112 }
00113
00114 CDotFeatures* CLatentModel::get_cached_psi_features() const
00115 {
00116 if (m_do_caching)
00117 {
00118 SG_REF(m_cached_psi);
00119 return m_cached_psi;
00120 }
00121 return NULL;
00122 }