SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MulticlassSVM.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2009 Soeren Sonnenburg
8  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
9  */
10 
11 #include <shogun/lib/common.h>
12 #include <shogun/io/SGIO.h>
15 
16 using namespace shogun;
17 
20 {
21  init();
22 }
23 
25  :CKernelMulticlassMachine(strategy, NULL, new CSVM(0), NULL)
26 {
27  init();
28 }
29 
31  CMulticlassStrategy *strategy, float64_t C, CKernel* k, CLabels* lab)
32  : CKernelMulticlassMachine(strategy, k, new CSVM(C, k, lab), lab)
33 {
34  init();
35  m_C=C;
36 }
37 
39 {
40 }
41 
42 void CMulticlassSVM::init()
43 {
44  SG_ADD(&m_C, "C", "C regularization constant",MS_AVAILABLE);
45  m_C=0;
46 }
47 
48 bool CMulticlassSVM::create_multiclass_svm(int32_t num_classes)
49 {
50  if (num_classes>0)
51  {
52  int32_t num_svms=m_multiclass_strategy->get_num_machines();
53 
55  for (index_t i=0; i<num_svms; ++i)
56  m_machines->push_back(NULL);
57 
58  return true;
59  }
60  return false;
61 }
62 
63 bool CMulticlassSVM::set_svm(int32_t num, CSVM* svm)
64 {
65  if (m_machines->get_num_elements()>0 && m_machines->get_num_elements()>num && num>=0 && svm)
66  {
67  m_machines->set_element(svm, num);
68  return true;
69  }
70  return false;
71 }
72 
74 {
75  if (is_data_locked())
76  {
77  SG_ERROR("CKernelMachine::apply(CFeatures*) cannot be called when "
78  "data_lock was called before. Call data_unlock to allow.");
79  }
80 
81  if (!m_kernel)
82  SG_ERROR("No kernel assigned!\n")
83 
84  CFeatures* lhs=m_kernel->get_lhs();
85  if (!lhs && m_kernel->get_kernel_type()!=K_COMBINED)
86  SG_ERROR("%s: No left hand side specified\n", get_name())
87 
89  {
90  SG_ERROR("%s: No vectors on left hand side (%s). This is probably due to"
91  " an implementation error in %s, where it was forgotten to set "
92  "the data (m_svs) indices\n", get_name(),
93  data->get_name());
94  }
95 
96  if (data && m_kernel->get_kernel_type()!=K_COMBINED)
97  m_kernel->init(lhs, data);
98  SG_UNREF(lhs);
99 
100  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
101  {
102  CSVM *the_svm = (CSVM *)m_machines->get_element(i);
103  ASSERT(the_svm)
104  the_svm->set_kernel(m_kernel);
105  SG_UNREF(the_svm);
106  }
107 
108  return true;
109 }
110 
111 bool CMulticlassSVM::load(FILE* modelfl)
112 {
113  bool result=true;
114  char char_buffer[1024];
115  int32_t int_buffer;
116  float64_t double_buffer;
117  int32_t line_number=1;
118  int32_t svm_idx=-1;
119 
121 
122  if (fscanf(modelfl,"%15s\n", char_buffer)==EOF)
123  SG_ERROR("error in svm file, line nr:%d\n", line_number)
124  else
125  {
126  char_buffer[15]='\0';
127  if (strcmp("%MultiClassSVM", char_buffer)!=0)
128  SG_ERROR("error in multiclass svm file, line nr:%d\n", line_number)
129 
130  line_number++;
131  }
132 
133  int_buffer=0;
134  if (fscanf(modelfl," num_classes=%d; \n", &int_buffer) != 1)
135  SG_ERROR("error in svm file, line nr:%d\n", line_number)
136 
137  if (!feof(modelfl))
138  line_number++;
139 
140  if (int_buffer < 2)
141  SG_ERROR("less than 2 classes - how is this multiclass?\n")
142 
143  create_multiclass_svm(int_buffer);
144 
145  int_buffer=0;
146  if (fscanf(modelfl," num_svms=%d; \n", &int_buffer) != 1)
147  SG_ERROR("error in svm file, line nr:%d\n", line_number)
148 
149  if (!feof(modelfl))
150  line_number++;
151 
152  if (m_machines->get_num_elements() != int_buffer)
153  SG_ERROR("Mismatch in number of svms: m_num_svms=%d vs m_num_svms(file)=%d\n", m_machines->get_num_elements(), int_buffer)
154 
155  if (fscanf(modelfl," kernel='%s'; \n", char_buffer) != 1)
156  SG_ERROR("error in svm file, line nr:%d\n", line_number)
157 
158  if (!feof(modelfl))
159  line_number++;
160 
161  for (int32_t n=0; n<m_machines->get_num_elements(); n++)
162  {
163  svm_idx=-1;
164  if (fscanf(modelfl,"\n%4s %d of %d\n", char_buffer, &svm_idx, &int_buffer)==EOF)
165  {
166  result=false;
167  SG_ERROR("error in svm file, line nr:%d\n", line_number)
168  }
169  else
170  {
171  char_buffer[4]='\0';
172  if (strncmp("%SVM", char_buffer, 4)!=0)
173  {
174  result=false;
175  SG_ERROR("error in svm file, line nr:%d\n", line_number)
176  }
177 
178  if (svm_idx != n)
179  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
180 
181  line_number++;
182  }
183 
184  int_buffer=0;
185  if (fscanf(modelfl,"numsv%d=%d;\n", &svm_idx, &int_buffer) != 2)
186  SG_ERROR("error in svm file, line nr:%d\n", line_number)
187 
188  if (svm_idx != n)
189  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
190 
191  if (!feof(modelfl))
192  line_number++;
193 
194  SG_INFO("loading %ld support vectors for svm %d\n",int_buffer, svm_idx)
195  CSVM* svm=new CSVM(int_buffer);
196 
197  double_buffer=0;
198 
199  if (fscanf(modelfl," b%d=%lf; \n", &svm_idx, &double_buffer) != 2)
200  SG_ERROR("error in svm file, line nr:%d\n", line_number)
201 
202  if (svm_idx != n)
203  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
204 
205  if (!feof(modelfl))
206  line_number++;
207 
208  svm->set_bias(double_buffer);
209 
210  if (fscanf(modelfl,"alphas%d=[\n", &svm_idx) != 1)
211  SG_ERROR("error in svm file, line nr:%d\n", line_number)
212 
213  if (svm_idx != n)
214  SG_ERROR("svm index mismatch n=%d, n(file)=%d\n", n, svm_idx)
215 
216  if (!feof(modelfl))
217  line_number++;
218 
219  for (int32_t i=0; i<svm->get_num_support_vectors(); i++)
220  {
221  double_buffer=0;
222  int_buffer=0;
223 
224  if (fscanf(modelfl,"\t[%lf,%d]; \n", &double_buffer, &int_buffer) != 2)
225  SG_ERROR("error in svm file, line nr:%d\n", line_number)
226 
227  if (!feof(modelfl))
228  line_number++;
229 
230  svm->set_support_vector(i, int_buffer);
231  svm->set_alpha(i, double_buffer);
232  }
233 
234  if (fscanf(modelfl,"%2s", char_buffer) == EOF)
235  {
236  result=false;
237  SG_ERROR("error in svm file, line nr:%d\n", line_number)
238  }
239  else
240  {
241  char_buffer[3]='\0';
242  if (strcmp("];", char_buffer)!=0)
243  {
244  result=false;
245  SG_ERROR("error in svm file, line nr:%d\n", line_number)
246  }
247  line_number++;
248  }
249 
250  set_svm(n, svm);
251  }
252 
253  svm_proto()->svm_loaded=result;
254 
256  return result;
257 }
258 
259 bool CMulticlassSVM::save(FILE* modelfl)
260 {
262 
263  if (!m_kernel)
264  SG_ERROR("Kernel not defined!\n")
265 
266  if (m_machines->get_num_elements()<1)
267  SG_ERROR("Multiclass SVM not trained!\n")
268 
269  SG_INFO("Writing model file...")
270  fprintf(modelfl,"%%MultiClassSVM\n");
271  fprintf(modelfl,"num_classes=%d;\n", m_multiclass_strategy->get_num_classes());
272  fprintf(modelfl,"num_svms=%d;\n", m_machines->get_num_elements());
273  fprintf(modelfl,"kernel='%s';\n", m_kernel->get_name());
274 
275  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
276  {
277  CSVM* svm=get_svm(i);
278  ASSERT(svm)
279  fprintf(modelfl,"\n%%SVM %d of %d\n", i, m_machines->get_num_elements()-1);
280  fprintf(modelfl,"numsv%d=%d;\n", i, svm->get_num_support_vectors());
281  fprintf(modelfl,"b%d=%+10.16e;\n",i,svm->get_bias());
282 
283  fprintf(modelfl, "alphas%d=[\n", i);
284 
285  for(int32_t j=0; j<svm->get_num_support_vectors(); j++)
286  {
287  fprintf(modelfl,"\t[%+10.16e,%d];\n",
288  svm->get_alpha(j), svm->get_support_vector(j));
289  }
290 
291  fprintf(modelfl, "];\n");
292  }
293 
295  SG_DONE()
296  return true ;
297 }

SHOGUN Machine Learning Toolbox - Documentation