KernelMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _KERNEL_MACHINE_H__
00012 #define _KERNEL_MACHINE_H__
00013 
00014 #include "lib/common.h"
00015 #include "lib/io.h"
00016 #include "kernel/Kernel.h"
00017 #include "features/Labels.h"
00018 #include "classifier/Classifier.h"
00019 
00020 #include <stdio.h>
00021 
00022 namespace shogun
00023 {
00024 class CClassifier;
00025 class CLabels;
00026 class CKernel;
00027 
00043 class CKernelMachine : public CClassifier
00044 {
00045     public:
00047         CKernelMachine();
00048 
00050         virtual ~CKernelMachine();
00051 
00057         virtual const char* get_name(void) const {
00058             return "KernelMachine"; }
00059 
00064         inline void set_kernel(CKernel* k)
00065         {
00066             SG_UNREF(kernel);
00067             SG_REF(k);
00068             kernel=k;
00069         }
00070 
00075         inline CKernel* get_kernel()
00076         {
00077             SG_REF(kernel);
00078             return kernel;
00079         }
00080 
00085         inline void set_batch_computation_enabled(bool enable)
00086         {
00087             use_batch_computation=enable;
00088         }
00089 
00094         inline bool get_batch_computation_enabled()
00095         {
00096             return use_batch_computation;
00097         }
00098 
00103         inline void set_linadd_enabled(bool enable)
00104         {
00105             use_linadd=enable;
00106         }
00107 
00112         inline bool get_linadd_enabled()
00113         {
00114             return use_linadd ;
00115         }
00116 
00121         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00122 
00127         inline bool get_bias_enabled() { return use_bias; }
00128 
00133         inline float64_t get_bias()
00134         {
00135             return m_bias;
00136         }
00137 
00142         inline void set_bias(float64_t bias)
00143         {
00144             m_bias=bias;
00145         }
00146 
00152         inline int32_t get_support_vector(int32_t idx)
00153         {
00154             ASSERT(m_svs && idx<num_svs);
00155             return m_svs[idx];
00156         }
00157 
00163         inline float64_t get_alpha(int32_t idx)
00164         {
00165             ASSERT(m_alpha && idx<num_svs);
00166             return m_alpha[idx];
00167         }
00168 
00175         inline bool set_support_vector(int32_t idx, int32_t val)
00176         {
00177             if (m_svs && idx<num_svs)
00178                 m_svs[idx]=val;
00179             else
00180                 return false;
00181 
00182             return true;
00183         }
00184 
00191         inline bool set_alpha(int32_t idx, float64_t val)
00192         {
00193             if (m_alpha && idx<num_svs)
00194                 m_alpha[idx]=val;
00195             else
00196                 return false;
00197 
00198             return true;
00199         }
00200 
00205         inline int32_t get_num_support_vectors()
00206         {
00207             return num_svs;
00208         }
00209 
00215         void set_alphas(float64_t* alphas, int32_t d)
00216         {
00217             ASSERT(alphas);
00218             ASSERT(m_alpha);
00219             ASSERT(d==num_svs);
00220 
00221             for(int32_t i=0; i<d; i++)
00222                 m_alpha[i]=alphas[i];
00223         }
00224 
00230         void set_support_vectors(int32_t* svs, int32_t d)
00231         {
00232             ASSERT(m_svs);
00233             ASSERT(svs);
00234             ASSERT(d==num_svs);
00235 
00236             for(int32_t i=0; i<d; i++)
00237                 m_svs[i]=svs[i];
00238         }
00239 
00245         void get_support_vectors(int32_t** svs, int32_t* num)
00246         {
00247             int32_t nsv = get_num_support_vectors();
00248 
00249             ASSERT(svs && num);
00250             *svs=NULL;
00251             *num=nsv;
00252 
00253             if (nsv>0)
00254             {
00255                 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv);
00256                 for(int32_t i=0; i<nsv; i++)
00257                     (*svs)[i] = get_support_vector(i);
00258             }
00259         }
00260 
00266         void get_alphas(float64_t** alphas, int32_t* d1)
00267         {
00268             int32_t nsv = get_num_support_vectors();
00269 
00270             ASSERT(alphas && d1);
00271             *alphas=NULL;
00272             *d1=nsv;
00273 
00274             if (nsv>0)
00275             {
00276                 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t));
00277                 for(int32_t i=0; i<nsv; i++)
00278                     (*alphas)[i] = get_alpha(i);
00279             }
00280         }
00281 
00286         inline bool create_new_model(int32_t num)
00287         {
00288             delete[] m_alpha;
00289             delete[] m_svs;
00290 
00291             m_bias=0;
00292             num_svs=num;
00293 
00294             if (num>0)
00295             {
00296                 m_alpha= new float64_t[num];
00297                 m_svs= new int32_t[num];
00298                 return (m_alpha!=NULL && m_svs!=NULL);
00299             }
00300             else
00301             {
00302                 m_alpha= NULL;
00303                 m_svs=NULL;
00304                 return true;
00305             }
00306         }
00307 
00312         bool init_kernel_optimization();
00313 
00318         virtual CLabels* classify();
00319 
00325         virtual CLabels* classify(CFeatures* data);
00326 
00332         virtual float64_t classify_example(int32_t num);
00333 
00339         static void* classify_example_helper(void* p);
00340 
00341     protected:
00343         CKernel* kernel;
00345         bool use_batch_computation;
00347         bool use_linadd;
00349         bool use_bias;
00351         float64_t m_bias;
00353         float64_t* m_alpha;
00355         int32_t* m_svs;
00357         int32_t num_svs;
00358 };
00359 }
00360 #endif /* _KERNEL_MACHINE_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation