SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
GPBTSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 
13 #ifdef USE_GPL_SHOGUN
14 #include <shogun/lib/external/gpdt.h>
15 #include <shogun/lib/external/gpdtsolve.h>
16 #include <shogun/io/SGIO.h>
18 
19 using namespace shogun;
20 
21 CGPBTSVM::CGPBTSVM()
22 : CSVM(), model(NULL)
23 {
24 }
25 
26 CGPBTSVM::CGPBTSVM(float64_t C, CKernel* k, CLabels* lab)
27 : CSVM(C, k, lab), model(NULL)
28 {
29 }
30 
31 CGPBTSVM::~CGPBTSVM()
32 {
33  SG_FREE(model);
34 }
35 
36 bool CGPBTSVM::train_machine(CFeatures* data)
37 {
38  float64_t* solution; /* store the solution found */
39  QPproblem prob; /* object containing the solvers */
40 
41  ASSERT(kernel)
42  ASSERT(m_labels && m_labels->get_num_labels())
43  ASSERT(m_labels->get_label_type() == LT_BINARY)
44  if (data)
45  {
46  if (m_labels->get_num_labels() != data->get_num_vectors())
47  SG_ERROR("Number of training vectors does not match number of labels\n")
48  kernel->init(data, data);
49  }
50 
51  SGVector<int32_t> lab=((CBinaryLabels*) m_labels)->get_int_labels();
52  prob.KER=new sKernel(kernel, lab.vlen);
53  prob.y=lab.vector;
54  prob.ell=lab.vlen;
55  SG_INFO("%d trainlabels\n", prob.ell)
56 
57  // /*** set options defaults ***/
58  prob.delta = epsilon;
59  prob.maxmw = kernel->get_cache_size();
60  prob.verbosity = 0;
61  prob.preprocess_size = -1;
62  prob.projection_projector = -1;
63  prob.c_const = get_C1();
64  prob.chunk_size = get_qpsize();
65  prob.linadd = get_linadd_enabled();
66 
67  if (prob.chunk_size < 2) prob.chunk_size = 2;
68  if (prob.q <= 0) prob.q = prob.chunk_size / 3;
69  if (prob.q < 2) prob.q = 2;
70  if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
71  prob.q = prob.q & (~1);
72  if (prob.maxmw < 5)
73  prob.maxmw = 5;
74 
75  /*** set the problem description for final report ***/
76  SG_INFO("\nTRAINING PARAMETERS:\n")
77  SG_INFO("\tNumber of training documents: %d\n", prob.ell)
78  SG_INFO("\tq: %d\n", prob.chunk_size)
79  SG_INFO("\tn: %d\n", prob.q)
80  SG_INFO("\tC: %lf\n", prob.c_const)
81  SG_INFO("\tkernel type: %d\n", prob.ker_type)
82  SG_INFO("\tcache size: %dMb\n", prob.maxmw)
83  SG_INFO("\tStopping tolerance: %lf\n", prob.delta)
84 
85  // /*** compute the number of cache rows up to maxmw Mb. ***/
86  if (prob.preprocess_size == -1)
87  prob.preprocess_size = (int32_t) ( (float64_t)prob.chunk_size * 1.5 );
88 
89  if (prob.projection_projector == -1)
90  {
91  if (prob.chunk_size <= 20) prob.projection_projector = 0;
92  else prob.projection_projector = 1;
93  }
94 
95  /*** compute the problem solution *******************************************/
96  solution = SG_MALLOC(float64_t, prob.ell);
97  prob.gpdtsolve(solution);
98  /****************************************************************************/
99 
100  CSVM::set_objective(prob.objective_value);
101 
102  int32_t num_sv=0;
103  int32_t bsv=0;
104  int32_t i=0;
105  int32_t k=0;
106 
107  for (i = 0; i < prob.ell; i++)
108  {
109  if (solution[i] > prob.DELTAsv)
110  {
111  num_sv++;
112  if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
113  }
114  }
115 
116  create_new_model(num_sv);
117  set_bias(prob.bee);
118 
119  SG_INFO("SV: %d BSV = %d\n", num_sv, bsv)
120 
121  for (i = 0; i < prob.ell; i++)
122  {
123  if (solution[i] > prob.DELTAsv)
124  {
125  set_support_vector(k, i);
126  set_alpha(k++, solution[i]*((CBinaryLabels*) m_labels)->get_label(i));
127  }
128  }
129 
130  delete prob.KER;
131  SG_FREE(solution);
132 
133  return true;
134 }
135 #endif //USE_GPL_SHOGUN
#define SG_INFO(...)
Definition: SGIO.h:118
binary labels +1/-1
Definition: LabelTypes.h:18
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
#define SG_ERROR(...)
Definition: SGIO.h:129
index_t vlen
Definition: SGVector.h:494
#define ASSERT(x)
Definition: SGIO.h:201
double float64_t
Definition: common.h:50
void set_objective(float64_t v)
Definition: SVM.h:209
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
The class Features is the base class of all feature objects.
Definition: Features.h:68
A generic Support Vector Machine Interface.
Definition: SVM.h:49
The Kernel base class.
Definition: Kernel.h:159
Binary Labels for binary classification.
Definition: BinaryLabels.h:37

SHOGUN Machine Learning Toolbox - Documentation