Machine.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 2011-2012 Heiko Strathmann
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/base/Parameter.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/base/ParameterMap.h>
00016 
00017 using namespace shogun;
00018 
00019 CMachine::CMachine() : CSGObject(), m_max_train_time(0), m_labels(NULL),
00020         m_solver_type(ST_AUTO)
00021 {
00022     m_data_locked=false;
00023     m_store_model_features=false;
00024 
00025     SG_ADD(&m_max_train_time, "max_train_time",
00026            "Maximum training time.", MS_NOT_AVAILABLE);
00027     SG_ADD((machine_int_t*) &m_solver_type, "solver_type",
00028            "Type of solver.", MS_NOT_AVAILABLE);
00029     SG_ADD((CSGObject**) &m_labels, "labels",
00030            "Labels to be used.", MS_NOT_AVAILABLE);
00031     SG_ADD(&m_store_model_features, "store_model_features",
00032            "Should feature data of model be stored after training?", MS_NOT_AVAILABLE);
00033     SG_ADD(&m_data_locked, "data_locked",
00034             "Indicates whether data is locked", MS_NOT_AVAILABLE);
00035 
00036     m_parameter_map->put(
00037         new SGParamInfo("data_locked", CT_SCALAR, ST_NONE, PT_BOOL, 1),
00038         new SGParamInfo()
00039     );
00040 
00041     m_parameter_map->finalize_map();
00042 }
00043 
00044 CMachine::~CMachine()
00045 {
00046     SG_UNREF(m_labels);
00047 }
00048 
00049 bool CMachine::train(CFeatures* data)
00050 {
00051     /* not allowed to train on locked data */
00052     if (m_data_locked)
00053     {
00054         SG_ERROR("%s::train data_lock() was called, only train_locked() is"
00055                 " possible. Call data_unlock if you want to call train()\n",
00056                 get_name());
00057     }
00058 
00059     if (train_require_labels())
00060     {
00061         if (m_labels == NULL)
00062             SG_ERROR("%s@%p: No labels given", get_name(), this);
00063 
00064         m_labels->ensure_valid(get_name());
00065     }
00066 
00067     bool result = train_machine(data);
00068 
00069     if (m_store_model_features)
00070         store_model_features();
00071 
00072     return result;
00073 }
00074 
00075 void CMachine::set_labels(CLabels* lab)
00076 {
00077     if (lab != NULL)
00078         if (!is_label_valid(lab))
00079             SG_ERROR("Invalid label for %s", get_name());
00080 
00081     SG_UNREF(m_labels);
00082     SG_REF(lab);
00083     m_labels = lab;
00084 }
00085 
00086 CLabels* CMachine::get_labels()
00087 {
00088     SG_REF(m_labels);
00089     return m_labels;
00090 }
00091 
00092 void CMachine::set_max_train_time(float64_t t)
00093 {
00094     m_max_train_time = t;
00095 }
00096 
00097 float64_t CMachine::get_max_train_time()
00098 {
00099     return m_max_train_time;
00100 }
00101 
00102 EMachineType CMachine::get_classifier_type()
00103 {
00104     return CT_NONE;
00105 }
00106 
00107 void CMachine::set_solver_type(ESolverType st)
00108 {
00109     m_solver_type = st;
00110 }
00111 
00112 ESolverType CMachine::get_solver_type()
00113 {
00114     return m_solver_type;
00115 }
00116 
00117 void CMachine::set_store_model_features(bool store_model)
00118 {
00119     m_store_model_features = store_model;
00120 }
00121 
00122 void CMachine::data_lock(CLabels* labs, CFeatures* features)
00123 {
00124     SG_DEBUG("entering %s::data_lock\n", get_name());
00125     if (!supports_locking())
00126     {
00127         {
00128             SG_ERROR("%s::data_lock(): Machine does not support data locking!\n",
00129                     get_name());
00130         }
00131     }
00132 
00133     if (!labs)
00134     {
00135         SG_ERROR("%s::data_lock() is not possible will NULL labels!\n",
00136                 get_name());
00137     }
00138 
00139     /* first set labels */
00140     set_labels(labs);
00141 
00142     if (m_data_locked)
00143     {
00144         SG_ERROR("%s::data_lock() was already called. Dont lock twice!",
00145                 get_name());
00146     }
00147 
00148     m_data_locked=true;
00149     post_lock(labs,features);
00150     SG_DEBUG("leaving %s::data_lock\n", get_name());
00151 }
00152 
00153 void CMachine::data_unlock()
00154 {
00155     SG_DEBUG("entering %s::data_lock\n", get_name());
00156     if (m_data_locked)
00157         m_data_locked=false;
00158 
00159     SG_DEBUG("leaving %s::data_lock\n", get_name());
00160 }
00161 
00162 CLabels* CMachine::apply(CFeatures* data)
00163 {
00164     SG_DEBUG("entering %s::apply(%s at %p)\n",
00165             get_name(), data ? data->get_name() : "NULL", data);
00166 
00167     CLabels* result=NULL;
00168 
00169     switch (get_machine_problem_type())
00170     {
00171         case PT_BINARY:
00172             result=apply_binary(data);
00173             break;
00174         case PT_REGRESSION:
00175             result=apply_regression(data);
00176             break;
00177         case PT_MULTICLASS:
00178             result=apply_multiclass(data);
00179             break;
00180         case PT_STRUCTURED:
00181             result=apply_structured(data);
00182             break;
00183         case PT_LATENT:
00184             result=apply_latent(data);
00185             break;
00186         default:
00187             SG_ERROR("Unknown problem type");
00188             break;
00189     }
00190 
00191     SG_DEBUG("leaving %s::apply(%s at %p)\n",
00192             get_name(), data ? data->get_name() : "NULL", data);
00193 
00194     return result;
00195 }
00196 
00197 CLabels* CMachine::apply_locked(SGVector<index_t> indices)
00198 {
00199     switch (get_machine_problem_type())
00200     {
00201         case PT_BINARY:
00202             return apply_locked_binary(indices);
00203         case PT_REGRESSION:
00204             return apply_locked_regression(indices);
00205         case PT_MULTICLASS:
00206             return apply_locked_multiclass(indices);
00207         case PT_STRUCTURED:
00208             return apply_locked_structured(indices);
00209         case PT_LATENT:
00210             return apply_locked_latent(indices);
00211         default:
00212             SG_ERROR("Unknown problem type");
00213             break;
00214     }
00215     return NULL;
00216 }
00217 
00218 CBinaryLabels* CMachine::apply_binary(CFeatures* data)
00219 {
00220     SG_ERROR("This machine does not support apply_binary()\n");
00221     return NULL;
00222 }
00223 
00224 CRegressionLabels* CMachine::apply_regression(CFeatures* data)
00225 {
00226     SG_ERROR("This machine does not support apply_regression()\n");
00227     return NULL;
00228 }
00229 
00230 CMulticlassLabels* CMachine::apply_multiclass(CFeatures* data)
00231 {
00232     SG_ERROR("This machine does not support apply_multiclass()\n");
00233     return NULL;
00234 }
00235 
00236 CStructuredLabels* CMachine::apply_structured(CFeatures* data)
00237 {
00238     SG_ERROR("This machine does not support apply_structured()\n");
00239     return NULL;
00240 }
00241 
00242 CLatentLabels* CMachine::apply_latent(CFeatures* data)
00243 {
00244     SG_ERROR("This machine does not support apply_latent()\n");
00245     return NULL;
00246 }
00247 
00248 CBinaryLabels* CMachine::apply_locked_binary(SGVector<index_t> indices)
00249 {
00250     SG_ERROR("apply_locked_binary(SGVector<index_t>) is not yet implemented "
00251             "for %s\n", get_name());
00252     return NULL;
00253 }
00254 
00255 CRegressionLabels* CMachine::apply_locked_regression(SGVector<index_t> indices)
00256 {
00257     SG_ERROR("apply_locked_regression(SGVector<index_t>) is not yet implemented "
00258             "for %s\n", get_name());
00259     return NULL;
00260 }
00261 
00262 CMulticlassLabels* CMachine::apply_locked_multiclass(SGVector<index_t> indices)
00263 {
00264     SG_ERROR("apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
00265             "for %s\n", get_name());
00266     return NULL;
00267 }
00268 
00269 CStructuredLabels* CMachine::apply_locked_structured(SGVector<index_t> indices)
00270 {
00271     SG_ERROR("apply_locked_structured(SGVector<index_t>) is not yet implemented "
00272             "for %s\n", get_name());
00273     return NULL;
00274 }
00275 
00276 CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices)
00277 {
00278     SG_ERROR("apply_locked_latent(SGVector<index_t>) is not yet implemented "
00279             "for %s\n", get_name());
00280     return NULL;
00281 }
00282 
00283 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation