00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2012 Jacob Walker 00008 * 00009 * Code adapted from CCombinedKernel 00010 */ 00011 00012 #ifndef _PRODUCTKERNEL_H___ 00013 #define _PRODUCTKERNEL_H___ 00014 00015 #include <shogun/lib/List.h> 00016 #include <shogun/io/SGIO.h> 00017 #include <shogun/kernel/Kernel.h> 00018 00019 #include <shogun/features/Features.h> 00020 #include <shogun/features/CombinedFeatures.h> 00021 00022 namespace shogun 00023 { 00024 class CFeatures; 00025 class CCombinedFeatures; 00026 class CList; 00027 class CListElement; 00042 class CProductKernel : public CKernel 00043 { 00044 public: 00050 CProductKernel(int32_t size=10); 00051 00052 virtual ~CProductKernel(); 00053 00060 virtual bool init(CFeatures* lhs, CFeatures* rhs); 00061 00063 virtual void cleanup(); 00064 00069 virtual EKernelType get_kernel_type() 00070 { 00071 return K_PRODUCT; 00072 } 00073 00078 virtual EFeatureType get_feature_type() 00079 { 00080 return F_UNKNOWN; 00081 } 00082 00087 virtual EFeatureClass get_feature_class() 00088 { 00089 return C_COMBINED; 00090 } 00091 00096 virtual const char* get_name() const { return "ProductKernel"; } 00097 00099 void list_kernels(); 00100 00105 inline CKernel* get_first_kernel() 00106 { 00107 return (CKernel*) kernel_list->get_first_element(); 00108 } 00109 00115 inline CKernel* get_first_kernel(CListElement*& current) 00116 { 00117 return (CKernel*) kernel_list->get_first_element(current); 00118 } 00119 00125 inline CKernel* get_kernel(int32_t idx) 00126 { 00127 CKernel * k = get_first_kernel(); 00128 for (int32_t i=0; i<idx; i++) 00129 { 00130 SG_UNREF(k); 00131 k = get_next_kernel(); 00132 } 00133 return k; 00134 } 00135 00140 inline CKernel* get_last_kernel() 00141 { 00142 return (CKernel*) kernel_list->get_last_element(); 00143 } 00144 00149 inline CKernel* get_next_kernel() 00150 { 00151 return (CKernel*) kernel_list->get_next_element(); 00152 } 00153 00159 inline CKernel* get_next_kernel(CListElement*& current) 00160 { 00161 return (CKernel*) kernel_list->get_next_element(current); 00162 } 00163 00169 inline bool insert_kernel(CKernel* k) 00170 { 00171 ASSERT(k); 00172 adjust_num_lhs_rhs_initialized(k); 00173 00174 if (!(k->has_property(KP_LINADD))) 00175 unset_property(KP_LINADD); 00176 00177 return kernel_list->insert_element(k); 00178 } 00179 00185 inline bool append_kernel(CKernel* k) 00186 { 00187 ASSERT(k); 00188 adjust_num_lhs_rhs_initialized(k); 00189 00190 if (!(k->has_property(KP_LINADD))) 00191 unset_property(KP_LINADD); 00192 00193 return kernel_list->append_element(k); 00194 } 00195 00196 00201 inline bool delete_kernel() 00202 { 00203 CKernel* k=(CKernel*) kernel_list->delete_element(); 00204 SG_UNREF(k); 00205 00206 if (!k) 00207 { 00208 num_lhs=0; 00209 num_rhs=0; 00210 } 00211 00212 return (k!=NULL); 00213 } 00214 00215 00220 inline int32_t get_num_subkernels() 00221 { 00222 return kernel_list->get_num_elements(); 00223 } 00224 00229 virtual bool has_features() 00230 { 00231 return initialized; 00232 } 00233 00235 virtual void remove_lhs(); 00236 00238 virtual void remove_rhs(); 00239 00241 virtual void remove_lhs_and_rhs(); 00242 00244 bool precompute_subkernels(); 00245 00249 CProductKernel* KernelToProductKernel(shogun::CKernel* n) 00250 { 00251 return dynamic_cast<CProductKernel*>(n); 00252 } 00253 00262 SGMatrix<float64_t> get_parameter_gradient(TParameter* param, 00263 CSGObject* obj, index_t index); 00264 00269 inline CList* get_list() {SG_REF(kernel_list); return kernel_list;} 00270 00271 protected: 00278 virtual float64_t compute(int32_t x, int32_t y); 00279 00285 inline void adjust_num_lhs_rhs_initialized(CKernel* k) 00286 { 00287 ASSERT(k); 00288 00289 if (k->get_num_vec_lhs()) 00290 { 00291 if (num_lhs) 00292 ASSERT(num_lhs==k->get_num_vec_lhs()); 00293 num_lhs=k->get_num_vec_lhs(); 00294 00295 if (!get_num_subkernels()) 00296 { 00297 initialized=true; 00298 #ifdef USE_SVMLIGHT 00299 cache_reset(); 00300 #endif //USE_SVMLIGHT 00301 } 00302 } 00303 else 00304 initialized=false; 00305 00306 if (k->get_num_vec_rhs()) 00307 { 00308 if (num_rhs) 00309 ASSERT(num_rhs==k->get_num_vec_rhs()); 00310 num_rhs=k->get_num_vec_rhs(); 00311 00312 if (!get_num_subkernels()) 00313 { 00314 initialized=true; 00315 #ifdef USE_SVMLIGHT 00316 cache_reset(); 00317 #endif //USE_SVMLIGHT 00318 } 00319 } 00320 else 00321 initialized=false; 00322 } 00323 00324 private: 00325 void init(); 00326 00327 protected: 00329 CList* kernel_list; 00331 bool initialized; 00332 }; 00333 } 00334 #endif /* _PRODUCTKERNEL_H__ */