Machine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _MACHINE_H__
00012 #define _MACHINE_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/base/SGObject.h>
00016 #include <shogun/mathematics/Math.h>
00017 #include <shogun/features/Labels.h>
00018 #include <shogun/features/Features.h>
00019 
00020 namespace shogun
00021 {
00022 
00023 class CFeatures;
00024 class CLabels;
00025 class CMath;
00026 
00028 enum EClassifierType
00029 {
00030     CT_NONE = 0,
00031     CT_LIGHT = 10,
00032     CT_LIGHTONECLASS = 11,
00033     CT_LIBSVM = 20,
00034     CT_LIBSVMONECLASS=30,
00035     CT_LIBSVMMULTICLASS=40,
00036     CT_MPD = 50,
00037     CT_GPBT = 60,
00038     CT_CPLEXSVM = 70,
00039     CT_PERCEPTRON = 80,
00040     CT_KERNELPERCEPTRON = 90,
00041     CT_LDA = 100,
00042     CT_LPM = 110,
00043     CT_LPBOOST = 120,
00044     CT_KNN = 130,
00045     CT_SVMLIN=140,
00046     CT_KRR = 150,
00047     CT_GNPPSVM = 160,
00048     CT_GMNPSVM = 170,
00049     CT_SUBGRADIENTSVM = 180,
00050     CT_SUBGRADIENTLPM = 190,
00051     CT_SVMPERF = 200,
00052     CT_LIBSVR = 210,
00053     CT_SVRLIGHT = 220,
00054     CT_LIBLINEAR = 230,
00055     CT_KMEANS = 240,
00056     CT_HIERARCHICAL = 250,
00057     CT_SVMOCAS = 260,
00058     CT_WDSVMOCAS = 270,
00059     CT_SVMSGD = 280,
00060     CT_MKLMULTICLASS = 290,
00061     CT_MKLCLASSIFICATION = 300,
00062     CT_MKLONECLASS = 310,
00063     CT_MKLREGRESSION = 320,
00064     CT_SCATTERSVM = 330,
00065     CT_DASVM = 340,
00066     CT_LARANK = 350,
00067     CT_DASVMLINEAR = 360,
00068     CT_GAUSSIANNAIVEBAYES = 370,
00069     CT_AVERAGEDPERCEPTRON = 380,
00070     CT_SGDQN = 390,
00071 
00072 };
00073 
00075 enum ESolverType
00076 {
00077     ST_AUTO=0,
00078     ST_CPLEX=1,
00079     ST_GLPK=2,
00080     ST_NEWTON=3,
00081     ST_DIRECT=4,
00082     ST_ELASTICNET=5,
00083     ST_BLOCK_NORM=6
00084 };
00085 
00097 class CMachine : public CSGObject
00098 {
00099     public:
00101         CMachine();
00102         virtual ~CMachine();
00103 
00113         virtual bool train(CFeatures* data=NULL)
00114         {
00115             bool result=train_machine(data);
00116 
00117             if (m_store_model_features)
00118                 store_model_features();
00119 
00120             return result;
00121         }
00122 
00127         virtual CLabels* apply()=0;
00128 
00134         virtual CLabels* apply(CFeatures* data)=0;
00135 
00143         virtual float64_t apply(int32_t num)
00144         {
00145             SG_NOTIMPLEMENTED;
00146             return CMath::INFTY;
00147         }
00148 
00156         virtual bool load(FILE* srcfile) { ASSERT(srcfile); return false; }
00157 
00165         virtual bool save(FILE* dstfile) { ASSERT(dstfile); return false; }
00166 
00171         virtual inline void set_labels(CLabels* lab)
00172         {
00173             SG_UNREF(labels);
00174             SG_REF(lab);
00175             labels=lab;
00176         }
00177 
00182         virtual inline CLabels* get_labels() { SG_REF(labels); return labels; }
00183 
00189         virtual inline float64_t get_label(int32_t i)
00190         {
00191             if (!labels)
00192                 SG_ERROR("No Labels assigned\n");
00193 
00194             return labels->get_label(i);
00195         }
00196 
00201         inline void set_max_train_time(float64_t t) { max_train_time=t; }
00202 
00207         inline float64_t get_max_train_time() { return max_train_time; }
00208 
00213         virtual inline EClassifierType get_classifier_type() { return CT_NONE; }
00214 
00219         inline void set_solver_type(ESolverType st) { solver_type=st; }
00220 
00225         inline ESolverType get_solver_type() { return solver_type; }
00226 
00232         virtual void set_store_model_features(bool store_model)
00233         {
00234             m_store_model_features=store_model;
00235         }
00236 
00237     protected:
00248         virtual bool train_machine(CFeatures* data=NULL)
00249         {
00250             SG_ERROR("train_machine is not yet implemented for %s!\n",
00251                     get_name());
00252             return false;
00253         }
00254 
00265         virtual void store_model_features()
00266         {
00267             SG_ERROR("Model storage and therefore Cross-Validation and "
00268                     "Model-Selection is not supported for %s\n", get_name());
00269         }
00270 
00271     protected:
00273         float64_t max_train_time;
00274 
00276         CLabels* labels;
00277 
00279         ESolverType solver_type;
00280 
00282         bool m_store_model_features;
00283 };
00284 }
00285 #endif // _MACHINE_H__
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation