Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00013 #include <shogun/machine/LinearMachine.h>
00014 #include <shogun/machine/KernelMachine.h>
00015 #include <shogun/machine/MulticlassMachine.h>
00016 #include <shogun/base/Parameter.h>
00017 #include <shogun/labels/MulticlassLabels.h>
00018 #include <shogun/labels/RegressionLabels.h>
00019
00020 using namespace shogun;
00021
00022 CMulticlassMachine::CMulticlassMachine()
00023 : CBaseMulticlassMachine(), m_multiclass_strategy(new CMulticlassOneVsRestStrategy()),
00024 m_machine(NULL)
00025 {
00026 SG_REF(m_multiclass_strategy);
00027 register_parameters();
00028 }
00029
00030 CMulticlassMachine::CMulticlassMachine(
00031 CMulticlassStrategy *strategy,
00032 CMachine* machine, CLabels* labs)
00033 : CBaseMulticlassMachine(), m_multiclass_strategy(strategy)
00034 {
00035 SG_REF(strategy);
00036 set_labels(labs);
00037 SG_REF(machine);
00038 m_machine = machine;
00039 register_parameters();
00040
00041 if (labs)
00042 init_strategy();
00043 }
00044
00045 CMulticlassMachine::~CMulticlassMachine()
00046 {
00047 SG_UNREF(m_multiclass_strategy);
00048 SG_UNREF(m_machine);
00049 }
00050
00051 void CMulticlassMachine::set_labels(CLabels* lab)
00052 {
00053 CMachine::set_labels(lab);
00054 if (lab)
00055 init_strategy();
00056 }
00057
00058 void CMulticlassMachine::register_parameters()
00059 {
00060 SG_ADD((CSGObject**)&m_multiclass_strategy,"m_multiclass_type", "Multiclass strategy", MS_NOT_AVAILABLE);
00061 SG_ADD((CSGObject**)&m_machine, "m_machine", "The base machine", MS_NOT_AVAILABLE);
00062 }
00063
00064 void CMulticlassMachine::init_strategy()
00065 {
00066 int32_t num_classes = ((CMulticlassLabels*) m_labels)->get_num_classes();
00067 m_multiclass_strategy->set_num_classes(num_classes);
00068 }
00069
00070 CBinaryLabels* CMulticlassMachine::get_submachine_outputs(int32_t i)
00071 {
00072 CMachine *machine = (CMachine*)m_machines->get_element(i);
00073 ASSERT(machine);
00074 CBinaryLabels* output = machine->apply_binary();
00075 SG_UNREF(machine);
00076 return output;
00077 }
00078
00079 float64_t CMulticlassMachine::get_submachine_output(int32_t i, int32_t num)
00080 {
00081 CMachine *machine = get_machine(i);
00082 float64_t output = 0.0;
00083
00084 if (dynamic_cast<CLinearMachine*>(machine))
00085 output = ((CLinearMachine*)machine)->apply_one(num);
00086 if (dynamic_cast<CKernelMachine*>(machine))
00087 output = ((CKernelMachine*)machine)->apply_one(num);
00088 SG_UNREF(machine);
00089 return output;
00090 }
00091
00092 CMulticlassLabels* CMulticlassMachine::apply_multiclass(CFeatures* data)
00093 {
00094 SG_DEBUG("entering %s::apply_multiclass(%s at %p)\n",
00095 get_name(), data ? data->get_name() : "NULL", data);
00096
00097 CMulticlassLabels* return_labels=NULL;
00098
00099 if (data)
00100 init_machines_for_apply(data);
00101 else
00102 init_machines_for_apply(NULL);
00103
00104 if (is_ready())
00105 {
00106
00107 int32_t num_vectors=data ? data->get_num_vectors() :
00108 get_num_rhs_vectors();
00109
00110 int32_t num_machines=m_machines->get_num_elements();
00111 if (num_machines <= 0)
00112 SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
00113
00114 CMulticlassLabels* result=new CMulticlassLabels(num_vectors);
00115 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00116
00117 for (int32_t i=0; i < num_machines; ++i)
00118 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00119
00120 SGVector<float64_t> output_for_i(num_machines);
00121 for (int32_t i=0; i<num_vectors; i++)
00122 {
00123 for (int32_t j=0; j<num_machines; j++)
00124 output_for_i[j] = outputs[j]->get_value(i);
00125
00126 result->set_label(i, m_multiclass_strategy->decide_label(output_for_i));
00127 result->set_multiclass_confidences(i, output_for_i.clone());
00128 }
00129
00130 for (int32_t i=0; i < num_machines; ++i)
00131 SG_UNREF(outputs[i]);
00132
00133 SG_FREE(outputs);
00134
00135 return_labels=result;
00136 }
00137 else
00138 SG_ERROR("Not ready");
00139
00140
00141 SG_DEBUG("leaving %s::apply_multiclass(%s at %p)\n",
00142 get_name(), data ? data->get_name() : "NULL", data);
00143 return return_labels;
00144 }
00145
00146 CMulticlassMultipleOutputLabels* CMulticlassMachine::apply_multiclass_multiple_output(CFeatures* data, int32_t n_outputs)
00147 {
00148 CMulticlassMultipleOutputLabels* return_labels=NULL;
00149
00150 if (data)
00151 init_machines_for_apply(data);
00152 else
00153 init_machines_for_apply(NULL);
00154
00155 if (is_ready())
00156 {
00157
00158 int32_t num_vectors=data ? data->get_num_vectors() :
00159 get_num_rhs_vectors();
00160
00161 int32_t num_machines=m_machines->get_num_elements();
00162 if (num_machines <= 0)
00163 SG_ERROR("num_machines = %d, did you train your machine?", num_machines);
00164 REQUIRE(n_outputs<=num_machines,"You request more outputs than machines available");
00165
00166 CMulticlassMultipleOutputLabels* result=new CMulticlassMultipleOutputLabels(num_vectors);
00167 CBinaryLabels** outputs=SG_MALLOC(CBinaryLabels*, num_machines);
00168
00169 for (int32_t i=0; i < num_machines; ++i)
00170 outputs[i] = (CBinaryLabels*) get_submachine_outputs(i);
00171
00172 SGVector<float64_t> output_for_i(num_machines);
00173 for (int32_t i=0; i<num_vectors; i++)
00174 {
00175 for (int32_t j=0; j<num_machines; j++)
00176 output_for_i[j] = outputs[j]->get_value(i);
00177
00178 result->set_label(i, m_multiclass_strategy->decide_label_multiple_output(output_for_i, n_outputs));
00179 }
00180
00181 for (int32_t i=0; i < num_machines; ++i)
00182 SG_UNREF(outputs[i]);
00183
00184 SG_FREE(outputs);
00185
00186 return_labels=result;
00187 }
00188 else
00189 SG_ERROR("Not ready");
00190
00191 return return_labels;
00192 }
00193
00194 bool CMulticlassMachine::train_machine(CFeatures* data)
00195 {
00196 ASSERT(m_multiclass_strategy);
00197
00198 if ( !data && !is_ready() )
00199 SG_ERROR("Please provide training data.\n");
00200 else
00201 init_machine_for_train(data);
00202
00203 m_machines->reset_array();
00204 CBinaryLabels* train_labels = new CBinaryLabels(get_num_rhs_vectors());
00205 SG_REF(train_labels);
00206 m_machine->set_labels(train_labels);
00207
00208 m_multiclass_strategy->train_start(CMulticlassLabels::obtain_from_generic(m_labels), train_labels);
00209 while (m_multiclass_strategy->train_has_more())
00210 {
00211 SGVector<index_t> subset=m_multiclass_strategy->train_prepare_next();
00212 if (subset.vlen)
00213 {
00214 train_labels->add_subset(subset);
00215 add_machine_subset(subset);
00216 }
00217
00218 m_machine->train();
00219 m_machines->push_back(get_machine_from_trained(m_machine));
00220
00221 if (subset.vlen)
00222 {
00223 train_labels->remove_subset();
00224 remove_machine_subset();
00225 }
00226 }
00227
00228 m_multiclass_strategy->train_stop();
00229 SG_UNREF(train_labels);
00230
00231 return true;
00232 }
00233
00234 float64_t CMulticlassMachine::apply_one(int32_t vec_idx)
00235 {
00236 init_machines_for_apply(NULL);
00237
00238 ASSERT(m_machines->get_num_elements()>0);
00239 SGVector<float64_t> outputs(m_machines->get_num_elements());
00240
00241 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00242 outputs[i] = get_submachine_output(i, vec_idx);
00243
00244 float64_t result = m_multiclass_strategy->decide_label(outputs);
00245
00246 return result;
00247 }