CustomKernel.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _CUSTOMKERNEL_H___
00012 #define _CUSTOMKERNEL_H___
00013 
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/lib/common.h>
00016 #include <shogun/kernel/Kernel.h>
00017 #include <shogun/features/Features.h>
00018 
00019 namespace shogun
00020 {
00029 class CCustomKernel: public CKernel
00030 {
00031     void init(void);
00032 
00033     public:
00035         CCustomKernel();
00036 
00042         CCustomKernel(CKernel* k);
00043 
00051         CCustomKernel(SGMatrix<float64_t> km);
00052 
00056         virtual ~CCustomKernel();
00057 
00068         virtual bool dummy_init(int32_t rows, int32_t cols);
00069 
00076         virtual bool init(CFeatures* l, CFeatures* r);
00077 
00079         virtual void cleanup();
00080 
00085         inline virtual EKernelType get_kernel_type() { return K_CUSTOM; }
00086 
00091         inline virtual EFeatureType get_feature_type() { return F_ANY; }
00092 
00097         inline virtual EFeatureClass get_feature_class() { return C_ANY; }
00098 
00103         virtual const char* get_name() const { return "CustomKernel"; }
00104 
00114         bool set_triangle_kernel_matrix_from_triangle(
00115             SGVector<float64_t> tri_kernel_matrix)
00116         {
00117             return set_triangle_kernel_matrix_from_triangle_generic(tri_kernel_matrix);
00118         }
00119 
00129         template <class T>
00130         bool set_triangle_kernel_matrix_from_triangle_generic(
00131             SGVector<T> tri_kernel_matrix)
00132         {
00133             ASSERT(tri_kernel_matrix.vector);
00134 
00135             int64_t len = tri_kernel_matrix.vlen;
00136             int64_t cols = (int64_t) floor(-0.5 + CMath::sqrt(0.25+2*len));
00137 
00138             if (cols*(cols+1)/2 != len)
00139             {
00140                 SG_ERROR("km should be a vector containing a lower triangle matrix, with len=cols*(cols+1)/2 elements\n");
00141                 return false;
00142             }
00143 
00144             cleanup_custom();
00145             SG_DEBUG( "using custom kernel of size %dx%d\n", cols,cols);
00146 
00147             kmatrix.matrix = SG_MALLOC(float32_t, len);
00148             kmatrix.num_rows=cols;
00149             kmatrix.num_cols=cols;
00150             upper_diagonal=true;
00151 
00152             for (int64_t i=0; i<len; i++)
00153                 kmatrix.matrix[i]=tri_kernel_matrix.vector[i];
00154 
00155             dummy_init(cols,cols);
00156             return true;
00157         }
00158 
00166         inline bool set_triangle_kernel_matrix_from_full(
00167             SGMatrix<float64_t> full_kernel_matrix)
00168         {
00169             return set_triangle_kernel_matrix_from_full_generic(full_kernel_matrix);
00170         }
00171 
00177         template <class T>
00178         bool set_triangle_kernel_matrix_from_full_generic(
00179             SGMatrix<T> full_kernel_matrix)
00180         {
00181             int32_t rows = full_kernel_matrix.num_rows;
00182             int32_t cols = full_kernel_matrix.num_cols;
00183             ASSERT(rows==cols);
00184 
00185             cleanup_custom();
00186             SG_DEBUG( "using custom kernel of size %dx%d\n", cols,cols);
00187 
00188             kmatrix.matrix = SG_MALLOC(float32_t, int64_t(rows)*cols);
00189             kmatrix.num_rows = rows;
00190             kmatrix.num_cols = cols;
00191             upper_diagonal = false;
00192 
00193             for (int64_t row=0; row<rows; row++)
00194             {
00195                 for (int64_t col=row; col<cols; col++)
00196                 {
00197                     int64_t idx=row * cols - row*(row+1)/2 + col;
00198                     kmatrix.matrix[idx] = full_kernel_matrix.matrix[col*rows+row];
00199                 }
00200             }
00201 
00202             dummy_init(rows, cols);
00203             return true;
00204         }
00205 
00212         bool set_full_kernel_matrix_from_full(
00213             SGMatrix<float32_t> full_kernel_matrix)
00214         {
00215             cleanup_custom();
00216             kmatrix.matrix = full_kernel_matrix.matrix;
00217             kmatrix.num_rows=full_kernel_matrix.num_rows;
00218             kmatrix.num_cols=full_kernel_matrix.num_cols;
00219             dummy_init(kmatrix.num_rows, kmatrix.num_cols);
00220             return true;
00221         }
00222 
00229         bool set_full_kernel_matrix_from_full(
00230             SGMatrix<float64_t> full_kernel_matrix)
00231         {
00232             cleanup_custom();
00233             int32_t rows=full_kernel_matrix.num_rows;
00234             int32_t cols=full_kernel_matrix.num_cols;
00235             SG_DEBUG( "using custom kernel of size %dx%d\n", rows,cols);
00236 
00237             kmatrix.matrix = SG_MALLOC(float32_t, int64_t(rows)*cols);
00238             kmatrix.num_rows = rows;
00239             kmatrix.num_cols = cols;
00240             upper_diagonal = false;
00241 
00242             for (int32_t row=0; row<rows; row++)
00243             {
00244                 for (int32_t col=0; col<cols; col++)
00245                     kmatrix.matrix[int64_t(row) * cols + col] =
00246                             full_kernel_matrix.matrix[int64_t(col)*rows+row];
00247             }
00248 
00249             dummy_init(rows, cols);
00250             return true;
00251         }
00252 
00257         virtual inline int32_t get_num_vec_lhs()
00258         {
00259             return kmatrix.num_rows;
00260         }
00261 
00266         virtual inline int32_t get_num_vec_rhs()
00267         {
00268             return kmatrix.num_cols;
00269         }
00270 
00275         virtual inline bool has_features()
00276         {
00277             return (kmatrix.num_rows>0) && (kmatrix.num_cols>0);
00278         }
00279 
00280     protected:
00281 
00288         inline virtual float64_t compute(int32_t row, int32_t col)
00289         {
00290             ASSERT(kmatrix.matrix);
00291 
00292             if (upper_diagonal)
00293             {
00294                 if (row <= col)
00295                 {
00296                     int64_t r=row;
00297                     return kmatrix.matrix[r*kmatrix.num_rows - r*(r+1)/2 + col];
00298                 }
00299                 else
00300                 {
00301                     int64_t c=col;
00302                     return kmatrix.matrix[c*kmatrix.num_cols - c*(c+1)/2 + row];
00303                 }
00304             }
00305             else
00306             {
00307                 int64_t r=row;
00308                 return kmatrix.matrix[r*kmatrix.num_cols+col];
00309             }
00310         }
00311 
00312     private:
00313 
00315         void cleanup_custom();
00316 
00317     protected:
00318 
00320         SGMatrix<float32_t> kmatrix;
00321 
00323         bool upper_diagonal;
00324 };
00325 
00326 }
00327 #endif /* _CUSTOMKERNEL_H__ */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation