Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _STRUCTURED_MODEL__H__
00012 #define _STRUCTURED_MODEL__H__
00013
00014 #include <shogun/base/SGObject.h>
00015 #include <shogun/features/Features.h>
00016 #include <shogun/labels/StructuredLabels.h>
00017 #include <shogun/lib/SGVector.h>
00018 #include <shogun/lib/StructuredData.h>
00019
00020 namespace shogun
00021 {
00022
00023 #define IGNORE_IN_CLASSLIST
00024
00029 IGNORE_IN_CLASSLIST struct TMultipleCPinfo {
00031 uint32_t _from;
00033 uint32_t N;
00034 };
00035
00036 class CStructuredModel;
00037
00039 struct CResultSet : public CSGObject
00040 {
00042 CResultSet() : CSGObject(), argmax(NULL) { };
00043
00045 virtual ~CResultSet() { SG_UNREF(argmax) }
00046
00048 CStructuredData* argmax;
00049
00051 SGVector< float64_t > psi_truth;
00052
00054 SGVector< float64_t > psi_pred;
00055
00058 float64_t score;
00059
00061 float64_t delta;
00062
00064 virtual const char* get_name() const { return "ResultSet"; }
00065 };
00066
00077 class CStructuredModel : public CSGObject
00078 {
00079 public:
00081 CStructuredModel();
00082
00088 CStructuredModel(CFeatures* features, CStructuredLabels* labels);
00089
00091 virtual ~CStructuredModel();
00092
00103 virtual void init_opt(
00104 SGMatrix< float64_t > & A, SGVector< float64_t > a,
00105 SGMatrix< float64_t > B, SGVector< float64_t > & b,
00106 SGVector< float64_t > lb, SGVector< float64_t > ub,
00107 SGMatrix < float64_t > & C);
00108
00113 virtual int32_t get_dim() const = 0;
00114
00119 void set_labels(CStructuredLabels* labs);
00120
00125 CStructuredLabels* get_labels();
00126
00131 void set_features(CFeatures* feats);
00132
00137 CFeatures* get_features();
00138
00151 SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, int32_t lab_idx);
00152
00165 virtual SGVector< float64_t > get_joint_feature_vector(int32_t feat_idx, CStructuredData* y);
00166
00180 virtual CResultSet* argmax(SGVector< float64_t > w, int32_t feat_idx, bool const training = true) = 0;
00181
00189 float64_t delta_loss(int32_t ytrue_idx, CStructuredData* ypred);
00190
00198 virtual float64_t delta_loss(CStructuredData* y1, CStructuredData* y2);
00199
00201 virtual const char* get_name() const { return "StructuredModel"; }
00202
00210 virtual bool check_training_setup() const;
00211
00221 virtual int32_t get_num_aux() const;
00222
00232 virtual int32_t get_num_aux_con() const;
00233
00241 virtual float64_t risk(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info=0);
00242
00243 private:
00245 void init();
00246
00247 protected:
00249 CStructuredLabels* m_labels;
00250
00252 CFeatures* m_features;
00253
00254 };
00255
00256 }
00257
00258 #endif