DomainAdaptationSVMLinear.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 2 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2010 Christian Widmer
00008  * Copyright (C) 2007-2010 Max-Planck-Society
00009  */
00010 
00011 #include "lib/config.h"
00012 
00013 #ifdef HAVE_LAPACK
00014 
00015 #include "classifier/svm/DomainAdaptationSVMLinear.h"
00016 #include "lib/io.h"
00017 #include "base/Parameter.h"
00018 #include <iostream>
00019 #include <vector>
00020 
00021 using namespace shogun;
00022 
00023 
00024 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear() : CLibLinear(L2R_L1LOSS_SVC_DUAL)
00025 {
00026     init(NULL, 0.0);
00027 }
00028 
00029 
00030 CDomainAdaptationSVMLinear::CDomainAdaptationSVMLinear(float64_t C, CDotFeatures* f, CLabels* lab, CLinearClassifier* pre_svm, float64_t B_param) : CLibLinear(C, f, lab)
00031 {
00032     init(pre_svm, B_param);
00033 
00034 }
00035 
00036 
00037 CDomainAdaptationSVMLinear::~CDomainAdaptationSVMLinear()
00038 {
00039 
00040     SG_UNREF(presvm);
00041     SG_DEBUG("deleting DomainAdaptationSVMLinear\n");
00042 }
00043 
00044 
00045 void CDomainAdaptationSVMLinear::init(CLinearClassifier* pre_svm, float64_t B_param)
00046 {
00047 
00048     if (pre_svm)
00049     {
00050         // increase reference counts
00051         SG_REF(pre_svm);
00052 
00053         // set bias of parent svm to zero
00054         pre_svm->set_bias(0.0);
00055     }
00056 
00057     this->presvm=pre_svm;
00058     this->B=B_param;
00059     this->train_factor=1.0;
00060 
00061     set_liblinear_solver_type(L2R_L1LOSS_SVC_DUAL);
00062 
00063     // invoke sanity check
00064     is_presvm_sane();
00065 
00066     // serialization code
00067     m_parameters->add((CSGObject**) &presvm, "presvm", "SVM to regularize against");
00068     m_parameters->add(&B, "B",  "Regularization strenth B.");
00069     m_parameters->add(&train_factor, "train_factor",  "train_factor");
00070 
00071 }
00072 
00073 
00074 bool CDomainAdaptationSVMLinear::is_presvm_sane()
00075 {
00076 
00077     if (!presvm) {
00078 
00079         SG_WARNING("presvm is null");
00080 
00081     } else {
00082 
00083         if (presvm->get_bias() != 0) {
00084             SG_ERROR("presvm bias not set to zero");
00085         }
00086 
00087         if (presvm->get_features()->get_feature_type() != this->get_features()->get_feature_type()) {
00088             SG_ERROR("feature types do not agree");
00089         }
00090     }
00091 
00092     return true;
00093 
00094 }
00095 
00096 
00097 bool CDomainAdaptationSVMLinear::train(CDotFeatures* train_data)
00098 {
00099 
00100     CDotFeatures* tmp_data;
00101 
00102     if (train_data)
00103     {
00104         if (labels->get_num_labels() != train_data->get_num_vectors())
00105             SG_ERROR("Number of training vectors does not match number of labels\n");
00106         tmp_data = train_data;
00107 
00108     } else {
00109 
00110         tmp_data = features;
00111     }
00112 
00113     int32_t num_training_points = get_labels()->get_num_labels();
00114 
00115     std::vector<float64_t> lin_term = std::vector<float64_t>(num_training_points);
00116 
00117     if (presvm)
00118     {
00119         // bias of parent SVM was set to zero in constructor, already contains B
00120         CLabels* parent_svm_out = presvm->classify(tmp_data);
00121 
00122         SG_DEBUG("pre-computing linear term from presvm\n");
00123 
00124         // pre-compute linear term
00125         for (int32_t i=0; i!=num_training_points; i++)
00126         {
00127             lin_term[i] = (- B*(get_label(i) * parent_svm_out->get_label(i)))*train_factor - 1.0;
00128         }
00129 
00130         // set linear term for QP
00131         this->set_linear_term(&lin_term[0], lin_term.size());
00132 
00133     }
00134 
00135     /*
00136     //TODO test this code, measure speed-ups
00137     //presvm w stored in presvm
00138     float64_t* tmp_w;
00139     presvm->get_w(tmp_w, w_dim);
00140 
00141     //copy vector
00142     float64_t* tmp_w_copy = new float64_t[w_dim];
00143     std::copy(tmp_w, tmp_w + w_dim, tmp_w_copy);
00144 
00145     for (int32_t i=0; i!=w_dim; i++)
00146     {
00147         tmp_w_copy[i] = B * tmp_w_copy[i];
00148     }
00149 
00150     //set w (copied in setter)
00151     set_w(tmp_w_copy, w_dim);
00152     delete[] tmp_w_copy;
00153     */
00154 
00155     bool success = false;
00156 
00157     //train SVM
00158     if (train_data)
00159     {
00160         success = CLibLinear::train(train_data);
00161     } else {
00162         success = CLibLinear::train();
00163     }
00164 
00165     //ASSERT(presvm)
00166 
00167     return success;
00168 
00169 }
00170 
00171 
00172 CLinearClassifier* CDomainAdaptationSVMLinear::get_presvm()
00173 {
00174     return presvm;
00175 }
00176 
00177 
00178 float64_t CDomainAdaptationSVMLinear::get_B()
00179 {
00180     return B;
00181 }
00182 
00183 
00184 float64_t CDomainAdaptationSVMLinear::get_train_factor()
00185 {
00186     return train_factor;
00187 }
00188 
00189 
00190 void CDomainAdaptationSVMLinear::set_train_factor(float64_t factor)
00191 {
00192     train_factor = factor;
00193 }
00194 
00195 
00196 CLabels* CDomainAdaptationSVMLinear::classify(CDotFeatures* data)
00197 {
00198 
00199     ASSERT(presvm->get_bias()==0.0);
00200 
00201     int32_t num_examples = data->get_num_vectors();
00202 
00203     CLabels* out_current = CLibLinear::classify(data);
00204 
00205     if (presvm)
00206     {
00207 
00208         // recursive call if used on DomainAdaptationSVM object
00209         CLabels* out_presvm = presvm->classify(data);
00210 
00211 
00212         // combine outputs
00213         for (int32_t i=0; i!=num_examples; i++)
00214         {
00215             float64_t out_combined = out_current->get_label(i) + B*out_presvm->get_label(i);
00216             out_current->set_label(i, out_combined);
00217         }
00218 
00219     }
00220 
00221 
00222     return out_current;
00223 
00224 }
00225 
00226 #endif //HAVE_LAPACK
00227 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation