DomainAdaptationMulticlassLibLinear.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Sergey Lisitsyn
00008  * Copyright (C) 2012 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 #ifdef HAVE_LAPACK
00013 #include <shogun/transfer/domain_adaptation/DomainAdaptationMulticlassLibLinear.h>
00014 #include <shogun/labels/MulticlassLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 CDomainAdaptationMulticlassLibLinear::CDomainAdaptationMulticlassLibLinear() :
00019     CMulticlassLibLinear()
00020 {
00021     init_defaults();
00022 }
00023 
00024 CDomainAdaptationMulticlassLibLinear::CDomainAdaptationMulticlassLibLinear(
00025         float64_t target_C, CDotFeatures* target_features, CLabels* target_labels,
00026         CLinearMulticlassMachine* source_machine) :
00027     CMulticlassLibLinear(target_C,target_features,target_labels)
00028 {
00029     init_defaults();
00030 
00031     set_source_machine(source_machine);
00032 }
00033 
00034 void CDomainAdaptationMulticlassLibLinear::init_defaults()
00035 {
00036     m_train_factor = 1.0;
00037     m_source_bias = 0.5;
00038     m_source_machine = NULL;
00039 
00040     register_parameters();
00041 }
00042 
00043 float64_t CDomainAdaptationMulticlassLibLinear::get_source_bias() const
00044 {
00045     return m_source_bias;
00046 }
00047 
00048 void CDomainAdaptationMulticlassLibLinear::set_source_bias(float64_t source_bias)
00049 {
00050     m_source_bias = source_bias;
00051 }
00052 
00053 float64_t CDomainAdaptationMulticlassLibLinear::get_train_factor() const
00054 {
00055     return m_train_factor;
00056 }
00057 
00058 void CDomainAdaptationMulticlassLibLinear::set_train_factor(float64_t train_factor)
00059 {
00060     m_train_factor = train_factor;
00061 }
00062 
00063 CLinearMulticlassMachine* CDomainAdaptationMulticlassLibLinear::get_source_machine() const
00064 {
00065     SG_REF(m_source_machine);
00066     return m_source_machine;
00067 }
00068 
00069 void CDomainAdaptationMulticlassLibLinear::set_source_machine(
00070         CLinearMulticlassMachine* source_machine)
00071 {
00072     SG_UNREF(m_source_machine);
00073     SG_REF(source_machine);
00074     m_source_machine = source_machine;
00075 }
00076 
00077 void CDomainAdaptationMulticlassLibLinear::register_parameters()
00078 {
00079     SG_ADD((CSGObject**)&m_source_machine, "source_machine", "source domain machine",
00080             MS_NOT_AVAILABLE);
00081     SG_ADD(&m_train_factor, "train_factor", "factor of target domain regularization",
00082             MS_AVAILABLE);
00083     SG_ADD(&m_source_bias, "source_bias", "bias to source domain",
00084             MS_AVAILABLE);
00085 }
00086 
00087 CDomainAdaptationMulticlassLibLinear::~CDomainAdaptationMulticlassLibLinear()
00088 {
00089 }
00090 
00091 SGMatrix<float64_t> CDomainAdaptationMulticlassLibLinear::obtain_regularizer_matrix() const
00092 {
00093     ASSERT(get_use_bias()==false);
00094     int32_t n_classes = ((CMulticlassLabels*)m_source_machine->get_labels())->get_num_classes();
00095     int32_t n_features = ((CDotFeatures*)m_source_machine->get_features())->get_dim_feature_space();
00096     SGMatrix<float64_t> w0(n_classes,n_features);
00097 
00098     for (int32_t i=0; i<n_classes; i++)
00099     {
00100         SGVector<float64_t> w = ((CLinearMachine*)m_source_machine->get_machine(i))->get_w();
00101         for (int32_t j=0; j<n_features; j++)
00102             w0(j,i) = m_train_factor*w[j];
00103     }
00104 
00105     return w0;
00106 }
00107 
00108 CBinaryLabels* CDomainAdaptationMulticlassLibLinear::get_submachine_outputs(int32_t i)
00109 {
00110     CBinaryLabels* target_outputs = CMulticlassMachine::get_submachine_outputs(i);
00111     CBinaryLabels* source_outputs = m_source_machine->get_submachine_outputs(i);
00112     int32_t n_target_outputs = target_outputs->get_num_labels();
00113     ASSERT(n_target_outputs==source_outputs->get_num_labels());
00114     SGVector<float64_t> result(n_target_outputs);
00115     for (int32_t j=0; j<result.vlen; j++)
00116         result[j] = (1-m_source_bias)*target_outputs->get_value(j) + m_source_bias*source_outputs->get_value(j);
00117 
00118     SG_UNREF(target_outputs);
00119     SG_UNREF(source_outputs);
00120 
00121     return new CBinaryLabels(result);
00122 }
00123 #endif /* HAVE_LAPACK */
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation