00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _KERNEL_MACHINE_H__ 00012 #define _KERNEL_MACHINE_H__ 00013 00014 #include <shogun/lib/common.h> 00015 #include <shogun/io/SGIO.h> 00016 #include <shogun/kernel/Kernel.h> 00017 #include <shogun/features/Labels.h> 00018 #include <shogun/machine/Machine.h> 00019 00020 #include <stdio.h> 00021 00022 namespace shogun 00023 { 00024 class CMachine; 00025 class CLabels; 00026 class CKernel; 00027 00043 class CKernelMachine : public CMachine 00044 { 00045 public: 00047 CKernelMachine(); 00048 00050 virtual ~CKernelMachine(); 00051 00057 virtual const char* get_name() const { 00058 return "KernelMachine"; } 00059 00064 inline void set_kernel(CKernel* k) 00065 { 00066 SG_UNREF(kernel); 00067 SG_REF(k); 00068 kernel=k; 00069 } 00070 00075 inline CKernel* get_kernel() 00076 { 00077 SG_REF(kernel); 00078 return kernel; 00079 } 00080 00085 inline void set_batch_computation_enabled(bool enable) 00086 { 00087 use_batch_computation=enable; 00088 } 00089 00094 inline bool get_batch_computation_enabled() 00095 { 00096 return use_batch_computation; 00097 } 00098 00103 inline void set_linadd_enabled(bool enable) 00104 { 00105 use_linadd=enable; 00106 } 00107 00112 inline bool get_linadd_enabled() 00113 { 00114 return use_linadd ; 00115 } 00116 00121 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00122 00127 inline bool get_bias_enabled() { return use_bias; } 00128 00133 inline float64_t get_bias() 00134 { 00135 return m_bias; 00136 } 00137 00142 inline void set_bias(float64_t bias) 00143 { 00144 m_bias=bias; 00145 } 00146 00152 inline int32_t get_support_vector(int32_t idx) 00153 { 00154 ASSERT(m_svs.vector && idx<m_svs.vlen); 00155 return m_svs.vector[idx]; 00156 } 00157 00163 inline float64_t get_alpha(int32_t idx) 00164 { 00165 if (!m_alpha.vector) 00166 SG_ERROR("No alphas set\n"); 00167 if (idx>=m_alpha.vlen) 00168 SG_ERROR("Alphas index (%d) out of range (%d)\n", idx, m_svs.vlen); 00169 return m_alpha.vector[idx]; 00170 } 00171 00178 inline bool set_support_vector(int32_t idx, int32_t val) 00179 { 00180 if (m_svs.vector && idx<m_svs.vlen) 00181 m_svs.vector[idx]=val; 00182 else 00183 return false; 00184 00185 return true; 00186 } 00187 00194 inline bool set_alpha(int32_t idx, float64_t val) 00195 { 00196 if (m_alpha.vector && idx<m_alpha.vlen) 00197 m_alpha.vector[idx]=val; 00198 else 00199 return false; 00200 00201 return true; 00202 } 00203 00208 inline int32_t get_num_support_vectors() 00209 { 00210 return m_svs.vlen; 00211 } 00212 00217 void set_alphas(SGVector<float64_t> alphas) 00218 { 00219 m_alpha = alphas; 00220 } 00221 00226 void set_support_vectors(SGVector<int32_t> svs) 00227 { 00228 m_svs = svs; 00229 } 00230 00234 SGVector<int32_t> get_support_vectors() 00235 { 00236 int32_t nsv = get_num_support_vectors(); 00237 int32_t* svs = NULL; 00238 00239 if (nsv>0) 00240 { 00241 svs = SG_MALLOC(int32_t, nsv); 00242 for(int32_t i=0; i<nsv; i++) 00243 svs[i] = get_support_vector(i); 00244 } 00245 00246 return SGVector<int32_t>(svs,nsv); 00247 } 00248 00252 SGVector<float64_t> get_alphas() 00253 { 00254 int32_t nsv = get_num_support_vectors(); 00255 float64_t* alphas = NULL; 00256 00257 if (nsv>0) 00258 { 00259 alphas = SG_MALLOC(float64_t, nsv); 00260 for(int32_t i=0; i<nsv; i++) 00261 alphas[i] = get_alpha(i); 00262 } 00263 00264 return SGVector<float64_t>(alphas,nsv); 00265 } 00266 00271 inline bool create_new_model(int32_t num) 00272 { 00273 m_alpha.destroy_vector(); 00274 m_svs.destroy_vector(); 00275 00276 m_bias=0; 00277 00278 if (num>0) 00279 { 00280 m_alpha= SGVector<float64_t>(num); 00281 m_svs= SGVector<int32_t>(num); 00282 return (m_alpha.vector!=NULL && m_svs.vector!=NULL); 00283 } 00284 else 00285 return true; 00286 } 00287 00292 bool init_kernel_optimization(); 00293 00298 virtual CLabels* apply(); 00299 00305 virtual CLabels* apply(CFeatures* data); 00306 00312 virtual float64_t apply(int32_t num); 00313 00319 static void* apply_helper(void* p); 00320 00321 protected: 00325 virtual void store_model_features(); 00326 00327 protected: 00329 CKernel* kernel; 00331 bool use_batch_computation; 00333 bool use_linadd; 00335 bool use_bias; 00337 float64_t m_bias; 00338 00340 SGVector<float64_t> m_alpha; 00341 00343 SGVector<int32_t> m_svs; 00344 }; 00345 } 00346 #endif /* _KERNEL_MACHINE_H__ */