LinearLatentMachine.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Viktor Gal
00008  * Copyright (C) 2012 Viktor Gal
00009  */
00010 
00011 #include <typeinfo>
00012 
00013 #include <shogun/machine/LinearLatentMachine.h>
00014 #include <shogun/features/LatentFeatures.h>
00015 #include <shogun/features/DenseFeatures.h>
00016 
00017 using namespace shogun;
00018 
00019 CLinearLatentMachine::CLinearLatentMachine()
00020     : CLinearMachine()
00021 {
00022     init();
00023 }
00024 
00025 CLinearLatentMachine::CLinearLatentMachine(CLatentModel* model, float64_t C)
00026     : CLinearMachine()
00027 {
00028     init();
00029     m_C= C;
00030     set_model(model);
00031 
00032     index_t feat_dim = m_model->get_dim();
00033     w.resize_vector(feat_dim);
00034     w.zero();
00035 }
00036 
00037 CLinearLatentMachine::~CLinearLatentMachine()
00038 {
00039     SG_UNREF(m_model);
00040 }
00041 
00042 CLatentLabels* CLinearLatentMachine::apply_latent(CFeatures* data)
00043 {
00044     if (m_model == NULL)
00045         SG_ERROR("LatentModel is not set!\n");
00046 
00047     CLatentFeatures* lf = CLatentFeatures::obtain_from_generic(data);
00048     m_model->set_features(lf);
00049 
00050     return apply_latent();
00051 }
00052 
00053 void CLinearLatentMachine::set_model(CLatentModel* latent_model)
00054 {
00055     ASSERT(latent_model != NULL);
00056     SG_UNREF(m_model);
00057     SG_REF(latent_model);
00058     m_model = latent_model;
00059 }
00060 
00061 bool CLinearLatentMachine::train_machine(CFeatures* data)
00062 {
00063     if (m_model == NULL)
00064         SG_ERROR("LatentModel is not set!\n");
00065 
00066     SG_DEBUG("PSI size: %d\n", m_model->get_dim());
00067     SG_DEBUG("Number of training data: %d\n", m_model->get_num_vectors());
00068     SG_DEBUG("Initialise PSI (x,h)\n");
00069     m_model->cache_psi_features();
00070 
00071     /*
00072      * define variables for calculating the stopping
00073      * criterion for the outer loop
00074      */
00075     float64_t decrement = 0.0, primal_obj = 0.0, prev_po = 0.0;
00076     float64_t inner_eps = 0.5*m_C*m_epsilon;
00077     bool stop = false;
00078     m_cur_iter = 0;
00079 
00080     /* do CCCP */
00081     SG_DEBUG("Starting CCCP\n");
00082     while ((m_cur_iter < 2)||(!stop&&(m_cur_iter < m_max_iter)))
00083     {
00084         SG_DEBUG("iteration: %d\n", m_cur_iter);
00085         /* do the SVM optimisation with fixed h* */
00086         SG_DEBUG("Do the inner loop of CCCP: optimize for w for fixed h*\n");
00087         primal_obj = do_inner_loop(inner_eps);
00088 
00089         /* calculate the decrement */
00090         decrement = prev_po - primal_obj;
00091         prev_po = primal_obj;
00092         SG_DEBUG("decrement: %f\n", decrement);
00093         SG_DEBUG("primal objective: %f\n", primal_obj);
00094 
00095         /* check the stopping criterion */
00096         stop = (inner_eps < (0.5*m_C*m_epsilon+1E-8)) && (decrement < m_C*m_epsilon);
00097 
00098         inner_eps = -decrement*0.01;
00099         inner_eps = CMath::max(inner_eps, 0.5*m_C*m_epsilon);
00100         SG_DEBUG("inner epsilon: %f\n", inner_eps);
00101 
00102         /* find argmaxH */
00103         SG_DEBUG("Find and set h_i = argmax_h (w, psi(x_i,h))\n");
00104         m_model->argmax_h(w);
00105 
00106         SG_DEBUG("Recalculating PSI (x,h) with the new h variables\n");
00107         m_model->cache_psi_features();
00108 
00109         /* increment iteration counter */
00110         m_cur_iter++;
00111     }
00112 
00113     return true;
00114 }
00115 
00116 void CLinearLatentMachine::init()
00117 {
00118     m_C = 10.0;
00119     m_epsilon = 1E-3;
00120     m_max_iter = 400;
00121     m_model = NULL;
00122 
00123     m_parameters->add(&m_C, "C",  "Cost constant.");
00124     m_parameters->add(&m_epsilon, "epsilon", "Convergence precision.");
00125     m_parameters->add(&m_max_iter, "max_iter", "Maximum iterations.");
00126     m_parameters->add((CSGObject**) &m_model, "latent_model", "Latent Model.");
00127 }
00128 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation