DomainAdaptationSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2011 Christian Widmer
00008  * Copyright (C) 2007-2011 Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 
00013 #ifdef USE_SVMLIGHT
00014 
00015 #include <shogun/classifier/svm/DomainAdaptationSVM.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <iostream>
00018 #include <vector>
00019 
00020 using namespace shogun;
00021 
00022 CDomainAdaptationSVM::CDomainAdaptationSVM() : CSVMLight()
00023 {
00024 }
00025 
00026 CDomainAdaptationSVM::CDomainAdaptationSVM(float64_t C, CKernel* k, CLabels* lab, CSVM* pre_svm, float64_t B_param) : CSVMLight(C, k, lab)
00027 {
00028     init();
00029     init(pre_svm, B_param);
00030 }
00031 
00032 CDomainAdaptationSVM::~CDomainAdaptationSVM()
00033 {
00034     SG_UNREF(presvm);
00035     SG_DEBUG("deleting DomainAdaptationSVM\n");
00036 }
00037 
00038 
00039 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
00040 {
00041     // increase reference counts
00042     SG_REF(pre_svm);
00043 
00044     this->presvm=pre_svm;
00045     this->B=B_param;
00046     this->train_factor=1.0;
00047 
00048     // set bias of parent svm to zero
00049     this->presvm->set_bias(0.0);
00050 
00051     // invoke sanity check
00052     is_presvm_sane();
00053 }
00054 
00055 bool CDomainAdaptationSVM::is_presvm_sane()
00056 {
00057     if (!presvm) {
00058         SG_ERROR("presvm is null");
00059     }
00060 
00061     if (presvm->get_num_support_vectors() == 0) {
00062         SG_ERROR("presvm has no support vectors, please train first");
00063     }
00064 
00065     if (presvm->get_bias() != 0) {
00066         SG_ERROR("presvm bias not set to zero");
00067     }
00068 
00069     if (presvm->get_kernel()->get_kernel_type() != this->get_kernel()->get_kernel_type()) {
00070         SG_ERROR("kernel types do not agree");
00071     }
00072 
00073     if (presvm->get_kernel()->get_feature_type() != this->get_kernel()->get_feature_type()) {
00074         SG_ERROR("feature types do not agree");
00075     }
00076 
00077     return true;
00078 }
00079 
00080 
00081 bool CDomainAdaptationSVM::train_machine(CFeatures* data)
00082 {
00083 
00084     if (data)
00085     {
00086         if (labels->get_num_labels() != data->get_num_vectors())
00087             SG_ERROR("Number of training vectors does not match number of labels\n");
00088         kernel->init(data, data);
00089     }
00090 
00091     int32_t num_training_points = get_labels()->get_num_labels();
00092 
00093 
00094     float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
00095 
00096     // grab current training features
00097     CFeatures* train_data = get_kernel()->get_lhs();
00098 
00099     // bias of parent SVM was set to zero in constructor, already contains B
00100     CLabels* parent_svm_out = presvm->apply(train_data);
00101 
00102     // pre-compute linear term
00103     for (int32_t i=0; i<num_training_points; i++)
00104     {
00105         lin_term[i] = train_factor * B * get_label(i) * parent_svm_out->get_label(i) - 1.0;
00106     }
00107 
00108     //set linear term for QP
00109     this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
00110 
00111     SG_FREE(lin_term);
00112 
00113     //train SVM
00114     bool success = CSVMLight::train_machine();
00115 
00116     ASSERT(presvm)
00117 
00118     return success;
00119 
00120 }
00121 
00122 
00123 CSVM* CDomainAdaptationSVM::get_presvm()
00124 {
00125     SG_REF(presvm);
00126     return presvm;
00127 }
00128 
00129 
00130 float64_t CDomainAdaptationSVM::get_B()
00131 {
00132     return B;
00133 }
00134 
00135 
00136 float64_t CDomainAdaptationSVM::get_train_factor()
00137 {
00138     return train_factor;
00139 }
00140 
00141 
00142 void CDomainAdaptationSVM::set_train_factor(float64_t factor)
00143 {
00144     train_factor = factor;
00145 }
00146 
00147 
00148 CLabels* CDomainAdaptationSVM::apply(CFeatures* data)
00149 {
00150 
00151     ASSERT(presvm->get_bias()==0.0);
00152 
00153     int32_t num_examples = data->get_num_vectors();
00154 
00155     CLabels* out_current = CSVMLight::apply(data);
00156 
00157     // recursive call if used on DomainAdaptationSVM object
00158     CLabels* out_presvm = presvm->apply(data);
00159 
00160 
00161     // combine outputs
00162     for (int32_t i=0; i!=num_examples; i++)
00163     {
00164         float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00165         out_current->set_label(i, out_combined);
00166     }
00167 
00168     return out_current;
00169 
00170 }
00171 
00172 void CDomainAdaptationSVM::init()
00173 {
00174     presvm = NULL;
00175     B = 0;
00176     train_factor = 1.0;
00177 
00178     m_parameters->add((CSGObject**) &presvm, "presvm",
00179                       "SVM to regularize against.");
00180     m_parameters->add(&B, "B", "regularization parameter B.");
00181     m_parameters->add(&train_factor,
00182             "train_factor", "flag to switch off regularization in training.");
00183 }
00184 
00185 #endif //USE_SVMLIGHT
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation