SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MulticlassMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2011 Soeren Sonnenburg
8  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
9  * Written (W) 2013 Shell Hu and Heiko Strathmann
10  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
11  */
12 
17 #include <shogun/base/Parameter.h>
21 
22 using namespace shogun;
23 
25 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
26  m_machine(NULL)
27 {
29  register_parameters();
30 }
31 
33  CMulticlassStrategy *strategy,
34  CMachine* machine, CLabels* labs)
35 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
36 {
37  SG_REF(strategy);
38  set_labels(labs);
39  SG_REF(machine);
40  m_machine = machine;
41  register_parameters();
42 
43  if (labs)
44  init_strategy();
45 }
46 
48 {
51 }
52 
54 {
56  if (lab)
57  init_strategy();
58 }
59 
60 void CMulticlassMachine::register_parameters()
61 {
62  SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
63  SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
64 }
65 
67 {
68  int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
70 }
71 
73 {
74  CMachine *machine = (CMachine*)m_machines->get_element(i);
75  ASSERT(machine)
76  CBinaryLabels* output = machine->apply_binary();
77  SG_UNREF(machine);
78  return output;
79 }
80 
82 {
83  CMachine *machine = get_machine(i);
84  float64_t output = 0.0;
85  // dirty hack
86  if (dynamic_cast<CLinearMachine*>(machine))
87  output = ((CLinearMachine*)machine)->apply_one(num);
88  if (dynamic_cast<CKernelMachine*>(machine))
89  output = ((CKernelMachine*)machine)->apply_one(num);
90  SG_UNREF(machine);
91  return output;
92 }
93 
95 {
96  SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
97  get_name(), data ? data->get_name() : "NULL", data);
98 
99  CMulticlassLabels* return_labels=NULL;
100 
101  if (data)
103  else
105 
106  if (is_ready())
107  {
108  /* num vectors depends on whether data is provided */
109  int32_t num_vectors=data ? data->get_num_vectors() :
111 
112  int32_t num_machines=m_machines->get_num_elements();
113  if (num_machines <= 0)
114  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
115 
116  CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
117 
118  // if outputs are prob, only one confidence for each class
119  int32_t num_classes=m_multiclass_strategy->get_num_classes();
121 
122  if (heuris!=PROB_HEURIS_NONE)
123  result->allocate_confidences_for(num_classes);
124  else
125  result->allocate_confidences_for(num_machines);
126 
127  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
128  SGVector<float64_t> As(num_machines);
129  SGVector<float64_t> Bs(num_machines);
130 
131  for (int32_t i=0; i<num_machines; ++i)
132  {
133  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
134 
135  if (heuris==OVA_SOFTMAX)
136  {
137  CStatistics::SigmoidParamters params = CStatistics::fit_sigmoid(outputs[i]->get_values());
138  As[i] = params.a;
139  Bs[i] = params.b;
140  }
141 
142  if (heuris!=PROB_HEURIS_NONE && heuris!=OVA_SOFTMAX)
143  outputs[i]->scores_to_probabilities(0,0);
144  }
145 
146  SGVector<float64_t> output_for_i(num_machines);
147  SGVector<float64_t> r_output_for_i(num_machines);
148  if (heuris!=PROB_HEURIS_NONE)
149  r_output_for_i.resize_vector(num_classes);
150 
151  for (int32_t i=0; i<num_vectors; i++)
152  {
153  for (int32_t j=0; j<num_machines; j++)
154  output_for_i[j] = outputs[j]->get_value(i);
155 
156  if (heuris==PROB_HEURIS_NONE)
157  {
158  r_output_for_i = output_for_i;
159  }
160  else
161  {
162  if (heuris==OVA_SOFTMAX)
163  m_multiclass_strategy->rescale_outputs(output_for_i,As,Bs);
164  else
165  m_multiclass_strategy->rescale_outputs(output_for_i);
166 
167  // only first num_classes are returned
168  for (int32_t r=0; r<num_classes; r++)
169  r_output_for_i[r] = output_for_i[r];
170 
171  SG_DEBUG("%s::apply_multiclass(): sum(r_output_for_i) = %f\n",
172  get_name(), SGVector<float64_t>::sum(r_output_for_i.vector,num_classes));
173  }
174 
175  // use rescaled outputs for label decision
176  result->set_label(i, m_multiclass_strategy->decide_label(r_output_for_i));
177  result->set_multiclass_confidences(i, r_output_for_i);
178  }
179 
180  for (int32_t i=0; i < num_machines; ++i)
181  SG_UNREF(outputs[i]);
182 
183  SG_FREE(outputs);
184 
185  return_labels=result;
186  }
187  else
188  SG_ERROR("Not ready")
189 
190 
191  SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
192  get_name(), data ? data->get_name() : "NULL", data);
193  return return_labels;
194 }
195 
197 {
198  CMulticlassMultipleOutputLabels* return_labels=NULL;
199 
200  if (data)
202  else
204 
205  if (is_ready())
206  {
207  /* num vectors depends on whether data is provided */
208  int32_t num_vectors=data ? data->get_num_vectors() :
210 
211  int32_t num_machines=m_machines->get_num_elements();
212  if (num_machines <= 0)
213  SG_ERROR("num_machines = %d, did you train your machine?", num_machines)
214  REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available")
215 
217  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
218 
219  for (int32_t i=0; i < num_machines; ++i)
220  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
221 
222  SGVector<float64_t> output_for_i(num_machines);
223  for (int32_t i=0; i<num_vectors; i++)
224  {
225  for (int32_t j=0; j<num_machines; j++)
226  output_for_i[j] = outputs[j]->get_value(i);
227 
228  result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
229  }
230 
231  for (int32_t i=0; i < num_machines; ++i)
232  SG_UNREF(outputs[i]);
233 
234  SG_FREE(outputs);
235 
236  return_labels=result;
237  }
238  else
239  SG_ERROR("Not ready")
240 
241  return return_labels;
242 }
243 
245 {
247 
248  if ( !data && !is_ready() )
249  SG_ERROR("Please provide training data.\n")
250  else
252 
254  CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
255  SG_REF(train_labels);
256  m_machine->set_labels(train_labels);
257 
260  {
262  if (subset.vlen)
263  {
264  train_labels->add_subset(subset);
265  add_machine_subset(subset);
266  }
267 
268  m_machine->train();
270 
271  if (subset.vlen)
272  {
273  train_labels->remove_subset();
275  }
276  }
277 
279  SG_UNREF(train_labels);
280 
281  return true;
282 }
283 
285 {
287 
290 
291  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
292  outputs[i] = get_submachine_output(i, vec_idx);
293 
294  float64_t result = m_multiclass_strategy->decide_label(outputs);
295 
296  return result;
297 }

SHOGUN Machine Learning Toolbox - Documentation