Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SCATTERKERNELNORMALIZER_H___
00012 #define _SCATTERKERNELNORMALIZER_H___
00013
00014 #include "kernel/KernelNormalizer.h"
00015 #include "kernel/IdentityKernelNormalizer.h"
00016 #include "kernel/Kernel.h"
00017 #include "features/Labels.h"
00018 #include "lib/io.h"
00019
00020 namespace shogun
00021 {
00022 class CScatterKernelNormalizer: public CKernelNormalizer
00023 {
00024
00025 public:
00027 CScatterKernelNormalizer() : CKernelNormalizer()
00028 {
00029 init();
00030 }
00031
00034 CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag,
00035 CLabels* labels,CKernelNormalizer* normalizer=NULL)
00036 : CKernelNormalizer()
00037 {
00038 init();
00039
00040 m_testing_class=-1;
00041 m_const_diag=const_diag;
00042 m_const_offdiag=const_offdiag;
00043
00044 ASSERT(labels)
00045 SG_REF(labels);
00046 m_labels=labels;
00047
00048 if (normalizer==NULL)
00049 normalizer=new CIdentityKernelNormalizer();
00050 SG_REF(normalizer);
00051 m_normalizer=normalizer;
00052
00053 SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g"
00054 " const_offdiag=%g num_labels=%d and normalizer='%s'\n",
00055 const_diag, const_offdiag, labels->get_num_labels(),
00056 normalizer->get_name());
00057 }
00058
00060 virtual ~CScatterKernelNormalizer()
00061 {
00062 SG_UNREF(m_labels);
00063 SG_UNREF(m_normalizer);
00064 }
00065
00068 virtual bool init(CKernel* k)
00069 {
00070 m_normalizer->init(k);
00071 return true;
00072 }
00073
00078 int32_t get_testing_class()
00079 {
00080 return m_testing_class;
00081 }
00082
00087 void set_testing_class(int32_t c)
00088 {
00089 m_testing_class=c;
00090 }
00091
00097 inline virtual float64_t normalize(float64_t value, int32_t idx_lhs,
00098 int32_t idx_rhs)
00099 {
00100 value=m_normalizer->normalize(value, idx_lhs, idx_rhs);
00101 float64_t c=m_const_offdiag;
00102
00103 if (m_testing_class>=0)
00104 {
00105 if (m_labels->get_label(idx_lhs) == m_testing_class)
00106 c=m_const_diag;
00107 }
00108 else
00109 {
00110 if (m_labels->get_label(idx_lhs) == m_labels->get_label(idx_rhs))
00111 c=m_const_diag;
00112
00113 }
00114 return value*c;
00115 }
00116
00121 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00122 {
00123 SG_ERROR("normalize_lhs not implemented");
00124 return 0;
00125 }
00126
00131 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00132 {
00133 SG_ERROR("normalize_rhs not implemented");
00134 return 0;
00135 }
00136
00138 inline virtual const char* get_name() const
00139 {
00140 return "ScatterKernelNormalizer";
00141 }
00142
00143 private:
00144 void init()
00145 {
00146 m_const_diag = 1.0;
00147 m_const_offdiag = 1.0;
00148
00149 m_labels = NULL;
00150 m_normalizer = NULL;
00151
00152 m_testing_class = -1;
00153
00154
00155 m_parameters->add(&m_testing_class, "m_testing_class"
00156 "Testing Class.");
00157 m_parameters->add(&m_const_diag, "m_const_diag"
00158 "Factor to multiply to diagonal elements.");
00159 m_parameters->add(&m_const_offdiag, "m_const_offdiag"
00160 "Factor to multiply to off-diagonal elements.");
00161
00162 m_parameters->add((CSGObject**) &m_labels, "m_labels", "Labels");
00163 m_parameters->add((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer.");
00164 }
00165
00166 protected:
00167
00169 float64_t m_const_diag;
00171 float64_t m_const_offdiag;
00172
00174 CLabels* m_labels;
00175
00177 CKernelNormalizer* m_normalizer;
00178
00180 int32_t m_testing_class;
00181 };
00182 }
00183 #endif
00184