SqrtDiagKernelNormalizer.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _SQRTDIAGKERNELNORMALIZER_H___
00012 #define _SQRTDIAGKERNELNORMALIZER_H___
00013 
00014 #include <shogun/kernel/KernelNormalizer.h>
00015 #include <shogun/kernel/CommWordStringKernel.h>
00016 
00017 namespace shogun
00018 {
00029 class CSqrtDiagKernelNormalizer : public CKernelNormalizer
00030 {
00031     public:
00036         CSqrtDiagKernelNormalizer(bool use_opt_diag=false): CKernelNormalizer(),
00037             sqrtdiag_lhs(NULL), num_sqrtdiag_lhs(0),
00038             sqrtdiag_rhs(NULL), num_sqrtdiag_rhs(0),
00039             use_optimized_diagonal_computation(use_opt_diag)
00040         {
00041             m_parameters->add_vector(&sqrtdiag_lhs, &num_sqrtdiag_lhs, "sqrtdiag_lhs",
00042                               "sqrt(K(x,x)) for left hand side examples.");
00043             m_parameters->add_vector(&sqrtdiag_rhs, &num_sqrtdiag_rhs, "sqrtdiag_rhs",
00044                               "sqrt(K(x,x)) for right hand side examples.");
00045             m_parameters->add(&use_optimized_diagonal_computation,
00046                     "use_optimized_diagonal_computation",
00047                     "flat if optimized diagonal computation is used");
00048         }
00049 
00051         virtual ~CSqrtDiagKernelNormalizer()
00052         {
00053             SG_FREE(sqrtdiag_lhs);
00054             SG_FREE(sqrtdiag_rhs);
00055         }
00056 
00059         virtual bool init(CKernel* k)
00060         {
00061             ASSERT(k);
00062             num_sqrtdiag_lhs=k->get_num_vec_lhs();
00063             num_sqrtdiag_rhs=k->get_num_vec_rhs();
00064             ASSERT(num_sqrtdiag_lhs>0);
00065             ASSERT(num_sqrtdiag_rhs>0);
00066 
00067             CFeatures* old_lhs=k->lhs;
00068             CFeatures* old_rhs=k->rhs;
00069 
00070             k->lhs=old_lhs;
00071             k->rhs=old_lhs;
00072             bool r1=alloc_and_compute_diag(k, sqrtdiag_lhs, num_sqrtdiag_lhs);
00073 
00074             k->lhs=old_rhs;
00075             k->rhs=old_rhs;
00076             bool r2=alloc_and_compute_diag(k, sqrtdiag_rhs, num_sqrtdiag_rhs);
00077 
00078             k->lhs=old_lhs;
00079             k->rhs=old_rhs;
00080 
00081             return r1 && r2;
00082         }
00083 
00089         inline virtual float64_t normalize(
00090             float64_t value, int32_t idx_lhs, int32_t idx_rhs)
00091         {
00092             float64_t sqrt_both=sqrtdiag_lhs[idx_lhs]*sqrtdiag_rhs[idx_rhs];
00093             return value/sqrt_both;
00094         }
00095 
00100         inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00101         {
00102             return value/sqrtdiag_lhs[idx_lhs];
00103         }
00104 
00109         inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00110         {
00111             return value/sqrtdiag_rhs[idx_rhs];
00112         }
00113 
00114     public:
00119         bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num)
00120         {
00121             SG_FREE(v);
00122             v=SG_MALLOC(float64_t, num);
00123 
00124             for (int32_t i=0; i<num; i++)
00125             {
00126                 if (k->get_kernel_type() == K_COMMWORDSTRING)
00127                 {
00128                     if (use_optimized_diagonal_computation)
00129                         v[i]=sqrt(((CCommWordStringKernel*) k)->compute_diag(i));
00130                     else
00131                         v[i]=sqrt(((CCommWordStringKernel*) k)->compute_helper(i,i, true));
00132                 }
00133                 else
00134                     v[i]=sqrt(k->compute(i,i));
00135 
00136                 if (v[i]==0.0)
00137                     v[i]=1e-16; /* avoid divide by zero exception */
00138             }
00139 
00140             return (v!=NULL);
00141         }
00142 
00144         inline virtual const char* get_name() const { return "SqrtDiagKernelNormalizer"; }
00145 
00146     protected:
00148         float64_t* sqrtdiag_lhs;
00149 
00151         int32_t num_sqrtdiag_lhs;
00152 
00154         float64_t* sqrtdiag_rhs;
00155 
00157         int32_t num_sqrtdiag_rhs;
00158 
00160         bool use_optimized_diagonal_computation;
00161 };
00162 }
00163 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation