KernelMachine.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _KERNEL_MACHINE_H__
00012 #define _KERNEL_MACHINE_H__
00013 
00014 #include <shogun/lib/common.h>
00015 #include <shogun/io/SGIO.h>
00016 #include <shogun/kernel/Kernel.h>
00017 #include <shogun/features/Labels.h>
00018 #include <shogun/machine/Machine.h>
00019 
00020 #include <stdio.h>
00021 
00022 namespace shogun
00023 {
00024 class CMachine;
00025 class CLabels;
00026 class CKernel;
00027 
00043 class CKernelMachine : public CMachine
00044 {
00045     public:
00047         CKernelMachine();
00048 
00050         virtual ~CKernelMachine();
00051 
00057         virtual const char* get_name(void) const {
00058             return "KernelMachine"; }
00059 
00064         inline void set_kernel(CKernel* k)
00065         {
00066             SG_UNREF(kernel);
00067             SG_REF(k);
00068             kernel=k;
00069         }
00070 
00075         inline CKernel* get_kernel()
00076         {
00077             SG_REF(kernel);
00078             return kernel;
00079         }
00080 
00085         inline void set_batch_computation_enabled(bool enable)
00086         {
00087             use_batch_computation=enable;
00088         }
00089 
00094         inline bool get_batch_computation_enabled()
00095         {
00096             return use_batch_computation;
00097         }
00098 
00103         inline void set_linadd_enabled(bool enable)
00104         {
00105             use_linadd=enable;
00106         }
00107 
00112         inline bool get_linadd_enabled()
00113         {
00114             return use_linadd ;
00115         }
00116 
00121         inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; }
00122 
00127         inline bool get_bias_enabled() { return use_bias; }
00128 
00133         inline float64_t get_bias()
00134         {
00135             return m_bias;
00136         }
00137 
00142         inline void set_bias(float64_t bias)
00143         {
00144             m_bias=bias;
00145         }
00146 
00152         inline int32_t get_support_vector(int32_t idx)
00153         {
00154             ASSERT(m_svs.vector && idx<m_svs.vlen);
00155             return m_svs.vector[idx];
00156         }
00157 
00163         inline float64_t get_alpha(int32_t idx)
00164         {
00165             if (!m_alpha.vector)
00166                 SG_ERROR("No alphas set\n");
00167             if (idx>=m_alpha.vlen)
00168                 SG_ERROR("Alphas index (%d) out of range (%d)\n", idx, m_svs.vlen);
00169             return m_alpha.vector[idx];
00170         }
00171 
00178         inline bool set_support_vector(int32_t idx, int32_t val)
00179         {
00180             if (m_svs.vector && idx<m_svs.vlen)
00181                 m_svs.vector[idx]=val;
00182             else
00183                 return false;
00184 
00185             return true;
00186         }
00187 
00194         inline bool set_alpha(int32_t idx, float64_t val)
00195         {
00196             if (m_alpha.vector && idx<m_alpha.vlen)
00197                 m_alpha.vector[idx]=val;
00198             else
00199                 return false;
00200 
00201             return true;
00202         }
00203 
00208         inline int32_t get_num_support_vectors()
00209         {
00210             return m_svs.vlen;
00211         }
00212 
00217         void set_alphas(SGVector<float64_t> alphas)
00218         {
00219             m_alpha = alphas;
00220         }
00221 
00226         void set_support_vectors(SGVector<int32_t> svs)
00227         {
00228             m_svs = svs;
00229         }
00230 
00234         SGVector<int32_t> get_support_vectors()
00235         {
00236             int32_t nsv = get_num_support_vectors();
00237             int32_t* svs = NULL;
00238 
00239             if (nsv>0)
00240             {
00241                 svs = SG_MALLOC(int32_t, nsv);
00242                 for(int32_t i=0; i<nsv; i++)
00243                     svs[i] = get_support_vector(i);
00244             }
00245 
00246             return SGVector<int32_t>(svs,nsv);
00247         }
00248 
00252         SGVector<float64_t> get_alphas()
00253         {
00254             int32_t nsv = get_num_support_vectors();
00255             float64_t* alphas = NULL;
00256 
00257             if (nsv>0)
00258             {
00259                 alphas = SG_MALLOC(float64_t, nsv);
00260                 for(int32_t i=0; i<nsv; i++)
00261                     alphas[i] = get_alpha(i);
00262             }
00263 
00264             return SGVector<float64_t>(alphas,nsv);
00265         }
00266 
00271         inline bool create_new_model(int32_t num)
00272         {
00273             m_alpha.destroy_vector();
00274             m_svs.destroy_vector();
00275 
00276             m_bias=0;
00277 
00278             if (num>0)
00279             {
00280                 m_alpha= SGVector<float64_t>(num);
00281                 m_svs= SGVector<int32_t>(num);
00282                 return (m_alpha.vector!=NULL && m_svs.vector!=NULL);
00283             }
00284             else
00285                 return true;
00286         }
00287 
00292         bool init_kernel_optimization();
00293 
00298         virtual CLabels* apply();
00299 
00305         virtual CLabels* apply(CFeatures* data);
00306 
00312         virtual float64_t apply(int32_t num);
00313 
00319         static void* apply_helper(void* p);
00320 
00321     protected:
00325         virtual void store_model_features();
00326 
00327     protected:
00329         CKernel* kernel;
00331         bool use_batch_computation;
00333         bool use_linadd;
00335         bool use_bias;
00337         float64_t m_bias;
00338 
00340         SGVector<float64_t> m_alpha;
00341 
00343         SGVector<int32_t> m_svs;
00344 };
00345 }
00346 #endif /* _KERNEL_MACHINE_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation