Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _SVMOCAS_H___
00013 #define _SVMOCAS_H___
00014
00015 #include "lib/common.h"
00016 #include "classifier/LinearClassifier.h"
00017 #include "classifier/svm/libocas.h"
00018 #include "features/DotFeatures.h"
00019 #include "features/Labels.h"
00020
00021 namespace shogun
00022 {
00023 enum E_SVM_TYPE
00024 {
00025 SVM_OCAS = 0,
00026 SVM_BMRM = 1
00027 };
00028
00030 class CSVMOcas : public CLinearClassifier
00031 {
00032 public:
00034 CSVMOcas(void);
00035
00040 CSVMOcas(E_SVM_TYPE type);
00041
00048 CSVMOcas(
00049 float64_t C, CDotFeatures* traindat,
00050 CLabels* trainlab);
00051 virtual ~CSVMOcas();
00052
00057 virtual inline EClassifierType get_classifier_type() { return CT_SVMOCAS; }
00058
00067 virtual bool train(CFeatures* data=NULL);
00068
00075 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; }
00076
00081 inline float64_t get_C1() { return C1; }
00082
00087 inline float64_t get_C2() { return C2; }
00088
00093 inline void set_epsilon(float64_t eps) { epsilon=eps; }
00094
00099 inline float64_t get_epsilon() { return epsilon; }
00100
00105 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00106
00111 inline bool get_bias_enabled() { return use_bias; }
00112
00117 inline void set_bufsize(int32_t sz) { bufsize=sz; }
00118
00123 inline int32_t get_bufsize() { return bufsize; }
00124
00125 protected:
00134 static void compute_W(
00135 float64_t *sq_norm_W, float64_t *dp_WoldW, float64_t *alpha,
00136 uint32_t nSel, void* ptr);
00137
00144 static float64_t update_W(float64_t t, void* ptr );
00145
00154 static int add_new_cut(
00155 float64_t *new_col_H, uint32_t *new_cut, uint32_t cut_length,
00156 uint32_t nSel, void* ptr );
00157
00163 static int compute_output( float64_t *output, void* ptr );
00164
00171 static int sort( float64_t* vals, float64_t* data, uint32_t size);
00172
00174 static inline void print(ocas_return_value_T value)
00175 {
00176 return;
00177 }
00178
00180 inline virtual const char* get_name() const { return "SVMOcas"; }
00181 private:
00182 void init();
00183
00184 protected:
00186 bool use_bias;
00188 int32_t bufsize;
00190 float64_t C1;
00192 float64_t C2;
00194 float64_t epsilon;
00196 E_SVM_TYPE method;
00197
00199 float64_t* old_w;
00201 float64_t old_bias;
00203 float64_t* tmp_a_buf;
00205 float64_t* lab;
00206
00209 float64_t** cp_value;
00211 uint32_t** cp_index;
00213 uint32_t* cp_nz_dims;
00215 float64_t* cp_bias;
00216 };
00217 }
00218 #endif