SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
DomainAdaptationSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2007-2011 Christian Widmer
8  * Copyright (C) 2007-2011 Max-Planck-Society
9  */
10 
11 #include <shogun/lib/config.h>
12 
13 #ifdef USE_SVMLIGHT
14 
16 #include <shogun/io/SGIO.h>
17 #include <shogun/labels/Labels.h>
20 #include <iostream>
21 #include <vector>
22 
23 using namespace shogun;
24 
26 {
27  init();
28 }
29 
31 {
32  init();
33  init(pre_svm, B_param);
34 }
35 
37 {
39  SG_DEBUG("deleting DomainAdaptationSVM\n")
40 }
41 
42 
43 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
44 {
45  REQUIRE(pre_svm != NULL, "Pre SVM should not be null");
46  // increase reference counts
47  SG_REF(pre_svm);
48 
49  this->presvm=pre_svm;
50  this->B=B_param;
51  this->train_factor=1.0;
52 
53  // set bias of parent svm to zero
54  this->presvm->set_bias(0.0);
55 
56  // invoke sanity check
58 }
59 
61 {
62  if (!presvm) {
63  SG_ERROR("presvm is null")
64  }
65 
66  if (presvm->get_num_support_vectors() == 0) {
67  SG_ERROR("presvm has no support vectors, please train first")
68  }
69 
70  if (presvm->get_bias() != 0) {
71  SG_ERROR("presvm bias not set to zero")
72  }
73 
75  SG_ERROR("kernel types do not agree")
76  }
77 
79  SG_ERROR("feature types do not agree")
80  }
81 
82  return true;
83 }
84 
85 
87 {
88 
89  if (data)
90  {
91  if (m_labels->get_num_labels() != data->get_num_vectors())
92  SG_ERROR("Number of training vectors does not match number of labels\n")
93  kernel->init(data, data);
94  }
95 
97  SG_ERROR("DomainAdaptationSVM requires binary labels\n")
98 
99  int32_t num_training_points = get_labels()->get_num_labels();
100  CBinaryLabels* labels = (CBinaryLabels*) get_labels();
101 
102  float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
103 
104  // grab current training features
105  CFeatures* train_data = get_kernel()->get_lhs();
106 
107  // bias of parent SVM was set to zero in constructor, already contains B
108  CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data);
109 
110  // pre-compute linear term
111  for (int32_t i=0; i<num_training_points; i++)
112  {
113  lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0;
114  }
115 
116  //set linear term for QP
117  this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
118 
119  //train SVM
120  bool success = CSVMLight::train_machine();
121  SG_UNREF(labels);
122 
123  ASSERT(presvm)
124 
125  return success;
126 
127 }
128 
129 
131 {
132  SG_REF(presvm);
133  return presvm;
134 }
135 
136 
138 {
139  return B;
140 }
141 
142 
144 {
145  return train_factor;
146 }
147 
148 
150 {
151  train_factor = factor;
152 }
153 
154 
156 {
157  ASSERT(data)
158  ASSERT(presvm->get_bias()==0.0)
159 
160  int32_t num_examples = data->get_num_vectors();
161 
162  CBinaryLabels* out_current = CSVMLight::apply_binary(data);
163 
164  // recursive call if used on DomainAdaptationSVM object
165  CBinaryLabels* out_presvm = presvm->apply_binary(data);
166 
167  // combine outputs
168  SGVector<float64_t> out_combined(num_examples);
169  for (int32_t i=0; i<num_examples; i++)
170  {
171  out_combined[i] = out_current->get_value(i) + B*out_presvm->get_value(i);
172  }
173  SG_UNREF(out_current);
174  SG_UNREF(out_presvm);
175 
176  return new CBinaryLabels(out_combined);
177 
178 }
179 
180 void CDomainAdaptationSVM::init()
181 {
182  presvm = NULL;
183  B = 0;
184  train_factor = 1.0;
185 
186  SG_ADD((CSGObject**) &presvm, "presvm", "SVM to regularize against.",
188  SG_ADD(&B, "B", "regularization parameter B.", MS_AVAILABLE);
189  SG_ADD(&train_factor, "train_factor",
190  "flag to switch off regularization in training.", MS_AVAILABLE);
191 }
192 
193 #endif //USE_SVMLIGHT
virtual bool init(CFeatures *lhs, CFeatures *rhs)
Definition: Kernel.cpp:98
virtual float64_t get_value(int32_t idx)
Definition: Labels.cpp:59
void init(CSVM *presvm, float64_t B)
virtual ELabelType get_label_type() const =0
binary labels +1/-1
Definition: LabelTypes.h:18
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
virtual bool train_machine(CFeatures *data=NULL)
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
virtual int32_t get_num_vectors() const =0
CLabels * m_labels
Definition: Machine.h:361
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
virtual void set_linear_term(const SGVector< float64_t > linear_term)
Definition: SVM.cpp:314
float64_t get_label(int32_t idx)
#define SG_REF(x)
Definition: SGObject.h:54
#define ASSERT(x)
Definition: SGIO.h:201
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:115
void set_bias(float64_t bias)
double float64_t
Definition: common.h:50
virtual EFeatureType get_feature_type()=0
virtual CLabels * get_labels()
Definition: Machine.cpp:76
virtual bool train_machine(CFeatures *data=NULL)
Definition: SVMLight.cpp:181
virtual void set_train_factor(float64_t factor)
#define SG_UNREF(x)
Definition: SGObject.h:55
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual EKernelType get_kernel_type()=0
The class Features is the base class of all feature objects.
Definition: Features.h:68
class SVMlight
Definition: SVMLight.h:225
A generic Support Vector Machine Interface.
Definition: SVM.h:49
The Kernel base class.
Definition: Kernel.h:159
Binary Labels for binary classification.
Definition: BinaryLabels.h:37
#define SG_ADD(...)
Definition: SGObject.h:84
virtual CBinaryLabels * apply_binary(CFeatures *data=NULL)
CFeatures * get_lhs()
Definition: Kernel.h:505

SHOGUN Machine Learning Toolbox - Documentation