Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/structure/StructuredModel.h>
00012
00013 using namespace shogun;
00014
00015 CStructuredModel::CStructuredModel() : CSGObject()
00016 {
00017 init();
00018 }
00019
00020 CStructuredModel::CStructuredModel(
00021 CFeatures* features,
00022 CStructuredLabels* labels)
00023 : CSGObject()
00024 {
00025 init();
00026
00027 m_features = features;
00028 m_labels = labels;
00029
00030 SG_REF(features);
00031 SG_REF(labels);
00032 }
00033
00034 CStructuredModel::~CStructuredModel()
00035 {
00036 SG_UNREF(m_labels);
00037 SG_UNREF(m_features);
00038 }
00039
00040 void CStructuredModel::init_opt(
00041 SGMatrix< float64_t > & A,
00042 SGVector< float64_t > a,
00043 SGMatrix< float64_t > B,
00044 SGVector< float64_t > & b,
00045 SGVector< float64_t > lb,
00046 SGVector< float64_t > ub,
00047 SGMatrix< float64_t > & C)
00048 {
00049 SG_ERROR("init_opt is not implemented for %s!\n", get_name());
00050 }
00051
00052 void CStructuredModel::set_labels(CStructuredLabels* labels)
00053 {
00054 SG_UNREF(m_labels);
00055 SG_REF(labels);
00056 m_labels = labels;
00057 }
00058
00059 CStructuredLabels* CStructuredModel::get_labels()
00060 {
00061 SG_REF(m_labels);
00062 return m_labels;
00063 }
00064
00065 void CStructuredModel::set_features(CFeatures* features)
00066 {
00067 SG_UNREF(m_features);
00068 SG_REF(features);
00069 m_features = features;
00070 }
00071
00072 CFeatures* CStructuredModel::get_features()
00073 {
00074 SG_REF(m_features);
00075 return m_features;
00076 }
00077
00078 SGVector< float64_t > CStructuredModel::get_joint_feature_vector(
00079 int32_t feat_idx,
00080 int32_t lab_idx)
00081 {
00082 CStructuredData* label = m_labels->get_label(lab_idx);
00083 SGVector< float64_t > ret = get_joint_feature_vector(feat_idx, label);
00084 SG_UNREF(label);
00085
00086 return ret;
00087 }
00088
00089 SGVector< float64_t > CStructuredModel::get_joint_feature_vector(
00090 int32_t feat_idx,
00091 CStructuredData* y)
00092 {
00093 SG_ERROR("compute_joint_feature(int32_t, CStructuredData*) is not "
00094 "implemented for %s!\n", get_name());
00095
00096 return SGVector< float64_t >();
00097 }
00098
00099 float64_t CStructuredModel::delta_loss(int32_t ytrue_idx, CStructuredData* ypred)
00100 {
00101 REQUIRE(ytrue_idx >= 0 || ytrue_idx < m_labels->get_num_labels(),
00102 "The label index must be inside [0, num_labels-1]\n");
00103
00104 CStructuredData* ytrue = m_labels->get_label(ytrue_idx);
00105 float64_t ret = delta_loss(ytrue, ypred);
00106 SG_UNREF(ytrue);
00107
00108 return ret;
00109 }
00110
00111 float64_t CStructuredModel::delta_loss(CStructuredData* y1, CStructuredData* y2)
00112 {
00113 SG_ERROR("delta_loss(CStructuredData*, CStructuredData*) is not "
00114 "implemented for %s!\n", get_name());
00115
00116 return 0.0;
00117 }
00118
00119 void CStructuredModel::init()
00120 {
00121 SG_ADD((CSGObject**) &m_labels, "m_labels", "Structured labels",
00122 MS_NOT_AVAILABLE);
00123 SG_ADD((CSGObject**) &m_features, "m_features", "Feature vectors",
00124 MS_NOT_AVAILABLE);
00125
00126 m_features = NULL;
00127 m_labels = NULL;
00128 }
00129
00130 bool CStructuredModel::check_training_setup() const
00131 {
00132
00133 return true;
00134 }
00135
00136 int32_t CStructuredModel::get_num_aux() const
00137 {
00138 return 0;
00139 }
00140
00141 int32_t CStructuredModel::get_num_aux_con() const
00142 {
00143 return 0;
00144 }
00145
00146 float64_t CStructuredModel::risk(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00147 {
00148 int32_t from=0, to=0;
00149 if (info)
00150 {
00151 from = info->_from;
00152 to = (info->N == 0) ? m_features->get_num_vectors() : from+info->N;
00153 }
00154 else
00155 {
00156 from = 0;
00157 to = m_features->get_num_vectors();
00158 }
00159
00160 int32_t dim = this->get_dim();
00161 float64_t R = 0.0;
00162 for (int32_t i=0; i<dim; i++)
00163 subgrad[i] = 0;
00164
00165 for (int32_t i=from; i<to; i++)
00166 {
00167 CResultSet* result = this->argmax(SGVector<float64_t>(W,dim,false), i, true);
00168 SGVector<float64_t> psi_pred = result->psi_pred;
00169 SGVector<float64_t> psi_truth = result->psi_truth;
00170 SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, 1.0, psi_pred.vector, dim);
00171 SGVector<float64_t>::vec1_plus_scalar_times_vec2(subgrad, -1.0, psi_truth.vector, dim);
00172 R += result->score;
00173 SG_UNREF(result);
00174 }
00175
00176 return R;
00177 }