SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
DomainAdaptationSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2007-2011 Christian Widmer
8  * Copyright (C) 2007-2011 Max-Planck-Society
9  */
10 
11 #include <shogun/lib/config.h>
12 
13 #ifdef USE_SVMLIGHT
14 
16 #include <shogun/io/SGIO.h>
17 #include <shogun/labels/Labels.h>
20 #include <iostream>
21 #include <vector>
22 
23 using namespace shogun;
24 
26 {
27 }
28 
30 {
31  init();
32  init(pre_svm, B_param);
33 }
34 
36 {
38  SG_DEBUG("deleting DomainAdaptationSVM\n");
39 }
40 
41 
42 void CDomainAdaptationSVM::init(CSVM* pre_svm, float64_t B_param)
43 {
44  // increase reference counts
45  SG_REF(pre_svm);
46 
47  this->presvm=pre_svm;
48  this->B=B_param;
49  this->train_factor=1.0;
50 
51  // set bias of parent svm to zero
52  this->presvm->set_bias(0.0);
53 
54  // invoke sanity check
56 }
57 
59 {
60  if (!presvm) {
61  SG_ERROR("presvm is null");
62  }
63 
64  if (presvm->get_num_support_vectors() == 0) {
65  SG_ERROR("presvm has no support vectors, please train first");
66  }
67 
68  if (presvm->get_bias() != 0) {
69  SG_ERROR("presvm bias not set to zero");
70  }
71 
73  SG_ERROR("kernel types do not agree");
74  }
75 
77  SG_ERROR("feature types do not agree");
78  }
79 
80  return true;
81 }
82 
83 
85 {
86 
87  if (data)
88  {
89  if (m_labels->get_num_labels() != data->get_num_vectors())
90  SG_ERROR("Number of training vectors does not match number of labels\n");
91  kernel->init(data, data);
92  }
93 
95  SG_ERROR("DomainAdaptationSVM requires binary labels\n");
96 
97  int32_t num_training_points = get_labels()->get_num_labels();
98  CBinaryLabels* labels = (CBinaryLabels*) get_labels();
99 
100  float64_t* lin_term = SG_MALLOC(float64_t, num_training_points);
101 
102  // grab current training features
103  CFeatures* train_data = get_kernel()->get_lhs();
104 
105  // bias of parent SVM was set to zero in constructor, already contains B
106  CBinaryLabels* parent_svm_out = presvm->apply_binary(train_data);
107 
108  // pre-compute linear term
109  for (int32_t i=0; i<num_training_points; i++)
110  {
111  lin_term[i] = train_factor * B * labels->get_label(i) * parent_svm_out->get_label(i) - 1.0;
112  }
113 
114  //set linear term for QP
115  this->set_linear_term(SGVector<float64_t>(lin_term, num_training_points));
116 
117  //train SVM
118  bool success = CSVMLight::train_machine();
119  SG_UNREF(labels);
120 
121  ASSERT(presvm)
122 
123  return success;
124 
125 }
126 
127 
129 {
130  SG_REF(presvm);
131  return presvm;
132 }
133 
134 
136 {
137  return B;
138 }
139 
140 
142 {
143  return train_factor;
144 }
145 
146 
148 {
149  train_factor = factor;
150 }
151 
152 
154 {
155  ASSERT(data);
156  ASSERT(presvm->get_bias()==0.0);
157 
158  int32_t num_examples = data->get_num_vectors();
159 
160  CBinaryLabels* out_current = CSVMLight::apply_binary(data);
161 
162  // recursive call if used on DomainAdaptationSVM object
163  CBinaryLabels* out_presvm = presvm->apply_binary(data);
164 
165  // combine outputs
166  SGVector<float64_t> out_combined(num_examples);
167  for (int32_t i=0; i<num_examples; i++)
168  {
169  out_combined[i] = out_current->get_confidence(i) + B*out_presvm->get_confidence(i);
170  }
171  SG_UNREF(out_current);
172  SG_UNREF(out_presvm);
173 
174  return new CBinaryLabels(out_combined);
175 
176 }
177 
178 void CDomainAdaptationSVM::init()
179 {
180  presvm = NULL;
181  B = 0;
182  train_factor = 1.0;
183 
184  m_parameters->add((CSGObject**) &presvm, "presvm",
185  "SVM to regularize against.");
186  m_parameters->add(&B, "B", "regularization parameter B.");
188  "train_factor", "flag to switch off regularization in training.");
189 }
190 
191 #endif //USE_SVMLIGHT

SHOGUN Machine Learning Toolbox - Documentation