00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _KERNEL_MACHINE_H__ 00012 #define _KERNEL_MACHINE_H__ 00013 00014 #include "lib/common.h" 00015 #include "lib/io.h" 00016 #include "kernel/Kernel.h" 00017 #include "features/Labels.h" 00018 #include "classifier/Classifier.h" 00019 00020 #include <stdio.h> 00021 00022 namespace shogun 00023 { 00024 class CClassifier; 00025 class CLabels; 00026 class CKernel; 00027 00043 class CKernelMachine : public CClassifier 00044 { 00045 public: 00047 CKernelMachine(); 00048 00050 virtual ~CKernelMachine(); 00051 00057 virtual const char* get_name(void) const { 00058 return "KernelMachine"; } 00059 00064 inline void set_kernel(CKernel* k) 00065 { 00066 SG_UNREF(kernel); 00067 SG_REF(k); 00068 kernel=k; 00069 } 00070 00075 inline CKernel* get_kernel() 00076 { 00077 SG_REF(kernel); 00078 return kernel; 00079 } 00080 00085 inline void set_batch_computation_enabled(bool enable) 00086 { 00087 use_batch_computation=enable; 00088 } 00089 00094 inline bool get_batch_computation_enabled() 00095 { 00096 return use_batch_computation; 00097 } 00098 00103 inline void set_linadd_enabled(bool enable) 00104 { 00105 use_linadd=enable; 00106 } 00107 00112 inline bool get_linadd_enabled() 00113 { 00114 return use_linadd ; 00115 } 00116 00121 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00122 00127 inline bool get_bias_enabled() { return use_bias; } 00128 00133 inline float64_t get_bias() 00134 { 00135 return m_bias; 00136 } 00137 00142 inline void set_bias(float64_t bias) 00143 { 00144 m_bias=bias; 00145 } 00146 00152 inline int32_t get_support_vector(int32_t idx) 00153 { 00154 ASSERT(m_svs && idx<num_svs); 00155 return m_svs[idx]; 00156 } 00157 00163 inline float64_t get_alpha(int32_t idx) 00164 { 00165 ASSERT(m_alpha && idx<num_svs); 00166 return m_alpha[idx]; 00167 } 00168 00175 inline bool set_support_vector(int32_t idx, int32_t val) 00176 { 00177 if (m_svs && idx<num_svs) 00178 m_svs[idx]=val; 00179 else 00180 return false; 00181 00182 return true; 00183 } 00184 00191 inline bool set_alpha(int32_t idx, float64_t val) 00192 { 00193 if (m_alpha && idx<num_svs) 00194 m_alpha[idx]=val; 00195 else 00196 return false; 00197 00198 return true; 00199 } 00200 00205 inline int32_t get_num_support_vectors() 00206 { 00207 return num_svs; 00208 } 00209 00215 void set_alphas(float64_t* alphas, int32_t d) 00216 { 00217 ASSERT(alphas); 00218 ASSERT(m_alpha); 00219 ASSERT(d==num_svs); 00220 00221 for(int32_t i=0; i<d; i++) 00222 m_alpha[i]=alphas[i]; 00223 } 00224 00230 void set_support_vectors(int32_t* svs, int32_t d) 00231 { 00232 ASSERT(m_svs); 00233 ASSERT(svs); 00234 ASSERT(d==num_svs); 00235 00236 for(int32_t i=0; i<d; i++) 00237 m_svs[i]=svs[i]; 00238 } 00239 00245 void get_support_vectors(int32_t** svs, int32_t* num) 00246 { 00247 int32_t nsv = get_num_support_vectors(); 00248 00249 ASSERT(svs && num); 00250 *svs=NULL; 00251 *num=nsv; 00252 00253 if (nsv>0) 00254 { 00255 *svs = (int32_t*) malloc(sizeof(int32_t)*nsv); 00256 for(int32_t i=0; i<nsv; i++) 00257 (*svs)[i] = get_support_vector(i); 00258 } 00259 } 00260 00266 void get_alphas(float64_t** alphas, int32_t* d1) 00267 { 00268 int32_t nsv = get_num_support_vectors(); 00269 00270 ASSERT(alphas && d1); 00271 *alphas=NULL; 00272 *d1=nsv; 00273 00274 if (nsv>0) 00275 { 00276 *alphas = (float64_t*) malloc(nsv*sizeof(float64_t)); 00277 for(int32_t i=0; i<nsv; i++) 00278 (*alphas)[i] = get_alpha(i); 00279 } 00280 } 00281 00286 inline bool create_new_model(int32_t num) 00287 { 00288 delete[] m_alpha; 00289 delete[] m_svs; 00290 00291 m_bias=0; 00292 num_svs=num; 00293 00294 if (num>0) 00295 { 00296 m_alpha= new float64_t[num]; 00297 m_svs= new int32_t[num]; 00298 return (m_alpha!=NULL && m_svs!=NULL); 00299 } 00300 else 00301 { 00302 m_alpha= NULL; 00303 m_svs=NULL; 00304 return true; 00305 } 00306 } 00307 00312 bool init_kernel_optimization(); 00313 00318 virtual CLabels* classify(); 00319 00325 virtual CLabels* classify(CFeatures* data); 00326 00332 virtual float64_t classify_example(int32_t num); 00333 00339 static void* classify_example_helper(void* p); 00340 00341 protected: 00343 CKernel* kernel; 00345 bool use_batch_computation; 00347 bool use_linadd; 00349 bool use_bias; 00351 float64_t m_bias; 00353 float64_t* m_alpha; 00355 int32_t* m_svs; 00357 int32_t num_svs; 00358 }; 00359 } 00360 #endif /* _KERNEL_MACHINE_H__ */