00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Viktor Gal 00008 * Copyright (C) 2012 Viktor Gal 00009 */ 00010 00011 #ifndef __LATENTMODEL_H__ 00012 #define __LATENTMODEL_H__ 00013 00014 #include <shogun/labels/LatentLabels.h> 00015 #include <shogun/features/LatentFeatures.h> 00016 #include <shogun/features/DotFeatures.h> 00017 #include <shogun/features/DenseFeatures.h> 00018 00019 namespace shogun 00020 { 00031 class CLatentModel: public CSGObject 00032 { 00033 public: 00035 CLatentModel(); 00036 00043 CLatentModel(CLatentFeatures* feats, CLatentLabels* labels, bool do_caching = true); 00044 00046 virtual ~CLatentModel(); 00047 00052 virtual int32_t get_num_vectors() const; 00053 00058 virtual int32_t get_dim() const=0; 00059 00064 void set_labels(CLatentLabels* labs); 00065 00070 CLatentLabels* get_labels() const; 00071 00076 void set_features(CLatentFeatures* feats); 00077 00082 CLatentFeatures* get_features() const; 00083 00088 virtual CDotFeatures* get_psi_feature_vectors()=0; 00089 00098 virtual CData* infer_latent_variable(const SGVector<float64_t>& w, index_t idx)=0; 00099 00105 virtual void argmax_h(const SGVector<float64_t>& w); 00106 00110 void cache_psi_features(); 00111 00116 CDotFeatures* get_cached_psi_features() const; 00117 00122 inline bool get_caching() const 00123 { 00124 return m_do_caching; 00125 } 00126 00131 inline void set_caching(bool caching) 00132 { 00133 m_do_caching = caching; 00134 } 00135 00140 virtual const char* get_name() const { return "LatentModel"; } 00141 00142 protected: 00144 CLatentFeatures* m_features; 00146 CLatentLabels* m_labels; 00148 bool m_do_caching; 00150 CDotFeatures* m_cached_psi; 00151 00152 private: 00154 void register_parameters(); 00155 }; 00156 } 00157 00158 #endif /* __LATENTMODEL_H__ */ 00159