ScatterKernelNormalizer.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2010 Soeren Sonnenburg
00008  * Copyright (C) 2010 Berlin Institute of Technology
00009  */
00010 
00011 #ifndef _SCATTERKERNELNORMALIZER_H___
00012 #define _SCATTERKERNELNORMALIZER_H___
00013 
00014 #include <shogun/kernel/normalizer/KernelNormalizer.h>
00015 #include <shogun/kernel/normalizer/IdentityKernelNormalizer.h>
00016 #include <shogun/kernel/Kernel.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/MulticlassLabels.h>
00019 #include <shogun/io/SGIO.h>
00020 
00021 namespace shogun
00022 {
00024 class CScatterKernelNormalizer: public CKernelNormalizer
00025 {
00026 
00027 public:
00029     CScatterKernelNormalizer() : CKernelNormalizer()
00030     {
00031         init();
00032     }
00033 
00036     CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag,
00037             CLabels* labels,CKernelNormalizer* normalizer=NULL)
00038         : CKernelNormalizer()
00039     {
00040         init();
00041 
00042         m_testing_class=-1;
00043         m_const_diag=const_diag;
00044         m_const_offdiag=const_offdiag;
00045 
00046         ASSERT(labels)
00047         SG_REF(labels);
00048         m_labels=labels;
00049         ASSERT(labels->get_label_type()==LT_MULTICLASS);
00050         labels->ensure_valid();
00051 
00052         if (normalizer==NULL)
00053             normalizer=new CIdentityKernelNormalizer();
00054         SG_REF(normalizer);
00055         m_normalizer=normalizer;
00056 
00057         SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g"
00058                 " const_offdiag=%g num_labels=%d and normalizer='%s'\n",
00059                 const_diag, const_offdiag, labels->get_num_labels(),
00060                 normalizer->get_name());
00061     }
00062 
00064     virtual ~CScatterKernelNormalizer()
00065     {
00066         SG_UNREF(m_labels);
00067         SG_UNREF(m_normalizer);
00068     }
00069 
00072     virtual bool init(CKernel* k)
00073     {
00074         m_normalizer->init(k);
00075         return true;
00076     }
00077 
00082     int32_t get_testing_class()
00083     {
00084         return m_testing_class;
00085     }
00086 
00091     void set_testing_class(int32_t c)
00092     {
00093         m_testing_class=c;
00094     }
00095 
00101     virtual float64_t normalize(float64_t value, int32_t idx_lhs,
00102             int32_t idx_rhs)
00103     {
00104         value=m_normalizer->normalize(value, idx_lhs, idx_rhs);
00105         float64_t c=m_const_offdiag;
00106 
00107         if (m_testing_class>=0)
00108         {
00109             if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == m_testing_class)
00110                 c=m_const_diag;
00111         }
00112         else
00113         {
00114             if (((CMulticlassLabels*) m_labels)->get_label(idx_lhs) == ((CMulticlassLabels*) m_labels)->get_label(idx_rhs))
00115                 c=m_const_diag;
00116 
00117         }
00118         return value*c;
00119     }
00120 
00125     virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs)
00126     {
00127         SG_ERROR("normalize_lhs not implemented");
00128         return 0;
00129     }
00130 
00135     virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs)
00136     {
00137         SG_ERROR("normalize_rhs not implemented");
00138         return 0;
00139     }
00140 
00142     virtual const char* get_name() const
00143     {
00144         return "ScatterKernelNormalizer";
00145     }
00146 
00147 private:
00148     void init()
00149     {
00150         m_const_diag = 1.0;
00151         m_const_offdiag = 1.0;
00152 
00153         m_labels = NULL;
00154         m_normalizer = NULL;
00155 
00156         m_testing_class = -1;
00157 
00158         SG_ADD(&m_testing_class, "m_testing_class",
00159                 "Testing Class.", MS_NOT_AVAILABLE);
00160         SG_ADD(&m_const_diag, "m_const_diag",
00161                 "Factor to multiply to diagonal elements.", MS_AVAILABLE);
00162         SG_ADD(&m_const_offdiag, "m_const_offdiag",
00163                 "Factor to multiply to off-diagonal elements.", MS_AVAILABLE);
00164 
00165         SG_ADD((CSGObject**) &m_labels, "m_labels", "Labels", MS_NOT_AVAILABLE);
00166         SG_ADD((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer.",
00167             MS_AVAILABLE);
00168     }
00169 
00170 protected:
00171 
00173     float64_t m_const_diag;
00175     float64_t m_const_offdiag;
00176 
00178     CLabels* m_labels;
00179 
00181     CKernelNormalizer* m_normalizer;
00182 
00184     int32_t m_testing_class;
00185 };
00186 }
00187 #endif
00188 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation