Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SQRTDIAGKERNELNORMALIZER_H___
00012 #define _SQRTDIAGKERNELNORMALIZER_H___
00013
00014 #include <shogun/kernel/KernelNormalizer.h>
00015 #include <shogun/kernel/CommWordStringKernel.h>
00016
00017 namespace shogun
00018 {
00029 class CSqrtDiagKernelNormalizer : public CKernelNormalizer
00030 {
00031 public:
00036 CSqrtDiagKernelNormalizer(bool use_opt_diag=false): CKernelNormalizer(),
00037 sqrtdiag_lhs(NULL), num_sqrtdiag_lhs(0),
00038 sqrtdiag_rhs(NULL), num_sqrtdiag_rhs(0),
00039 use_optimized_diagonal_computation(use_opt_diag)
00040 {
00041 m_parameters->add_vector(&sqrtdiag_lhs, &num_sqrtdiag_lhs, "sqrtdiag_lhs",
00042 "sqrt(K(x,x)) for left hand side examples.");
00043 m_parameters->add_vector(&sqrtdiag_rhs, &num_sqrtdiag_rhs, "sqrtdiag_rhs",
00044 "sqrt(K(x,x)) for right hand side examples.");
00045 m_parameters->add(&use_optimized_diagonal_computation,
00046 "use_optimized_diagonal_computation",
00047 "flat if optimized diagonal computation is used");
00048 }
00049
00051 virtual ~CSqrtDiagKernelNormalizer()
00052 {
00053 SG_FREE(sqrtdiag_lhs);
00054 SG_FREE(sqrtdiag_rhs);
00055 }
00056
00059 virtual bool init(CKernel* k)
00060 {
00061 ASSERT(k);
00062 num_sqrtdiag_lhs=k->get_num_vec_lhs();
00063 num_sqrtdiag_rhs=k->get_num_vec_rhs();
00064 ASSERT(num_sqrtdiag_lhs>0);
00065 ASSERT(num_sqrtdiag_rhs>0);
00066
00067 CFeatures* old_lhs=k->lhs;
00068 CFeatures* old_rhs=k->rhs;
00069
00070 k->lhs=old_lhs;
00071 k->rhs=old_lhs;
00072 bool r1=alloc_and_compute_diag(k, sqrtdiag_lhs, num_sqrtdiag_lhs);
00073
00074 k->lhs=old_rhs;
00075 k->rhs=old_rhs;
00076 bool r2=alloc_and_compute_diag(k, sqrtdiag_rhs, num_sqrtdiag_rhs);
00077
00078 k->lhs=old_lhs;
00079 k->rhs=old_rhs;
00080
00081 return r1 && r2;
00082 }
00083
00089 inline virtual float64_t normalize(
00090 float64_t value, int32_t idx_lhs, int32_t idx_rhs)
00091 {
00092 float64_t sqrt_both=sqrtdiag_lhs[idx_lhs]*sqrtdiag_rhs[idx_rhs];
00093 return value/sqrt_both;
00094 }
00095
00100 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00101 {
00102 return value/sqrtdiag_lhs[idx_lhs];
00103 }
00104
00109 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00110 {
00111 return value/sqrtdiag_rhs[idx_rhs];
00112 }
00113
00114 public:
00119 bool alloc_and_compute_diag(CKernel* k, float64_t* &v, int32_t num)
00120 {
00121 SG_FREE(v);
00122 v=SG_MALLOC(float64_t, num);
00123
00124 for (int32_t i=0; i<num; i++)
00125 {
00126 if (k->get_kernel_type() == K_COMMWORDSTRING)
00127 {
00128 if (use_optimized_diagonal_computation)
00129 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_diag(i));
00130 else
00131 v[i]=sqrt(((CCommWordStringKernel*) k)->compute_helper(i,i, true));
00132 }
00133 else
00134 v[i]=sqrt(k->compute(i,i));
00135
00136 if (v[i]==0.0)
00137 v[i]=1e-16;
00138 }
00139
00140 return (v!=NULL);
00141 }
00142
00144 inline virtual const char* get_name() const { return "SqrtDiagKernelNormalizer"; }
00145
00146 protected:
00148 float64_t* sqrtdiag_lhs;
00149
00151 int32_t num_sqrtdiag_lhs;
00152
00154 float64_t* sqrtdiag_rhs;
00155
00157 int32_t num_sqrtdiag_rhs;
00158
00160 bool use_optimized_diagonal_computation;
00161 };
00162 }
00163 #endif