Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SCATTERKERNELNORMALIZER_H___
00012 #define _SCATTERKERNELNORMALIZER_H___
00013
00014 #include <shogun/kernel/normalizer/KernelNormalizer.h>
00015 #include <shogun/kernel/normalizer/IdentityKernelNormalizer.h>
00016 #include <shogun/kernel/Kernel.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/MulticlassLabels.h>
00019 #include <shogun/io/SGIO.h>
00020
00021 namespace shogun
00022 {
00024 class CScatterKernelNormalizer: public CKernelNormalizer
00025 {
00026
00027 public:
00029 CScatterKernelNormalizer() : CKernelNormalizer()
00030 {
00031 init();
00032 }
00033
00036 CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag,
00037 CLabels* labels,CKernelNormalizer* normalizer=NULL)
00038 : CKernelNormalizer()
00039 {
00040 init();
00041
00042 m_testing_class=-1;
00043 m_const_diag=const_diag;
00044 m_const_offdiag=const_offdiag;
00045
00046 ASSERT(labels)
00047 SG_REF(labels);
00048 m_labels=labels;
00049 ASSERT(labels->get_label_type()==LT_MULTICLASS);
00050 labels->ensure_valid();
00051
00052 if (normalizer==NULL)
00053 normalizer=new CIdentityKernelNormalizer();
00054 SG_REF(normalizer);
00055 m_normalizer=normalizer;
00056
00057 SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g"
00058 " const_offdiag=%g num_labels=%d and normalizer='%s'\n",
00059 const_diag, const_offdiag, labels->get_num_labels(),
00060 normalizer->get_name());
00061 }
00062
00064 virtual ~CScatterKernelNormalizer()
00065 {
00066 SG_UNREF(m_labels);
00067 SG_UNREF(m_normalizer);
00068 }
00069
00072 virtual bool init(CKernel* k)
00073 {
00074 m_normalizer->init(k);
00075 return true;
00076 }
00077
00082 int32_t get_testing_class()
00083 {
00084 return m_testing_class;
00085 }
00086
00091 void set_testing_class(int32_t c)
00092 {
00093 m_testing_class=c;
00094 }
00095
00101 virtual float64_t normalize(float64_t value, int32_t idx_lhs,
00102 int32_t idx_rhs)
00103 {
00104 value=m_normalizer->normalize(value, idx_lhs, idx_rhs);
00105 float64_t c=m_const_offdiag;
00106
00107 if (m_testing_class>=0)
00108 {
00109 if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == m_testing_class)
00110 c=m_const_diag;
00111 }
00112 else
00113 {
00114 if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == ((CMulticlassLabels*) m_labels)->get_label(idx_rhs))
00115 c=m_const_diag;
00116
00117 }
00118 return value*c;
00119 }
00120
00125 virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00126 {
00127 SG_ERROR("normalize_lhs not implemented");
00128 return 0;
00129 }
00130
00135 virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00136 {
00137 SG_ERROR("normalize_rhs not implemented");
00138 return 0;
00139 }
00140
00142 virtual const char* get_name() const
00143 {
00144 return "ScatterKernelNormalizer";
00145 }
00146
00147 private:
00148 void init()
00149 {
00150 m_const_diag = 1.0;
00151 m_const_offdiag = 1.0;
00152
00153 m_labels = NULL;
00154 m_normalizer = NULL;
00155
00156 m_testing_class = -1;
00157
00158 SG_ADD(&m_testing_class, "m_testing_class",
00159 "Testing Class.", MS_NOT_AVAILABLE);
00160 SG_ADD(&m_const_diag, "m_const_diag",
00161 "Factor to multiply to diagonal elements.", MS_AVAILABLE);
00162 SG_ADD(&m_const_offdiag, "m_const_offdiag",
00163 "Factor to multiply to off-diagonal elements.", MS_AVAILABLE);
00164
00165 SG_ADD((CSGObject**) &m_labels, "m_labels", "Labels", MS_NOT_AVAILABLE);
00166 SG_ADD((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer.",
00167 MS_AVAILABLE);
00168 }
00169
00170 protected:
00171
00173 float64_t m_const_diag;
00175 float64_t m_const_offdiag;
00176
00178 CLabels* m_labels;
00179
00181 CKernelNormalizer* m_normalizer;
00182
00184 int32_t m_testing_class;
00185 };
00186 }
00187 #endif
00188