SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MulticlassSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #include <shogun/lib/common.h>
12 #include <shogun/io/SGIO.h>
15 
16 using namespace shogun;
17 
19  :CKernelMulticlassMachine(new CMulticlassOneVsRestStrategy(), NULL, new CSVM(0), NULL), m_C(0)
20 {
21  init();
22 }
23 
25  :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL), m_C(0)
26 {
27  init();
28 }
29 
31  CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
32  : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab), m_C(C)
33 {
34  init();
35 }
36 
38 {
39 }
40 
41 void CMulticlassSVM::init()
42 {
43  SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
44 }
45 
46 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
47 {
48  if (num_classes>0)
49  {
50  int32_t num_svms=m_multiclass_strategy->get_num_machines();
51 
53  for (index_t i=0; i<num_svms; ++i)
54  m_machines->push_back(NULL);
55 
56  return true;
57  }
58  return false;
59 }
60 
61 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm)
62 {
63  if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm)
64  {
65  m_machines->set_element(svm, num);
66  return true;
67  }
68  return false;
69 }
70 
72 {
73  if (is_data_locked())
74  {
75  SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when "
76  "data_lock was called before. Call data_unlock to allow.");
77  }
78 
79  if (!m_kernel)
80  SG_ERROR("No kernel assigned!\n");
81 
82  CFeatures* lhs=m_kernel->get_lhs();
83  if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED)
84  SG_ERROR("%s: No left hand side specified\n", get_name());
85 
87  {
88  SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to"
89  " an implementation error in %s, where it was forgotten to set "
90  "the data (m_svs) indices\n", get_name(),
91  data->get_name());
92  }
93 
94  if (data && m_kernel->get_kernel_type()!=K_COMBINED)
95  m_kernel->init(lhs, data);
96  SG_UNREF(lhs);
97 
98  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
99  {
100  CSVM *the_svm = (CSVM *)m_machines->get_element(i);
101  ASSERT(the_svm);
102  the_svm->set_kernel(m_kernel);
103  SG_UNREF(the_svm);
104  }
105 
106  return true;
107 }
108 
109 bool CMulticlassSVM::load(FILE* modelfl)
110 {
111  bool result=true;
112  char char_buffer[1024];
113  int32_t int_buffer;
114  float64_t double_buffer;
115  int32_t line_number=1;
116  int32_t svm_idx=-1;
117 
119 
120  if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
121  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
122  else
123  {
124  char_buffer[15]='\0';
125  if (strcmp("%MultiClassSVM", char_buffer)!=0)
126  SG_ERROR( "error in multiclass svm file, line nr:%d\n", line_number);
127 
128  line_number++;
129  }
130 
131  int_buffer=0;
132  if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
133  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
134 
135  if (!feof(modelfl))
136  line_number++;
137 
138  if (int_buffer < 2)
139  SG_ERROR("less than 2 classes - how is this multiclass?\n");
140 
141  create_multiclass_svm(int_buffer);
142 
143  int_buffer=0;
144  if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
145  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
146 
147  if (!feof(modelfl))
148  line_number++;
149 
150  if (m_machines->get_num_elements() != int_buffer)
151  SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer);
152 
153  if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
154  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
155 
156  if (!feof(modelfl))
157  line_number++;
158 
159  for (int32_t n=0; n<m_machines->get_num_elements(); n++)
160  {
161  svm_idx=-1;
162  if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
163  {
164  result=false;
165  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
166  }
167  else
168  {
169  char_buffer[4]='\0';
170  if (strncmp("%SVM", char_buffer, 4)!=0)
171  {
172  result=false;
173  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
174  }
175 
176  if (svm_idx != n)
177  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
178 
179  line_number++;
180  }
181 
182  int_buffer=0;
183  if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
184  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
185 
186  if (svm_idx != n)
187  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
188 
189  if (!feof(modelfl))
190  line_number++;
191 
192  SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx);
193  CSVM* svm=new CSVM(int_buffer);
194 
195  double_buffer=0;
196 
197  if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
198  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
199 
200  if (svm_idx != n)
201  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
202 
203  if (!feof(modelfl))
204  line_number++;
205 
206  svm->set_bias(double_buffer);
207 
208  if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
209  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
210 
211  if (svm_idx != n)
212  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx);
213 
214  if (!feof(modelfl))
215  line_number++;
216 
217  for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
218  {
219  double_buffer=0;
220  int_buffer=0;
221 
222  if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
223  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
224 
225  if (!feof(modelfl))
226  line_number++;
227 
228  svm->set_support_vector(i, int_buffer);
229  svm->set_alpha(i, double_buffer);
230  }
231 
232  if (fscanf(modelfl,"%2s", char_buffer) == EOF)
233  {
234  result=false;
235  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
236  }
237  else
238  {
239  char_buffer[3]='\0';
240  if (strcmp("];", char_buffer)!=0)
241  {
242  result=false;
243  SG_ERROR( "error in svm file, line nr:%d\n", line_number);
244  }
245  line_number++;
246  }
247 
248  set_svm(n, svm);
249  }
250 
251  svm_proto()->svm_loaded=result;
252 
254  return result;
255 }
256 
257 bool CMulticlassSVM::save(FILE* modelfl)
258 {
260 
261  if (!m_kernel)
262  SG_ERROR("Kernel not defined!\n");
263 
264  if (m_machines->get_num_elements()<1)
265  SG_ERROR("Multiclass SVM not trained!\n");
266 
267  SG_INFO( "Writing model file...");
268  fprintf(modelfl,"%%MultiClassSVM\n");
269  fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
270  fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
271  fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
272 
273  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
274  {
275  CSVM* svm=get_svm(i);
276  ASSERT(svm);
277  fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1);
278  fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
279  fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
280 
281  fprintf(modelfl, "alphas%d=[\n", i);
282 
283  for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
284  {
285  fprintf(modelfl,"\t[%+10.16e,%d];\n",
286  svm->get_alpha(j), svm->get_support_vector(j));
287  }
288 
289  fprintf(modelfl, "];\n");
290  }
291 
293  SG_DONE();
294  return true ;
295 }

SHOGUN Machine Learning Toolbox - Documentation