00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2009 Alexander Binder 00008 * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef MKLMULTICLASS_H_ 00012 #define MKLMULTICLASS_H_ 00013 00014 #include <vector> 00015 00016 #include <shogun/base/SGObject.h> 00017 #include <shogun/kernel/Kernel.h> 00018 #include <shogun/kernel/CombinedKernel.h> 00019 #include <shogun/multiclass/GMNPSVM.h> 00020 #include <shogun/classifier/mkl/MKLMulticlassGLPK.h> 00021 #include <shogun/classifier/mkl/MKLMulticlassGradient.h> 00022 #include <shogun/multiclass/MulticlassSVM.h> 00023 00024 00025 namespace shogun 00026 { 00033 class CMKLMulticlass : public CMulticlassSVM 00034 { 00035 public: 00039 CMKLMulticlass(); 00045 CMKLMulticlass(float64_t C, CKernel* k, CLabels* lab); 00046 00047 00051 virtual ~CMKLMulticlass(); 00052 00057 virtual EMachineType get_classifier_type() 00058 { return CT_MKLMULTICLASS; } 00059 00060 00069 float64_t* getsubkernelweights(int32_t & numweights); 00070 00078 void set_mkl_epsilon(float64_t eps ); 00079 00087 void set_max_num_mkliters(int32_t maxnum); 00088 00092 virtual void set_mkl_norm(float64_t norm); 00093 00094 00095 protected: 00100 CMKLMulticlass( const CMKLMulticlass & cm); 00105 CMKLMulticlass operator=( const CMKLMulticlass & cm); 00106 00111 void initlpsolver(); 00112 00116 void initsvm(); 00117 00118 00119 00120 00126 virtual bool evaluatefinishcriterion(const int32_t 00127 numberofsilpiterations); 00128 00129 00139 void addingweightsstep( const std::vector<float64_t> & curweights); 00144 float64_t getsumofsignfreealphas(); 00151 float64_t getsquarenormofprimalcoefficients( 00152 const int32_t ind); 00153 00154 00163 virtual bool train_machine(CFeatures* data=NULL); 00164 00166 virtual const char* get_name() const { return "MKLMulticlass"; } 00167 00168 protected: 00173 CGMNPSVM* svm; 00177 MKLMulticlassOptimizationBase* lpw; 00181 ::std::vector< std::vector< float64_t> > weightshistory; 00182 00186 float64_t mkl_eps; 00190 int32_t max_num_mkl_iters; 00194 float64_t pnorm; 00198 std::vector<float64_t> normweightssquared; 00199 00200 }; 00201 } 00202 #endif // GMNPMKL_H_