SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
MulticlassMachine.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 1999-2011 Soeren Sonnenburg
8  * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn
9  * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia
10  */
11 
16 #include <shogun/base/Parameter.h>
19 
20 using namespace shogun;
21 
23 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
24  m_machine(NULL)
25 {
27  register_parameters();
28 }
29 
31  CMulticlassStrategy *strategy,
32  CMachine* machine, CLabels* labs)
33 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
34 {
35  SG_REF(strategy);
36  set_labels(labs);
37  SG_REF(machine);
38  m_machine = machine;
39  register_parameters();
40 
41  if (labs)
42  init_strategy();
43 }
44 
46 {
49 }
50 
52 {
54  if (lab)
55  init_strategy();
56 }
57 
58 void CMulticlassMachine::register_parameters()
59 {
60  SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
61  SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
62 }
63 
65 {
66  int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
68 }
69 
71 {
72  CMachine *machine = (CMachine*)m_machines->get_element(i);
73  ASSERT(machine);
74  CBinaryLabels* output = machine->apply_binary();
75  SG_UNREF(machine);
76  return output;
77 }
78 
80 {
81  CMachine *machine = get_machine(i);
82  float64_t output = 0.0;
83  // dirty hack
84  if (dynamic_cast<CLinearMachine*>(machine))
85  output = ((CLinearMachine*)machine)->apply_one(num);
86  if (dynamic_cast<CKernelMachine*>(machine))
87  output = ((CKernelMachine*)machine)->apply_one(num);
88  SG_UNREF(machine);
89  return output;
90 }
91 
93 {
94  SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
95  get_name(), data ? data->get_name() : "NULL", data);
96 
97  CMulticlassLabels* return_labels=NULL;
98 
99  if (data)
101  else
103 
104  if (is_ready())
105  {
106  /* num vectors depends on whether data is provided */
107  int32_t num_vectors=data ? data->get_num_vectors() :
109 
110  int32_t num_machines=m_machines->get_num_elements();
111  if (num_machines <= 0)
112  SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
113 
114  CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
115  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
116 
117  for (int32_t i=0; i < num_machines; ++i)
118  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
119 
120  SGVector<float64_t> output_for_i(num_machines);
121  for (int32_t i=0; i<num_vectors; i++)
122  {
123  for (int32_t j=0; j<num_machines; j++)
124  output_for_i[j] = outputs[j]->get_confidence(i);
125 
126  result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
127  result->set_multiclass_confidences(i, output_for_i.clone());
128  }
129 
130  for (int32_t i=0; i < num_machines; ++i)
131  SG_UNREF(outputs[i]);
132 
133  SG_FREE(outputs);
134 
135  return_labels=result;
136  }
137  else
138  SG_ERROR("Not ready");
139 
140 
141  SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
142  get_name(), data ? data->get_name() : "NULL", data);
143  return return_labels;
144 }
145 
147 {
148  CMulticlassMultipleOutputLabels* return_labels=NULL;
149 
150  if (data)
152  else
154 
155  if (is_ready())
156  {
157  /* num vectors depends on whether data is provided */
158  int32_t num_vectors=data ? data->get_num_vectors() :
160 
161  int32_t num_machines=m_machines->get_num_elements();
162  if (num_machines <= 0)
163  SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
164  REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available");
165 
167  CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
168 
169  for (int32_t i=0; i < num_machines; ++i)
170  outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
171 
172  SGVector<float64_t> output_for_i(num_machines);
173  for (int32_t i=0; i<num_vectors; i++)
174  {
175  for (int32_t j=0; j<num_machines; j++)
176  output_for_i[j] = outputs[j]->get_confidence(i);
177 
178  result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
179  }
180 
181  for (int32_t i=0; i < num_machines; ++i)
182  SG_UNREF(outputs[i]);
183 
184  SG_FREE(outputs);
185 
186  return_labels=result;
187  }
188  else
189  SG_ERROR("Not ready");
190 
191  return return_labels;
192 }
193 
195 {
197 
198  if ( !data && !is_ready() )
199  SG_ERROR("Please provide training data.\n");
200  else
202 
204  CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
205  SG_REF(train_labels);
206  m_machine->set_labels(train_labels);
207 
210  {
212  if (subset.vlen)
213  {
214  train_labels->add_subset(subset);
215  add_machine_subset(subset);
216  }
217 
218  m_machine->train();
220 
221  if (subset.vlen)
222  {
223  train_labels->remove_subset();
225  }
226  }
227 
229  SG_UNREF(train_labels);
230 
231  return true;
232 }
233 
235 {
237 
240 
241  for (int32_t i=0; i<m_machines->get_num_elements(); i++)
242  outputs[i] = get_submachine_output(i, vec_idx);
243 
244  float64_t result = m_multiclass_strategy->decide_label(outputs);
245 
246  return result;
247 }

SHOGUN Machine Learning Toolbox - Documentation