DomainAdaptationSVMLinear.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2011 Christian Widmer
00008  * Copyright (C) 2007-2011 Max-Planck-Society
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 
00013 #ifdef HAVE_LAPACK
00014 
00015 #include <shogun/classifier/svm/DomainAdaptationSVMLinear.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/base/Parameter.h>
00018 #include <iostream>
00019 #include <vector>
00020 
00021 using namespace shogun;
00022 
00023 
00024 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear() : CLibLinear(L2R_L1LOSS_SVC_DUAL)
00025 {
00026     init(NULL, 0.0);
00027 }
00028 
00029 
00030 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear(float64_t C, CDotFeatures* f, CLabels* lab, CLinearMachine* pre_svm, float64_t B_param) : CLibLinear(C, f, lab)
00031 {
00032     init(pre_svm, B_param);
00033 
00034 }
00035 
00036 
00037 CDomainAdaptationSVMLinear::~CDomainAdaptationSVMLinear()
00038 {
00039 
00040     SG_UNREF(presvm);
00041     SG_DEBUG("deleting DomainAdaptationSVMLinear\n");
00042 }
00043 
00044 
00045 void CDomainAdaptationSVMLinear::init(CLinearMachine* pre_svm, float64_t B_param)
00046 {
00047 
00048     if (pre_svm)
00049     {
00050         // increase reference counts
00051         SG_REF(pre_svm);
00052 
00053         // set bias of parent svm to zero
00054         pre_svm->set_bias(0.0);
00055     }
00056 
00057     this->presvm = pre_svm;
00058     this->B = B_param;
00059     this->train_factor = 1.0;
00060 
00061     set_liblinear_solver_type(L2R_L1LOSS_SVC_DUAL);
00062 
00063     // invoke sanity check
00064     is_presvm_sane();
00065 
00066     // serialization code
00067     m_parameters->add((CSGObject**) &presvm, "presvm", "SVM to regularize against");
00068     m_parameters->add(&B, "B",  "Regularization strenth B.");
00069     m_parameters->add(&train_factor, "train_factor",  "train_factor");
00070 
00071 }
00072 
00073 
00074 bool CDomainAdaptationSVMLinear::is_presvm_sane()
00075 {
00076 
00077     if (!presvm) {
00078 
00079         SG_WARNING("presvm is null");
00080 
00081     } else {
00082 
00083         if (presvm->get_bias() != 0) {
00084             SG_ERROR("presvm bias not set to zero");
00085         }
00086 
00087         if (presvm->get_features()->get_feature_type() != this->get_features()->get_feature_type()) {
00088             SG_ERROR("feature types do not agree");
00089         }
00090     }
00091 
00092     return true;
00093 
00094 }
00095 
00096 
00097 bool CDomainAdaptationSVMLinear::train_machine(CDotFeatures* train_data)
00098 {
00099 
00100     CDotFeatures* tmp_data;
00101 
00102     if (train_data)
00103     {
00104         if (labels->get_num_labels() != train_data->get_num_vectors())
00105             SG_ERROR("Number of training vectors does not match number of labels\n");
00106         tmp_data = train_data;
00107 
00108     } else {
00109 
00110         tmp_data = features;
00111     }
00112 
00113     int32_t num_training_points = get_labels()->get_num_labels();
00114 
00115     std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points);
00116 
00117     if (presvm)
00118     {
00119         ASSERT(presvm->get_bias() == 0.0);
00120 
00121         // bias of parent SVM was set to zero in constructor, already contains B
00122         CLabels* parent_svm_out = presvm->apply(tmp_data);
00123 
00124         SG_DEBUG("pre-computing linear term from presvm\n");
00125 
00126         // pre-compute linear term
00127         for (int32_t i=0; i!=num_training_points; i++)
00128         {
00129             lin_term[i] = train_factor * B * get_label(i) * parent_svm_out->get_label(i) - 1.0;
00130         }
00131 
00132         // set linear term for QP
00133         this->set_linear_term(
00134                 SGVector<float64_t>(&lin_term[0], lin_term.size()));
00135 
00136     }
00137 
00138     /*
00139     // warm-start liblinear
00140     //TODO test this code, measure speed-ups
00141     //presvm w stored in presvm
00142     float64_t* tmp_w;
00143     presvm->get_w(tmp_w, w_dim);
00144 
00145     //copy vector
00146     float64_t* tmp_w_copy = SG_MALLOC(float64_t, w_dim);
00147     std::copy(tmp_w, tmp_w + w_dim, tmp_w_copy);
00148 
00149     for (int32_t i=0; i!=w_dim; i++)
00150     {
00151         tmp_w_copy[i] = B * tmp_w_copy[i];
00152     }
00153 
00154     //set w (copied in setter)
00155     set_w(tmp_w_copy, w_dim);
00156     SG_FREE(tmp_w_copy);
00157     */
00158 
00159     bool success = false;
00160 
00161     //train SVM
00162     if (train_data)
00163     {
00164         success = CLibLinear::train_machine(train_data);
00165     } else {
00166         success = CLibLinear::train_machine();
00167     }
00168 
00169     //ASSERT(presvm)
00170 
00171     return success;
00172 
00173 }
00174 
00175 
00176 CLinearMachine* CDomainAdaptationSVMLinear::get_presvm()
00177 {
00178     return presvm;
00179 }
00180 
00181 
00182 float64_t CDomainAdaptationSVMLinear::get_B()
00183 {
00184     return B;
00185 }
00186 
00187 
00188 float64_t CDomainAdaptationSVMLinear::get_train_factor()
00189 {
00190     return train_factor;
00191 }
00192 
00193 
00194 void CDomainAdaptationSVMLinear::set_train_factor(float64_t factor)
00195 {
00196     train_factor = factor;
00197 }
00198 
00199 
00200 CLabels* CDomainAdaptationSVMLinear::apply(CDotFeatures* data)
00201 {
00202 
00203     ASSERT(presvm->get_bias()==0.0);
00204 
00205     int32_t num_examples = data->get_num_vectors();
00206 
00207     CLabels* out_current = CLibLinear::apply(data);
00208 
00209     if (presvm)
00210     {
00211 
00212         // recursive call if used on DomainAdaptationSVM object
00213         CLabels* out_presvm = presvm->apply(data);
00214 
00215 
00216         // combine outputs
00217         for (int32_t i=0; i!=num_examples; i++)
00218         {
00219             float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00220             out_current->set_label(i, out_combined);
00221         }
00222 
00223     }
00224 
00225 
00226     return out_current;
00227 
00228 }
00229 
00230 #endif //HAVE_LAPACK
00231 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation