Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _MULTICLASSOCAS_H___
00012 #define _MULTICLASSOCAS_H___
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/features/DotFeatures.h>
00016 #include <shogun/lib/external/libocas.h>
00017 #include <shogun/machine/LinearMulticlassMachine.h>
00018
00019 namespace shogun
00020 {
00021
00023 class CMulticlassOCAS : public CLinearMulticlassMachine
00024 {
00025 public:
00026 MACHINE_PROBLEM_TYPE(PT_MULTICLASS)
00027
00028
00029 CMulticlassOCAS();
00030
00036 CMulticlassOCAS(float64_t C, CDotFeatures* features, CLabels* labs);
00037
00039 virtual ~CMulticlassOCAS();
00040
00042 virtual const char* get_name() const
00043 {
00044 return "MulticlassOCAS";
00045 }
00046
00050 inline void set_C(float64_t C)
00051 {
00052 ASSERT(C>0);
00053 m_C = C;
00054 }
00058 inline float64_t get_C() const { return m_C; }
00059
00063 inline void set_epsilon(float64_t epsilon)
00064 {
00065 ASSERT(epsilon>0);
00066 m_epsilon = epsilon;
00067 }
00071 inline float64_t get_epsilon() const { return m_epsilon; }
00072
00076 inline void set_max_iter(int32_t max_iter)
00077 {
00078 ASSERT(max_iter>0);
00079 m_max_iter = max_iter;
00080 }
00084 inline int32_t get_max_iter() const { return m_max_iter; }
00085
00089 inline void set_method(int32_t method)
00090 {
00091 ASSERT(method==0 || method==1);
00092 m_method = method;
00093 }
00097 inline int32_t get_method() const { return m_method; }
00098
00102 inline void set_buf_size(int32_t buf_size)
00103 {
00104 ASSERT(buf_size>0);
00105 m_buf_size = buf_size;
00106 }
00110 inline int32_t get_buf_size() const { return m_buf_size; }
00111
00112 protected:
00113
00115 virtual bool train_machine(CFeatures* data = NULL);
00116
00118 static float64_t msvm_update_W(float64_t t, void* user_data);
00119
00121 static void msvm_full_compute_W(float64_t *sq_norm_W, float64_t *dp_WoldW,
00122 float64_t *alpha, uint32_t nSel, void* user_data);
00123
00125 static int msvm_full_add_new_cut(float64_t *new_col_H, uint32_t *new_cut,
00126 uint32_t nSel, void* user_data);
00127
00129 static int msvm_full_compute_output(float64_t *output, void* user_data);
00130
00132 static int msvm_sort_data(float64_t* vals, float64_t* data, uint32_t size);
00133
00135 static void msvm_print(ocas_return_value_T value);
00136
00137 private:
00138
00140 void register_parameters();
00141
00142 protected:
00143
00145 float64_t m_C;
00146
00148 float64_t m_epsilon;
00149
00151 int32_t m_max_iter;
00152
00154 int32_t m_method;
00155
00157 int32_t m_buf_size;
00158 };
00159 }
00160 #endif