DomainAdaptationSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2011 Christian Widmer
00008  * Copyright (C) 2007-2011 Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 
00013 #ifdef USE_SVMLIGHT
00014 
00015 #include <shogun/transfer/domain_adaptation/DomainAdaptationSVM.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 #include <shogun/labels/RegressionLabels.h>
00020 #include <iostream>
00021 #include <vector>
00022 
00023 using namespace shogun;
00024 
00025 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00026 {
00027 }
00028 
00029 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00030 {
00031     init();
00032     init(pre_svm, B_param);
00033 }
00034 
00035 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00036 {
00037     SG_UNREF(presvm);
00038     SG_DEBUG("deleting DomainAdaptationSVM\n");
00039 }
00040 
00041 
00042 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00043 {
00044     // increase reference counts
00045     SG_REF(pre_svm);
00046 
00047     this->presvm=pre_svm;
00048     this->B=B_param;
00049     this->train_factor=1.0;
00050 
00051     // set bias of parent svm to zero
00052     this->presvm->set_bias(0.0);
00053 
00054     // invoke sanity check
00055     is_presvm_sane();
00056 }
00057 
00058 bool CDomainAdaptationSVM::is_presvm_sane()
00059 {
00060     if (!presvm) {
00061         SG_ERROR("presvm is null");
00062     }
00063 
00064     if (presvm->get_num_support_vectors() == 0) {
00065         SG_ERROR("presvm has no support vectors, please train first");
00066     }
00067 
00068     if (presvm->get_bias() != 0) {
00069         SG_ERROR("presvm bias not set to zero");
00070     }
00071 
00072     if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00073         SG_ERROR("kernel types do not agree");
00074     }
00075 
00076     if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00077         SG_ERROR("feature types do not agree");
00078     }
00079 
00080     return true;
00081 }
00082 
00083 
00084 bool CDomainAdaptationSVM::train_machine(CFeatures* data)
00085 {
00086 
00087     if (data)
00088     {
00089         if (m_labels->get_num_labels() != data->get_num_vectors())
00090             SG_ERROR("Number of training vectors does not match number of labels\n");
00091         kernel->init(data, data);
00092     }
00093 
00094     if (m_labels->get_label_type() != LT_BINARY)
00095         SG_ERROR("DomainAdaptationSVM requires binary labels\n");
00096 
00097     int32_t num_training_points = get_labels()->get_num_labels();
00098     CBinaryLabels* labels = (CBinaryLabels*) get_labels();
00099 
00100     float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
00101 
00102     // grab current training features
00103     CFeatures* train_data = get_kernel()->get_lhs();
00104 
00105     // bias of parent SVM was set to zero in constructor, already contains B
00106     CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data);
00107 
00108     // pre-compute linear term
00109     for (int32_t i=0; i<num_training_points; i++)
00110     {
00111         lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0;
00112     }
00113 
00114     //set linear term for QP
00115     this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
00116 
00117     //train SVM
00118     bool success = CSVMLight::train_machine();
00119     SG_UNREF(labels);
00120 
00121     ASSERT(presvm)
00122 
00123     return success;
00124 
00125 }
00126 
00127 
00128 CSVM* CDomainAdaptationSVM::get_presvm()
00129 {
00130     SG_REF(presvm);
00131     return presvm;
00132 }
00133 
00134 
00135 float64_t CDomainAdaptationSVM::get_B()
00136 {
00137     return B;
00138 }
00139 
00140 
00141 float64_t CDomainAdaptationSVM::get_train_factor()
00142 {
00143     return train_factor;
00144 }
00145 
00146 
00147 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00148 {
00149     train_factor = factor;
00150 }
00151 
00152 
00153 CBinaryLabels* CDomainAdaptationSVM::apply_binary(CFeatures* data)
00154 {
00155     ASSERT(data);
00156     ASSERT(presvm->get_bias()==0.0);
00157 
00158     int32_t num_examples = data->get_num_vectors();
00159 
00160     CBinaryLabels* out_current = CSVMLight::apply_binary(data);
00161 
00162     // recursive call if used on DomainAdaptationSVM object
00163     CBinaryLabels* out_presvm = presvm->apply_binary(data);
00164 
00165     // combine outputs
00166     SGVector<float64_t> out_combined(num_examples);
00167     for (int32_t i=0; i<num_examples; i++)
00168     {
00169         out_combined[i] = out_current->get_value(i) + B*out_presvm->get_value(i);
00170     }
00171     SG_UNREF(out_current);
00172     SG_UNREF(out_presvm);
00173 
00174     return new CBinaryLabels(out_combined);
00175 
00176 }
00177 
00178 void CDomainAdaptationSVM::init()
00179 {
00180     presvm = NULL;
00181     B = 0;
00182     train_factor = 1.0;
00183 
00184     m_parameters->add((CSGObject**) &presvm, "presvm",
00185                       "SVM to regularize against.");
00186     m_parameters->add(&B, "B", "regularization parameter B.");
00187     m_parameters->add(&train_factor,
00188             "train_factor", "flag to switch off regularization in training.");
00189 }
00190 
00191 #endif //USE_SVMLIGHT
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation