00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/structure/HMSVMModel.h>
00012 #include <shogun/features/MatrixFeatures.h>
00013 #include <shogun/structure/TwoStateModel.h>
00014
00015 using namespace shogun;
00016
00017 CHMSVMModel::CHMSVMModel()
00018 : CStructuredModel()
00019 {
00020 init();
00021 }
00022
00023 CHMSVMModel::CHMSVMModel(CFeatures* features, CStructuredLabels* labels, EStateModelType smt, int32_t num_obs)
00024 : CStructuredModel(features, labels)
00025 {
00026 init();
00027
00028 m_num_obs = num_obs;
00029
00030 int32_t free_states = ((CHMSVMLabels*) m_labels)->get_num_states();
00031
00032 int32_t D = ((CMatrixFeatures< float64_t >*) m_features)->get_num_features();
00033 m_num_aux = free_states*D*(num_obs-1);
00034
00035 switch (smt)
00036 {
00037 case SMT_TWO_STATE:
00038 m_state_model = new CTwoStateModel();
00039 break;
00040 case SMT_UNKNOWN:
00041 default:
00042 SG_ERROR("The EStateModelType given is not valid\n");
00043 }
00044
00045 int32_t S = m_state_model->get_num_states();
00046 m_transmission_weights = SGMatrix< float64_t >(S,S);
00047 m_emission_weights = SGVector< float64_t >(S*D*m_num_obs);
00048 }
00049
00050 CHMSVMModel::~CHMSVMModel()
00051 {
00052 SG_UNREF(m_state_model);
00053 }
00054
00055 int32_t CHMSVMModel::get_dim() const
00056 {
00057
00058 int32_t S = ((CHMSVMLabels*) m_labels)->get_num_states();
00059
00060 int32_t D = ((CMatrixFeatures< float64_t >*) m_features)->get_num_features();
00061
00062 return S*(S + D*m_num_obs);
00063 }
00064
00065 SGVector< float64_t > CHMSVMModel::get_joint_feature_vector(
00066 int32_t feat_idx,
00067 CStructuredData* y)
00068 {
00069
00070 CMatrixFeatures< float64_t >* mf = (CMatrixFeatures< float64_t >*) m_features;
00071 int32_t D = mf->get_num_features();
00072
00073
00074 CSequence* label_seq = CSequence::obtain_from_generic(y);
00075
00076
00077 SGVector< float64_t > psi(get_dim());
00078 psi.zero();
00079
00080
00081 SGVector< int32_t > state_seq = m_state_model->labels_to_states(label_seq);
00082 m_transmission_weights.zero();
00083
00084 for ( int32_t i = 0 ; i < state_seq.vlen-1 ; ++i )
00085 m_transmission_weights(state_seq[i],state_seq[i+1]) += 1;
00086
00087 SGMatrix< float64_t > obs = mf->get_feature_vector(feat_idx);
00088 ASSERT(obs.num_rows == D && obs.num_cols == state_seq.vlen);
00089 m_emission_weights.zero();
00090 index_t aux_idx, weight_idx;
00091
00092 for ( int32_t f = 0 ; f < D ; ++f )
00093 {
00094 aux_idx = f*m_num_obs;
00095
00096 for ( int32_t j = 0 ; j < state_seq.vlen ; ++j )
00097 {
00098 weight_idx = aux_idx + state_seq[j]*D*m_num_obs + obs(f,j);
00099 m_emission_weights[weight_idx] += 1;
00100 }
00101 }
00102
00103 m_state_model->weights_to_vector(psi, m_transmission_weights, m_emission_weights,
00104 D, m_num_obs);
00105
00106 return psi;
00107 }
00108
00109 CResultSet* CHMSVMModel::argmax(
00110 SGVector< float64_t > w,
00111 int32_t feat_idx,
00112 bool const training)
00113 {
00114 int32_t dim = get_dim();
00115 ASSERT( w.vlen == get_dim() );
00116
00117
00118 CMatrixFeatures< float64_t >* mf = (CMatrixFeatures< float64_t >*) m_features;
00119 int32_t D = mf->get_num_features();
00120
00121 int32_t S = m_state_model->get_num_states();
00122
00123
00124 SGVector< float64_t > p = m_state_model->get_start_states();
00125
00126 SGVector< float64_t > q = m_state_model->get_stop_states();
00127
00128
00129
00130
00131 SGMatrix< float64_t > x = mf->get_feature_vector(feat_idx);
00132
00133 int32_t T = x.num_cols;
00134 SGMatrix< float64_t > E(S, T);
00135 E.zero();
00136 index_t em_idx;
00137 m_state_model->reshape_emission_params(m_emission_weights, w, D, m_num_obs);
00138
00139 for ( int32_t i = 0 ; i < T ; ++i )
00140 {
00141 for ( int32_t j = 0 ; j < D ; ++j )
00142 {
00143
00144 em_idx = j*m_num_obs + (index_t)CMath::round(x(j,i));
00145
00146 for ( int32_t s = 0 ; s < S ; ++s )
00147 E(s,i) += m_emission_weights[s*D*m_num_obs + em_idx];
00148 }
00149 }
00150
00151
00152 if ( training )
00153 {
00154 CSequence* ytrue =
00155 CSequence::obtain_from_generic(m_labels->get_label(feat_idx));
00156
00157 REQUIRE(ytrue->get_data().size() == T, "T, the length of the feature "
00158 "x^i (%d) and the length of its corresponding label y^i "
00159 "(%d) must be the same.\n", T, ytrue->get_data().size());
00160
00161 SGMatrix< float64_t > loss_matrix = m_state_model->loss_matrix(ytrue);
00162
00163 ASSERT(loss_matrix.num_rows == E.num_rows &&
00164 loss_matrix.num_cols == E.num_cols);
00165
00166 SGVector< float64_t >::add(E.matrix, 1.0, E.matrix,
00167 1.0, loss_matrix.matrix, E.num_rows*E.num_cols);
00168
00169
00170 SG_UNREF(ytrue);
00171 }
00172
00173
00174 SGMatrix< float64_t > dp(T, S);
00175 SGMatrix< float64_t > trb(T, S);
00176 m_state_model->reshape_transmission_params(m_transmission_weights, w);
00177
00178 for ( int32_t s = 0 ; s < S ; ++s )
00179 {
00180 if ( p[s] > -CMath::INFTY )
00181 {
00182
00183 dp(0,s) = E[s];
00184 }
00185 else
00186 {
00187 dp(0,s) = -CMath::INFTY;
00188 }
00189 }
00190
00191
00192 int32_t idx;
00193 float64_t tmp_score, e, a;
00194
00195 for ( int32_t i = 1 ; i < T ; ++i )
00196 {
00197 for ( int32_t cur = 0 ; cur < S ; ++cur )
00198 {
00199 idx = cur*T + i;
00200
00201 dp[idx] = -CMath::INFTY;
00202 trb[idx] = -1;
00203
00204
00205 e = E[i*S + cur];
00206
00207 for ( int32_t prev = 0 ; prev < S ; ++prev )
00208 {
00209
00210 a = m_transmission_weights[cur*S + prev];
00211
00212 if ( a > -CMath::INFTY )
00213 {
00214
00215 tmp_score = e + a + dp[prev*T + i-1];
00216
00217 if ( tmp_score > dp[idx] )
00218 {
00219 dp[idx] = tmp_score;
00220 trb[idx] = prev;
00221 }
00222 }
00223 }
00224 }
00225 }
00226
00227
00228 SGVector< int32_t > opt_path(T);
00229 CResultSet* ret = new CResultSet();
00230 SG_REF(ret);
00231 ret->score = -CMath::INFTY;
00232 opt_path[T-1] = -1;
00233
00234 for ( int32_t s = 0 ; s < S ; ++s )
00235 {
00236 idx = s*T + T-1;
00237
00238 if ( q[s] > -CMath::INFTY && dp[idx] > ret->score )
00239 {
00240 ret->score = dp[idx];
00241 opt_path[T-1] = s;
00242 }
00243 }
00244
00245 for ( int32_t i = T-1 ; i > 0 ; --i )
00246 opt_path[i-1] = trb[opt_path[i]*T + i];
00247
00248
00249 CSequence* ypred = m_state_model->states_to_labels(opt_path);
00250
00251 ret->psi_pred = get_joint_feature_vector(feat_idx, ypred);
00252 ret->argmax = ypred;
00253 if ( training )
00254 {
00255 ret->delta = CStructuredModel::delta_loss(feat_idx, ypred);
00256 ret->psi_truth = CStructuredModel::get_joint_feature_vector(
00257 feat_idx, feat_idx);
00258 ret->score -= SGVector< float64_t >::dot(w.vector,
00259 ret->psi_truth.vector, dim);
00260 }
00261
00262 return ret;
00263 }
00264
00265 float64_t CHMSVMModel::delta_loss(CStructuredData* y1, CStructuredData* y2)
00266 {
00267 CSequence* seq1 = CSequence::obtain_from_generic(y1);
00268 CSequence* seq2 = CSequence::obtain_from_generic(y2);
00269
00270
00271 return m_state_model->loss(seq1, seq2);
00272 }
00273
00274 void CHMSVMModel::init_opt(
00275 SGMatrix< float64_t > & A,
00276 SGVector< float64_t > a,
00277 SGMatrix< float64_t > B,
00278 SGVector< float64_t > & b,
00279 SGVector< float64_t > lb,
00280 SGVector< float64_t > ub,
00281 SGMatrix< float64_t > & C)
00282 {
00283
00284 int32_t S = ((CHMSVMLabels*) m_labels)->get_num_states();
00285
00286 int32_t D = ((CMatrixFeatures< float64_t >*) m_features)->get_num_features();
00287
00288
00289 SGVector< int32_t > monotonicity = m_state_model->get_monotonicity(S,D);
00290
00291
00292
00293 float64_t C_small = 5.0;
00294 float64_t C_smooth = 10.0;
00295
00296 C = SGMatrix< float64_t >(get_dim()+m_num_aux, get_dim()+m_num_aux);
00297 C.zero();
00298 for ( int32_t i = 0 ; i < get_dim() ; ++i )
00299 C(i,i) = C_small;
00300 for ( int32_t i = get_dim() ; i < get_dim()+m_num_aux ; ++i )
00301 C(i,i) = C_smooth;
00302
00303
00304
00305
00306
00307 A = SGMatrix< float64_t >(2*m_num_aux, get_dim()+m_num_aux);
00308 A.zero();
00309
00310
00311
00312 SGVector< int32_t > score_starts(S*D);
00313 for ( int32_t idx = S*S, k = 0 ; k < S*D ; idx += m_num_obs, ++k )
00314 score_starts[k] = idx;
00315
00316
00317 SGVector< int32_t > aux_starts_smooth(S*D);
00318 for ( int32_t idx = get_dim(), k = 0 ; k < S*D ; idx += m_num_obs-1, ++k )
00319 aux_starts_smooth[k] = idx;
00320
00321
00322
00323
00324
00325 int32_t con_idx = 0, scr_idx, aux_idx;
00326
00327 for ( int32_t i = 0 ; i < score_starts.vlen ; ++i )
00328 {
00329 scr_idx = score_starts[i];
00330 aux_idx = aux_starts_smooth[i];
00331
00332 for ( int32_t j = 0 ; j < m_num_obs-1 ; ++j )
00333 {
00334 A(con_idx, scr_idx) = 1;
00335 A(con_idx, scr_idx+1) = -1;
00336
00337 if ( monotonicity[i] != 1 )
00338 A(con_idx, aux_idx) = -1;
00339 ++con_idx;
00340
00341 A(con_idx, scr_idx) = -1;
00342 A(con_idx, scr_idx+1) = 1;
00343
00344 if ( monotonicity[i] != -1 )
00345 A(con_idx, aux_idx) = -1;
00346 ++con_idx;
00347
00348 ++scr_idx, ++aux_idx;
00349 }
00350 }
00351
00352
00353 b = SGVector< float64_t >(2*m_num_aux);
00354 b.zero();
00355 }
00356
00357 bool CHMSVMModel::check_training_setup() const
00358 {
00359
00360 CHMSVMLabels* hmsvm_labels = (CHMSVMLabels*) m_labels;
00361
00362 SGVector< int32_t > state_freq( hmsvm_labels->get_num_states() );
00363 state_freq.zero();
00364
00365 CSequence* seq;
00366 int32_t state;
00367 for ( int32_t i = 0 ; i < hmsvm_labels->get_num_labels() ; ++i )
00368 {
00369 seq = CSequence::obtain_from_generic(hmsvm_labels->get_label(i));
00370
00371 SGVector<int32_t> seq_data = seq->get_data();
00372 for ( int32_t j = 0 ; j < seq_data.size() ; ++j )
00373 {
00374 state = seq_data[j];
00375
00376 if ( state < 0 || state >= hmsvm_labels->get_num_states() )
00377 {
00378 SG_ERROR("Found state out of {0, 1, ..., "
00379 "num_states-1}\n");
00380 return false;
00381 }
00382 else
00383 {
00384 ++state_freq[state];
00385 }
00386 }
00387
00388
00389 SG_UNREF(seq);
00390 }
00391
00392 for ( int32_t i = 0 ; i < hmsvm_labels->get_num_states() ; ++i )
00393 {
00394 if ( state_freq[i] <= 0 )
00395 {
00396 SG_ERROR("What? State %d has never appeared\n", i);
00397 return false;
00398 }
00399 }
00400
00401 return true;
00402 }
00403
00404 void CHMSVMModel::init()
00405 {
00406 SG_ADD(&m_num_states, "m_num_states", "The number of states", MS_NOT_AVAILABLE);
00407 SG_ADD((CSGObject**) &m_state_model, "m_state_model", "The state model", MS_NOT_AVAILABLE);
00408 SG_ADD(&m_transmission_weights, "m_transmission_weights",
00409 "Transmission weights used in Viterbi", MS_NOT_AVAILABLE);
00410 SG_ADD(&m_emission_weights, "m_emission_weights",
00411 "Emission weights used in Viterbi", MS_NOT_AVAILABLE);
00412
00413 m_num_states = 0;
00414 m_num_obs = 0;
00415 m_num_aux = 0;
00416 m_state_model = NULL;
00417 }
00418
00419 int32_t CHMSVMModel::get_num_aux() const
00420 {
00421 return m_num_aux;
00422 }
00423
00424 int32_t CHMSVMModel::get_num_aux_con() const
00425 {
00426 return 2*m_num_aux;
00427 }