MulticlassModel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #include <shogun/features/DotFeatures.h>
00012 #include <shogun/mathematics/Math.h>
00013 #include <shogun/structure/MulticlassModel.h>
00014 #include <shogun/structure/MulticlassSOLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassModel::CMulticlassModel()
00019 : CStructuredModel()
00020 {
00021     init();
00022 }
00023 
00024     CMulticlassModel::CMulticlassModel(CFeatures* features, CStructuredLabels* labels)
00025 : CStructuredModel(features, labels)
00026 {
00027     init();
00028 }
00029 
00030 CMulticlassModel::~CMulticlassModel()
00031 {
00032 }
00033 
00034 int32_t CMulticlassModel::get_dim() const
00035 {
00036     // TODO make the casts safe!
00037     int32_t num_classes = ((CMulticlassSOLabels*) m_labels)->get_num_classes();
00038     int32_t feats_dim   = ((CDotFeatures*) m_features)->get_dim_feature_space();
00039 
00040     return feats_dim*num_classes;
00041 }
00042 
00043 SGVector< float64_t > CMulticlassModel::get_joint_feature_vector(int32_t feat_idx, CStructuredData* y)
00044 {
00045     SGVector< float64_t > psi( get_dim() );
00046     psi.zero();
00047 
00048     SGVector< float64_t > x = ((CDotFeatures*) m_features)->
00049         get_computed_dot_feature_vector(feat_idx);
00050     CRealNumber* r = CRealNumber::obtain_from_generic(y);
00051     ASSERT(r != NULL)
00052     float64_t label_value = r->value;
00053 
00054     for ( index_t i = 0, j = label_value*x.vlen ; i < x.vlen ; ++i, ++j )
00055         psi[j] = x[i];
00056 
00057     return psi;
00058 }
00059 
00060 CResultSet* CMulticlassModel::argmax(
00061         SGVector< float64_t > w,
00062         int32_t feat_idx,
00063         bool const training)
00064 {
00065     CDotFeatures* df = (CDotFeatures*) m_features;
00066     int32_t feats_dim   = df->get_dim_feature_space();
00067 
00068     if ( training )
00069     {
00070         CMulticlassSOLabels* ml = (CMulticlassSOLabels*) m_labels;
00071         m_num_classes = ml->get_num_classes();
00072     }
00073     else
00074     {
00075         REQUIRE(m_num_classes > 0, "The model needs to be trained before "
00076                 "using it for prediction\n");
00077     }
00078 
00079     int32_t dim = get_dim();
00080     ASSERT(dim == w.vlen);
00081 
00082     // Find the class that gives the maximum score
00083 
00084     float64_t score = 0, ypred = 0;
00085     float64_t max_score = -CMath::INFTY;
00086 
00087     for ( int32_t c = 0 ; c < m_num_classes ; ++c )
00088     {
00089         score = df->dense_dot(feat_idx, w.vector+c*feats_dim, feats_dim);
00090         if ( training )
00091             score += delta_loss(feat_idx, c);
00092 
00093         if ( score > max_score )
00094         {
00095             max_score = score;
00096             ypred = c;
00097         }
00098     }
00099 
00100     // Build the CResultSet object to return
00101     CResultSet* ret = new CResultSet();
00102     SG_REF(ret);
00103     CRealNumber* y  = new CRealNumber(ypred);
00104     SG_REF(y);
00105 
00106     ret->psi_pred = get_joint_feature_vector(feat_idx, y);
00107     ret->score    = max_score;
00108     ret->argmax   = y;
00109     if ( training )
00110     {
00111         ret->delta     = CStructuredModel::delta_loss(feat_idx, y);
00112         ret->psi_truth = CStructuredModel::get_joint_feature_vector(
00113                     feat_idx, feat_idx);
00114         ret->score    -= SGVector< float64_t >::dot(w.vector,
00115                     ret->psi_truth.vector, dim);
00116     }
00117 
00118     return ret;
00119 }
00120 
00121 float64_t CMulticlassModel::delta_loss(CStructuredData* y1, CStructuredData* y2)
00122 {
00123     CRealNumber* rn1 = CRealNumber::obtain_from_generic(y1);
00124     CRealNumber* rn2 = CRealNumber::obtain_from_generic(y2);
00125     ASSERT(rn1 != NULL);
00126     ASSERT(rn2 != NULL);
00127 
00128     return delta_loss(rn1->value, rn2->value);
00129 }
00130 
00131 float64_t CMulticlassModel::delta_loss(int32_t y1_idx, float64_t y2)
00132 {
00133     REQUIRE(y1_idx >= 0 || y1_idx < m_labels->get_num_labels(),
00134             "The label index must be inside [0, num_labels-1]\n");
00135 
00136     CRealNumber* rn1 = CRealNumber::obtain_from_generic(m_labels->get_label(y1_idx));
00137     float64_t ret = delta_loss(rn1->value, y2);
00138     SG_UNREF(rn1);
00139 
00140     return ret;
00141 }
00142 
00143 float64_t CMulticlassModel::delta_loss(float64_t y1, float64_t y2)
00144 {
00145     return (y1 == y2) ? 0 : 1;
00146 }
00147 
00148 void CMulticlassModel::init_opt(
00149         SGMatrix< float64_t > & A,
00150         SGVector< float64_t > a,
00151         SGMatrix< float64_t > B,
00152         SGVector< float64_t > & b,
00153         SGVector< float64_t > lb,
00154         SGVector< float64_t > ub,
00155         SGMatrix< float64_t > & C)
00156 {
00157     C = SGMatrix< float64_t >::create_identity_matrix(get_dim(), 1);
00158 }
00159 
00160 void CMulticlassModel::init()
00161 {
00162     SG_ADD(&m_num_classes, "m_num_classes", "The number of classes",
00163             MS_NOT_AVAILABLE);
00164 
00165     m_num_classes = 0;
00166 }
00167 
00168 float64_t CMulticlassModel::risk(float64_t* subgrad, float64_t* W, TMultipleCPinfo* info)
00169 {
00170     CDotFeatures* X=(CDotFeatures*)m_features;
00171     CMulticlassSOLabels* y=(CMulticlassSOLabels*)m_labels;
00172     m_num_classes = y->get_num_classes();
00173     uint32_t from, to;
00174 
00175     if (info)
00176     {
00177         from=info->_from;
00178         to=(info->N == 0) ? X->get_num_vectors() : from+info->N;
00179     } else {
00180         from=0;
00181         to=X->get_num_vectors();
00182     }
00183 
00184     uint32_t num_classes=y->get_num_classes();
00185     uint32_t feats_dim=X->get_dim_feature_space();
00186     const uint32_t w_dim=get_dim();
00187 
00188     float64_t R=0.0;
00189     for (uint32_t i=0; i<w_dim; i++)
00190         subgrad[i] = 0;
00191 
00192     float64_t Rtmp=0.0;
00193     float64_t Rmax=0.0;
00194     float64_t loss=0.0;
00195     uint32_t yhat=0;
00196     uint32_t GT=0;
00197     CRealNumber* GT_rn=NULL;
00198 
00199     /* loop through examples */
00200     for(uint32_t i=from; i<to; ++i)
00201     {
00202         Rmax=-CMath::INFTY;
00203         GT_rn=CRealNumber::obtain_from_generic(y->get_label(i));
00204         GT=(uint32_t)GT_rn->value;
00205 
00206         for (uint32_t c = 0; c < num_classes; ++c)
00207         {
00208             loss=(c == GT) ? 0.0 : 1.0;
00209             Rtmp=loss+X->dense_dot(i, W+c*feats_dim, feats_dim)
00210                 -X->dense_dot(i, W+GT*feats_dim, feats_dim);
00211 
00212             if (Rtmp > Rmax)
00213             {
00214                 Rmax=Rtmp;
00215                 yhat=c;
00216             }
00217         }
00218         R += Rmax;
00219 
00220         X->add_to_dense_vec(1.0, i, subgrad+yhat*feats_dim, feats_dim);
00221         X->add_to_dense_vec(-1.0, i, subgrad+GT*feats_dim, feats_dim);
00222 
00223         SG_UNREF(GT_rn);
00224     }
00225 
00226     return R;
00227 }
00228 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation