00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _MACHINE_H__
00013 #define _MACHINE_H__
00014
00015 #include <shogun/lib/common.h>
00016 #include <shogun/base/SGObject.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <shogun/labels/MulticlassLabels.h>
00021 #include <shogun/labels/StructuredLabels.h>
00022 #include <shogun/labels/LatentLabels.h>
00023 #include <shogun/features/Features.h>
00024
00025 namespace shogun
00026 {
00027
00028 class CFeatures;
00029 class CLabels;
00030 class CMath;
00031
00033 enum EMachineType
00034 {
00035 CT_NONE = 0,
00036 CT_LIGHT = 10,
00037 CT_LIGHTONECLASS = 11,
00038 CT_LIBSVM = 20,
00039 CT_LIBSVMONECLASS=30,
00040 CT_LIBSVMMULTICLASS=40,
00041 CT_MPD = 50,
00042 CT_GPBT = 60,
00043 CT_CPLEXSVM = 70,
00044 CT_PERCEPTRON = 80,
00045 CT_KERNELPERCEPTRON = 90,
00046 CT_LDA = 100,
00047 CT_LPM = 110,
00048 CT_LPBOOST = 120,
00049 CT_KNN = 130,
00050 CT_SVMLIN=140,
00051 CT_KERNELRIDGEREGRESSION = 150,
00052 CT_GNPPSVM = 160,
00053 CT_GMNPSVM = 170,
00054 CT_SUBGRADIENTSVM = 180,
00055 CT_SUBGRADIENTLPM = 190,
00056 CT_SVMPERF = 200,
00057 CT_LIBSVR = 210,
00058 CT_SVRLIGHT = 220,
00059 CT_LIBLINEAR = 230,
00060 CT_KMEANS = 240,
00061 CT_HIERARCHICAL = 250,
00062 CT_SVMOCAS = 260,
00063 CT_WDSVMOCAS = 270,
00064 CT_SVMSGD = 280,
00065 CT_MKLMULTICLASS = 290,
00066 CT_MKLCLASSIFICATION = 300,
00067 CT_MKLONECLASS = 310,
00068 CT_MKLREGRESSION = 320,
00069 CT_SCATTERSVM = 330,
00070 CT_DASVM = 340,
00071 CT_LARANK = 350,
00072 CT_DASVMLINEAR = 360,
00073 CT_GAUSSIANNAIVEBAYES = 370,
00074 CT_AVERAGEDPERCEPTRON = 380,
00075 CT_SGDQN = 390,
00076 CT_CONJUGATEINDEX = 400,
00077 CT_LINEARRIDGEREGRESSION = 410,
00078 CT_LEASTSQUARESREGRESSION = 420,
00079 CT_QDA = 430,
00080 CT_NEWTONSVM = 440,
00081 CT_GAUSSIANPROCESSREGRESSION = 450,
00082 CT_LARS = 460,
00083 CT_MULTICLASS = 470,
00084 CT_DIRECTORLINEAR = 480,
00085 CT_DIRECTORKERNEL = 490
00086 };
00087
00089 enum ESolverType
00090 {
00091 ST_AUTO=0,
00092 ST_CPLEX=1,
00093 ST_GLPK=2,
00094 ST_NEWTON=3,
00095 ST_DIRECT=4,
00096 ST_ELASTICNET=5,
00097 ST_BLOCK_NORM=6
00098 };
00099
00101 enum EProblemType
00102 {
00103 PT_BINARY = 0,
00104 PT_REGRESSION = 1,
00105 PT_MULTICLASS = 2,
00106 PT_STRUCTURED = 3,
00107 PT_LATENT = 4
00108 };
00109
00110 #define MACHINE_PROBLEM_TYPE(PT) \
00111 \
00114 virtual EProblemType get_machine_problem_type() const { return PT; }
00115
00133 class CMachine : public CSGObject
00134 {
00135 public:
00137 CMachine();
00138
00140 virtual ~CMachine();
00141
00151 virtual bool train(CFeatures* data=NULL);
00152
00159 virtual CLabels* apply(CFeatures* data=NULL);
00160
00162 virtual CBinaryLabels* apply_binary(CFeatures* data=NULL);
00164 virtual CRegressionLabels* apply_regression(CFeatures* data=NULL);
00166 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL);
00168 virtual CStructuredLabels* apply_structured(CFeatures* data=NULL);
00170 virtual CLatentLabels* apply_latent(CFeatures* data=NULL);
00171
00176 virtual void set_labels(CLabels* lab);
00177
00182 virtual CLabels* get_labels();
00183
00188 void set_max_train_time(float64_t t);
00189
00194 float64_t get_max_train_time();
00195
00200 virtual EMachineType get_classifier_type();
00201
00206 void set_solver_type(ESolverType st);
00207
00212 ESolverType get_solver_type();
00213
00219 virtual void set_store_model_features(bool store_model);
00220
00229 virtual bool train_locked(SGVector<index_t> indices)
00230 {
00231 SG_ERROR("train_locked(SGVector<index_t>) is not yet implemented "
00232 "for %s\n", get_name());
00233 return false;
00234 }
00235
00237 virtual float64_t apply_one(int32_t i)
00238 {
00239 SG_NOTIMPLEMENTED;
00240 return 0.0;
00241 }
00242
00248 virtual CLabels* apply_locked(SGVector<index_t> indices);
00249
00251 virtual CBinaryLabels* apply_locked_binary(
00252 SGVector<index_t> indices);
00254 virtual CRegressionLabels* apply_locked_regression(
00255 SGVector<index_t> indices);
00257 virtual CMulticlassLabels* apply_locked_multiclass(
00258 SGVector<index_t> indices);
00260 virtual CStructuredLabels* apply_locked_structured(
00261 SGVector<index_t> indices);
00263 virtual CLatentLabels* apply_locked_latent(
00264 SGVector<index_t> indices);
00265
00274 virtual void data_lock(CLabels* labs, CFeatures* features);
00275
00277 virtual void post_lock(CLabels* labs, CFeatures* features) { };
00278
00280 virtual void data_unlock();
00281
00283 virtual bool supports_locking() const { return false; }
00284
00286 bool is_data_locked() const { return m_data_locked; }
00287
00289 virtual EProblemType get_machine_problem_type() const
00290 {
00291 SG_NOTIMPLEMENTED;
00292 return PT_BINARY;
00293 }
00294
00296 virtual CMachine* clone()
00297 {
00298 SG_NOTIMPLEMENTED;
00299 return NULL;
00300 }
00301
00302 virtual const char* get_name() const { return "Machine"; }
00303
00304 protected:
00315 virtual bool train_machine(CFeatures* data=NULL)
00316 {
00317 SG_ERROR("train_machine is not yet implemented for %s!\n",
00318 get_name());
00319 return false;
00320 }
00321
00332 virtual void store_model_features()
00333 {
00334 SG_ERROR("Model storage and therefore unlocked Cross-Validation and"
00335 " Model-Selection is not supported for %s. Locked may"
00336 " work though.\n", get_name());
00337 }
00338
00345 virtual bool is_label_valid(CLabels *lab) const
00346 {
00347 return true;
00348 }
00349
00351 virtual bool train_require_labels() const { return true; }
00352
00353 protected:
00355 float64_t m_max_train_time;
00356
00358 CLabels* m_labels;
00359
00361 ESolverType m_solver_type;
00362
00364 bool m_store_model_features;
00365
00367 bool m_data_locked;
00368 };
00369 }
00370 #endif // _MACHINE_H__