Machine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Written (W) 2011-2012 Heiko Strathmann
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #ifndef _MACHINE_H__
00013 #define _MACHINE_H__
00014 
00015 #include <shogun/lib/common.h>
00016 #include <shogun/base/SGObject.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <shogun/labels/MulticlassLabels.h>
00021 #include <shogun/labels/StructuredLabels.h>
00022 #include <shogun/labels/LatentLabels.h>
00023 #include <shogun/features/Features.h>
00024 
00025 namespace shogun
00026 {
00027 
00028 class CFeatures;
00029 class CLabels;
00030 class CMath;
00031 
00033 enum EMachineType
00034 {
00035     CT_NONE = 0,
00036     CT_LIGHT = 10,
00037     CT_LIGHTONECLASS = 11,
00038     CT_LIBSVM = 20,
00039     CT_LIBSVMONECLASS=30,
00040     CT_LIBSVMMULTICLASS=40,
00041     CT_MPD = 50,
00042     CT_GPBT = 60,
00043     CT_CPLEXSVM = 70,
00044     CT_PERCEPTRON = 80,
00045     CT_KERNELPERCEPTRON = 90,
00046     CT_LDA = 100,
00047     CT_LPM = 110,
00048     CT_LPBOOST = 120,
00049     CT_KNN = 130,
00050     CT_SVMLIN=140,
00051     CT_KERNELRIDGEREGRESSION = 150,
00052     CT_GNPPSVM = 160,
00053     CT_GMNPSVM = 170,
00054     CT_SUBGRADIENTSVM = 180,
00055     CT_SUBGRADIENTLPM = 190,
00056     CT_SVMPERF = 200,
00057     CT_LIBSVR = 210,
00058     CT_SVRLIGHT = 220,
00059     CT_LIBLINEAR = 230,
00060     CT_KMEANS = 240,
00061     CT_HIERARCHICAL = 250,
00062     CT_SVMOCAS = 260,
00063     CT_WDSVMOCAS = 270,
00064     CT_SVMSGD = 280,
00065     CT_MKLMULTICLASS = 290,
00066     CT_MKLCLASSIFICATION = 300,
00067     CT_MKLONECLASS = 310,
00068     CT_MKLREGRESSION = 320,
00069     CT_SCATTERSVM = 330,
00070     CT_DASVM = 340,
00071     CT_LARANK = 350,
00072     CT_DASVMLINEAR = 360,
00073     CT_GAUSSIANNAIVEBAYES = 370,
00074     CT_AVERAGEDPERCEPTRON = 380,
00075     CT_SGDQN = 390,
00076     CT_CONJUGATEINDEX = 400,
00077     CT_LINEARRIDGEREGRESSION = 410,
00078     CT_LEASTSQUARESREGRESSION = 420,
00079     CT_QDA = 430,
00080     CT_NEWTONSVM = 440,
00081     CT_GAUSSIANPROCESSREGRESSION = 450,
00082     CT_LARS = 460,
00083     CT_MULTICLASS = 470,
00084     CT_DIRECTORLINEAR = 480,
00085     CT_DIRECTORKERNEL = 490
00086 };
00087 
00089 enum ESolverType
00090 {
00091     ST_AUTO=0,
00092     ST_CPLEX=1,
00093     ST_GLPK=2,
00094     ST_NEWTON=3,
00095     ST_DIRECT=4,
00096     ST_ELASTICNET=5,
00097     ST_BLOCK_NORM=6
00098 };
00099 
00101 enum EProblemType
00102 {
00103     PT_BINARY = 0,
00104     PT_REGRESSION = 1,
00105     PT_MULTICLASS = 2,
00106     PT_STRUCTURED = 3,
00107     PT_LATENT = 4
00108 };
00109 
00110 #define MACHINE_PROBLEM_TYPE(PT) \
00111  \
00114     virtual EProblemType get_machine_problem_type() const { return PT; }
00115 
00133 class CMachine : public CSGObject
00134 {
00135     public:
00137         CMachine();
00138 
00140         virtual ~CMachine();
00141 
00151         virtual bool train(CFeatures* data=NULL);
00152 
00159         virtual CLabels* apply(CFeatures* data=NULL);
00160 
00162         virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00164         virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
00166         virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00168         virtual CStructuredLabels* apply_structured(CFeatures* data=NULL);
00170         virtual CLatentLabels* apply_latent(CFeatures* data=NULL);
00171 
00176         virtual void set_labels(CLabels* lab);
00177 
00182         virtual CLabels* get_labels();
00183 
00188         void set_max_train_time(float64_t t);
00189 
00194         float64_t get_max_train_time();
00195 
00200         virtual EMachineType get_classifier_type();
00201 
00206         void set_solver_type(ESolverType st);
00207 
00212         ESolverType get_solver_type();
00213 
00219         virtual void set_store_model_features(bool store_model);
00220 
00229         virtual bool train_locked(SGVector<index_t> indices)
00230         {
00231             SG_ERROR("train_locked(SGVector<index_t>) is not yet implemented "
00232                     "for %s\n", get_name());
00233             return false;
00234         }
00235 
00237         virtual float64_t apply_one(int32_t i)
00238         {
00239             SG_NOTIMPLEMENTED;
00240             return 0.0;
00241         }
00242 
00248         virtual CLabels* apply_locked(SGVector<index_t> indices);
00249 
00251         virtual CBinaryLabels* apply_locked_binary(
00252                 SGVector<index_t> indices);
00254         virtual CRegressionLabels* apply_locked_regression(
00255                 SGVector<index_t> indices);
00257         virtual CMulticlassLabels* apply_locked_multiclass(
00258                 SGVector<index_t> indices);
00260         virtual CStructuredLabels* apply_locked_structured(
00261                 SGVector<index_t> indices);
00263         virtual CLatentLabels* apply_locked_latent(
00264                 SGVector<index_t> indices);
00265 
00274         virtual void data_lock(CLabels* labs, CFeatures* features);
00275 
00277         virtual void post_lock(CLabels* labs, CFeatures* features) { };
00278 
00280         virtual void data_unlock();
00281 
00283         virtual bool supports_locking() const { return false; }
00284 
00286         bool is_data_locked() const { return m_data_locked; }
00287 
00289         virtual EProblemType get_machine_problem_type() const
00290         {
00291             SG_NOTIMPLEMENTED;
00292             return PT_BINARY;
00293         }
00294 
00296         virtual CMachine* clone()
00297         {
00298             SG_NOTIMPLEMENTED;
00299             return NULL;
00300         }
00301 
00302         virtual const char* get_name() const { return "Machine"; }
00303 
00304     protected:
00315         virtual bool train_machine(CFeatures* data=NULL)
00316         {
00317             SG_ERROR("train_machine is not yet implemented for %s!\n",
00318                     get_name());
00319             return false;
00320         }
00321 
00332         virtual void store_model_features()
00333         {
00334             SG_ERROR("Model storage and therefore unlocked Cross-Validation and"
00335                     " Model-Selection is not supported for %s. Locked may"
00336                     " work though.\n", get_name());
00337         }
00338 
00345         virtual bool is_label_valid(CLabels *lab) const
00346         {
00347             return true;
00348         }
00349 
00351         virtual bool train_require_labels() const { return true; }
00352 
00353     protected:
00355         float64_t m_max_train_time;
00356 
00358         CLabels* m_labels;
00359 
00361         ESolverType m_solver_type;
00362 
00364         bool m_store_model_features;
00365 
00367         bool m_data_locked;
00368 };
00369 }
00370 #endif // _MACHINE_H__
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation