SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
LatentModel.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012 Viktor Gal
8  * Copyright (C) 2012 Viktor Gal
9  */
10 
13 
14 using namespace shogun;
15 
17  : m_features(NULL),
18  m_labels(NULL),
19  m_do_caching(false),
20  m_cached_psi(NULL)
21 {
22  register_parameters();
23 }
24 
25 CLatentModel::CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching)
26  : m_features(feats),
27  m_labels(labels),
28  m_do_caching(do_caching),
29  m_cached_psi(NULL)
30 {
31  register_parameters();
34 }
35 
37 {
41 }
42 
44 {
45  return m_features->get_num_vectors();
46 }
47 
49 {
50  SG_REF(labs);
52  m_labels = labs;
53 }
54 
56 {
58  return m_labels;
59 }
60 
62 {
63  SG_REF(feats);
65  m_features = feats;
66 }
67 
69 {
70  int32_t num = get_num_vectors();
72  ASSERT(num > 0)
73  ASSERT(num == m_labels->get_num_labels())
74 
75  // argmax_h only for positive examples
76  for (int32_t i = 0; i < num; ++i)
77  {
78  if (y->get_label(i) == 1)
79  {
80  // infer h and set it for the argmax_h <w,psi(x,h)>
81  CData* latent_data = infer_latent_variable(w, i);
82  m_labels->set_latent_label(i, latent_data);
83  }
84  }
85 }
86 
87 void CLatentModel::register_parameters()
88 {
89  m_parameters->add((CSGObject**) &m_features, "features", "Latent features");
90  m_parameters->add((CSGObject**) &m_labels, "labels", "Latent labels");
91  m_parameters->add((CSGObject**) &m_cached_psi, "cached_psi", "Cached PSI features after argmax_h");
92  m_parameters->add(&m_do_caching, "do_caching", "Indicate whether or not do PSI vector caching after argmax_h");
93 }
94 
95 
97 {
99  return m_features;
100 }
101 
103 {
104  if (m_do_caching)
105  {
106  if (m_cached_psi)
110  }
111 }
112 
114 {
115  if (m_do_caching)
116  {
118  return m_cached_psi;
119  }
120  return NULL;
121 }
CDotFeatures * m_cached_psi
Definition: LatentModel.h:152
Latent Features class The class if for representing features for latent learning, e...
virtual int32_t get_num_labels() const
CLatentLabels * get_labels() const
Definition: LatentModel.cpp:55
CLatentFeatures * get_features() const
Definition: LatentModel.cpp:96
Parameter * m_parameters
Definition: SGObject.h:546
float64_t get_label(int32_t idx)
Features that support dot products among other operations.
Definition: DotFeatures.h:44
#define SG_REF(x)
Definition: SGObject.h:54
void set_labels(CLatentLabels *labs)
Definition: LatentModel.cpp:48
static CBinaryLabels * to_binary(CLabels *base_labels)
void add(bool *param, const char *name, const char *description="")
Definition: Parameter.cpp:37
#define ASSERT(x)
Definition: SGIO.h:201
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
CLatentLabels * m_labels
Definition: LatentModel.h:148
dummy data holder
Definition: Data.h:25
CLabels * get_labels() const
virtual int32_t get_num_vectors() const
virtual CData * infer_latent_variable(const SGVector< float64_t > &w, index_t idx)=0
CLatentFeatures * m_features
Definition: LatentModel.h:146
virtual void argmax_h(const SGVector< float64_t > &w)
Definition: LatentModel.cpp:68
#define SG_UNREF(x)
Definition: SGObject.h:55
CDotFeatures * get_cached_psi_features() const
virtual CDotFeatures * get_psi_feature_vectors()=0
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
virtual int32_t get_num_vectors() const
Definition: LatentModel.cpp:43
void set_features(CLatentFeatures *feats)
Definition: LatentModel.cpp:61
abstract class for latent labels As latent labels always depends on the given application, this class only defines the API that the user has to implement for latent labels.
Definition: LatentLabels.h:26
bool set_latent_label(int32_t idx, CData *label)

SHOGUN Machine Learning Toolbox - Documentation