00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2011 Soeren Sonnenburg 00008 * Written (W) 2012 Soeren Sonnenburg, Chiyuan Zhang 00009 * Copyright (C) 1999-2011 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _MULTICLASSSVM_H___ 00013 #define _MULTICLASSSVM_H___ 00014 00015 #include <shogun/lib/common.h> 00016 #include <shogun/features/Features.h> 00017 #include <shogun/classifier/svm/SVM.h> 00018 #include <shogun/machine/KernelMulticlassMachine.h> 00019 00020 namespace shogun 00021 { 00022 00023 class CSVM; 00024 00026 class CMulticlassSVM : public CKernelMulticlassMachine 00027 { 00028 public: 00030 MACHINE_PROBLEM_TYPE(PT_MULTICLASS); 00031 00033 CMulticlassSVM(); 00034 00039 CMulticlassSVM(CMulticlassStrategy *strategy); 00040 00048 CMulticlassSVM( 00049 CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab); 00050 virtual ~CMulticlassSVM(); 00051 00059 bool create_multiclass_svm(int32_t num_classes); 00060 00067 bool set_svm(int32_t num, CSVM* svm); 00068 00074 CSVM* get_svm(int32_t num) 00075 { 00076 return dynamic_cast<CSVM *>(m_machines->get_element_safe(num)); 00077 } 00078 00082 bool load(FILE* svm_file); 00083 00087 bool save(FILE* svm_file); 00088 00089 // TODO remove if unnecessary here 00093 SGVector<float64_t> get_linear_term() { return svm_proto()->get_linear_term(); } 00094 // TODO remove if unnecessary here 00098 float64_t get_tube_epsilon() { return svm_proto()->get_tube_epsilon(); } 00099 // TODO remove if unnecessary here 00103 float64_t get_epsilon() { return svm_proto()->get_epsilon(); } 00104 // TODO remove if unnecessary here 00108 float64_t get_nu() { return svm_proto()->get_nu(); } 00109 // TODO remove if unnecessary here 00113 float64_t get_C() { return m_C; } 00114 // TODO remove if unnecessary here 00118 int32_t get_qpsize() { return svm_proto()->get_qpsize(); } 00119 // TODO remove if unnecessary here 00123 bool get_shrinking_enabled() { return svm_proto()->get_shrinking_enabled(); } 00124 // TODO remove if unnecessary here 00128 float64_t get_objective() { return svm_proto()->get_objective(); } 00129 00130 // TODO remove if unnecessary here 00134 bool get_bias_enabled() { return svm_proto()->get_bias_enabled(); } 00135 // TODO remove if unnecessary here 00139 bool get_linadd_enabled() { return svm_proto()->get_linadd_enabled(); } 00140 // TODO remove if unnecessary here 00144 bool get_batch_computation_enabled() { return svm_proto()->get_batch_computation_enabled(); } 00145 00146 // TODO remove in unnecessary here 00150 void set_defaults(int32_t num_sv=0) { svm_proto()->set_defaults(num_sv); } 00151 // TODO remove in unnecessary here 00155 void set_linear_term(SGVector<float64_t> linear_term) { svm_proto()->set_linear_term(linear_term); } 00156 // TODO remove in unnecessary here 00160 void set_C(float64_t C) { svm_proto()->set_C(C,C); m_C = C; } 00161 // TODO remove in unnecessary here 00165 void set_epsilon(float64_t eps) { svm_proto()->set_epsilon(eps); } 00166 // TODO remove in unnecessary here 00170 void set_nu(float64_t nue) { svm_proto()->set_nu(nue); } 00171 // TODO remove in unnecessary here 00175 void set_tube_epsilon(float64_t eps) { svm_proto()->set_tube_epsilon(eps); } 00176 // TODO remove in unnecessary here 00180 void set_qpsize(int32_t qps) { svm_proto()->set_qpsize(qps); } 00181 // TODO remove in unnecessary here 00185 void set_shrinking_enabled(bool enable) { svm_proto()->set_shrinking_enabled(enable); } 00186 // TODO remove in unnecessary here 00190 void set_objective(float64_t v) { svm_proto()->set_objective(v); } 00191 // TODO remove in unnecessary here 00195 void set_bias_enabled(bool enable_bias) { svm_proto()->set_bias_enabled(enable_bias); } 00196 // TODO remove in unnecessary here 00200 void set_linadd_enabled(bool enable) { svm_proto()->set_linadd_enabled(enable); } 00201 // TODO remove in unnecessary here 00205 void set_batch_computation_enabled(bool enable) { svm_proto()->set_batch_computation_enabled(enable); } 00206 00207 protected: 00208 00210 CSVM *svm_proto() 00211 { 00212 return dynamic_cast<CSVM*>(m_machine); 00213 } 00215 SGVector<int32_t> svm_svs() 00216 { 00217 return svm_proto()->m_svs; 00218 } 00219 00221 virtual bool init_machines_for_apply(CFeatures* data); 00222 00224 virtual bool is_acceptable_machine(CMachine *machine) 00225 { 00226 CSVM *svm = dynamic_cast<CSVM*>(machine); 00227 if (svm == NULL) 00228 return false; 00229 return true; 00230 } 00231 00232 private: 00233 00234 void init(); 00235 00236 protected: 00237 00239 float64_t m_C; 00240 }; 00241 } 00242 #endif