ScatterKernelNormalizer.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2010 Soeren Sonnenburg
00008  * Copyright (C) 2010 Berlin Institute of Technology
00009  */
00010 
00011 #ifndef _SCATTERKERNELNORMALIZER_H___
00012 #define _SCATTERKERNELNORMALIZER_H___
00013 
00014 #include <shogun/kernel/KernelNormalizer.h>
00015 #include <shogun/kernel/IdentityKernelNormalizer.h>
00016 #include <shogun/kernel/Kernel.h>
00017 #include <shogun/features/Labels.h>
00018 #include <shogun/io/SGIO.h>
00019 
00020 namespace shogun
00021 {
00023 class CScatterKernelNormalizer: public CKernelNormalizer
00024 {
00025 
00026 public:
00028     CScatterKernelNormalizer() : CKernelNormalizer()
00029     {
00030         init();
00031     }
00032 
00035     CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag,
00036             CLabels* labels,CKernelNormalizer* normalizer=NULL)
00037         : CKernelNormalizer()
00038     {
00039         init();
00040 
00041         m_testing_class=-1;
00042         m_const_diag=const_diag;
00043         m_const_offdiag=const_offdiag;
00044 
00045         ASSERT(labels)
00046         SG_REF(labels);
00047         m_labels=labels;
00048 
00049         if (normalizer==NULL)
00050             normalizer=new CIdentityKernelNormalizer();
00051         SG_REF(normalizer);
00052         m_normalizer=normalizer;
00053 
00054         SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g"
00055                 " const_offdiag=%g num_labels=%d and normalizer='%s'\n",
00056                 const_diag, const_offdiag, labels->get_num_labels(),
00057                 normalizer->get_name());
00058     }
00059 
00061     virtual ~CScatterKernelNormalizer()
00062     {
00063         SG_UNREF(m_labels);
00064         SG_UNREF(m_normalizer);
00065     }
00066 
00069     virtual bool init(CKernel* k)
00070     {
00071         m_normalizer->init(k);
00072         return true;
00073     }
00074 
00079     int32_t get_testing_class()
00080     {
00081         return m_testing_class;
00082     }
00083 
00088     void set_testing_class(int32_t c)
00089     {
00090         m_testing_class=c;
00091     }
00092 
00098     inline virtual float64_t normalize(float64_t value, int32_t idx_lhs,
00099             int32_t idx_rhs)
00100     {
00101         value=m_normalizer->normalize(value, idx_lhs, idx_rhs);
00102         float64_t c=m_const_offdiag;
00103 
00104         if (m_testing_class>=0)
00105         {
00106             if (m_labels->get_label(idx_lhs) == m_testing_class)
00107                 c=m_const_diag;
00108         }
00109         else
00110         {
00111             if (m_labels->get_label(idx_lhs) == m_labels->get_label(idx_rhs))
00112                 c=m_const_diag;
00113 
00114         }
00115         return value*c;
00116     }
00117 
00122     inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00123     {
00124         SG_ERROR("normalize_lhs not implemented");
00125         return 0;
00126     }
00127 
00132     inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00133     {
00134         SG_ERROR("normalize_rhs not implemented");
00135         return 0;
00136     }
00137 
00139     inline virtual const char* get_name() const
00140     {
00141         return "ScatterKernelNormalizer";
00142     }
00143 
00144 private:
00145     void init()
00146     {
00147         m_const_diag = 1.0;
00148         m_const_offdiag = 1.0;
00149 
00150         m_labels = NULL;
00151         m_normalizer = NULL;
00152 
00153         m_testing_class = -1;
00154 
00155         
00156         m_parameters->add(&m_testing_class, "m_testing_class"
00157                 "Testing Class.");
00158         m_parameters->add(&m_const_diag, "m_const_diag"
00159                 "Factor to multiply to diagonal elements.");
00160         m_parameters->add(&m_const_offdiag, "m_const_offdiag"
00161                 "Factor to multiply to off-diagonal elements.");
00162 
00163         m_parameters->add((CSGObject**) &m_labels, "m_labels", "Labels");
00164         m_parameters->add((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer.");
00165     }
00166 
00167 protected:
00168 
00170     float64_t m_const_diag;
00172     float64_t m_const_offdiag;
00173 
00175     CLabels* m_labels;
00176 
00178     CKernelNormalizer* m_normalizer;
00179 
00181     int32_t m_testing_class;
00182 };
00183 }
00184 #endif
00185 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation