00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _MACHINE_H__
00012 #define _MACHINE_H__
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/base/SGObject.h>
00016 #include <shogun/features/Labels.h>
00017 #include <shogun/features/Features.h>
00018
00019 namespace shogun
00020 {
00021
00022 class CFeatures;
00023 class CLabels;
00024 class CMath;
00025
00027 enum EClassifierType
00028 {
00029 CT_NONE = 0,
00030 CT_LIGHT = 10,
00031 CT_LIGHTONECLASS = 11,
00032 CT_LIBSVM = 20,
00033 CT_LIBSVMONECLASS=30,
00034 CT_LIBSVMMULTICLASS=40,
00035 CT_MPD = 50,
00036 CT_GPBT = 60,
00037 CT_CPLEXSVM = 70,
00038 CT_PERCEPTRON = 80,
00039 CT_KERNELPERCEPTRON = 90,
00040 CT_LDA = 100,
00041 CT_LPM = 110,
00042 CT_LPBOOST = 120,
00043 CT_KNN = 130,
00044 CT_SVMLIN=140,
00045 CT_KRR = 150,
00046 CT_GNPPSVM = 160,
00047 CT_GMNPSVM = 170,
00048 CT_SUBGRADIENTSVM = 180,
00049 CT_SUBGRADIENTLPM = 190,
00050 CT_SVMPERF = 200,
00051 CT_LIBSVR = 210,
00052 CT_SVRLIGHT = 220,
00053 CT_LIBLINEAR = 230,
00054 CT_KMEANS = 240,
00055 CT_HIERARCHICAL = 250,
00056 CT_SVMOCAS = 260,
00057 CT_WDSVMOCAS = 270,
00058 CT_SVMSGD = 280,
00059 CT_MKLMULTICLASS = 290,
00060 CT_MKLCLASSIFICATION = 300,
00061 CT_MKLONECLASS = 310,
00062 CT_MKLREGRESSION = 320,
00063 CT_SCATTERSVM = 330,
00064 CT_DASVM = 340,
00065 CT_LARANK = 350,
00066 CT_DASVMLINEAR = 360,
00067 CT_GAUSSIANNAIVEBAYES = 370,
00068 CT_AVERAGEDPERCEPTRON = 380,
00069 CT_SGDQN = 390,
00070
00071 };
00072
00074 enum ESolverType
00075 {
00076 ST_AUTO=0,
00077 ST_CPLEX=1,
00078 ST_GLPK=2,
00079 ST_NEWTON=3,
00080 ST_DIRECT=4,
00081 ST_ELASTICNET=5,
00082 ST_BLOCK_NORM=6
00083 };
00084
00096 class CMachine : public CSGObject
00097 {
00098 public:
00100 CMachine();
00101
00103 virtual ~CMachine();
00104
00114 virtual bool train(CFeatures* data=NULL);
00115
00120 virtual CLabels* apply()=0;
00121
00127 virtual CLabels* apply(CFeatures* data)=0;
00128
00136 virtual float64_t apply(int32_t num);
00137
00145 virtual bool load(FILE* srcfile);
00146
00154 virtual bool save(FILE* dstfile);
00155
00160 virtual void set_labels(CLabels* lab);
00161
00166 virtual CLabels* get_labels();
00167
00173 virtual float64_t get_label(int32_t i);
00174
00179 void set_max_train_time(float64_t t);
00180
00185 float64_t get_max_train_time();
00186
00191 virtual EClassifierType get_classifier_type();
00192
00197 void set_solver_type(ESolverType st);
00198
00203 ESolverType get_solver_type();
00204
00210 virtual void set_store_model_features(bool store_model);
00211
00212 protected:
00223 virtual bool train_machine(CFeatures* data=NULL)
00224 {
00225 SG_ERROR("train_machine is not yet implemented for %s!\n",
00226 get_name());
00227 return false;
00228 }
00229
00240 virtual void store_model_features()
00241 {
00242 SG_ERROR("Model storage and therefore Cross-Validation and "
00243 "Model-Selection is not supported for %s\n", get_name());
00244 }
00245
00246 protected:
00248 float64_t max_train_time;
00249
00251 CLabels* labels;
00252
00254 ESolverType solver_type;
00255
00257 bool m_store_model_features;
00258 };
00259 }
00260 #endif // _MACHINE_H__