00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2011 Soeren Sonnenburg 00008 * Written (W) 2012 Fernando José Iglesias García and Sergey Lisitsyn 00009 * Copyright (C) 2012 Sergey Lisitsyn, Fernando José Iglesias Garcia 00010 */ 00011 00012 #ifndef _MULTICLASSMACHINE_H___ 00013 #define _MULTICLASSMACHINE_H___ 00014 00015 #include <shogun/machine/BaseMulticlassMachine.h> 00016 #include <shogun/lib/DynamicObjectArray.h> 00017 #include <shogun/multiclass/MulticlassStrategy.h> 00018 #include <shogun/labels/MulticlassLabels.h> 00019 #include <shogun/labels/MulticlassMultipleOutputLabels.h> 00020 00021 namespace shogun 00022 { 00023 00024 class CFeatures; 00025 class CLabels; 00026 00028 class CMulticlassMachine : public CBaseMulticlassMachine 00029 { 00030 public: 00032 CMulticlassMachine(); 00033 00039 CMulticlassMachine(CMulticlassStrategy* strategy, CMachine* machine, CLabels* labels); 00040 00042 virtual ~CMulticlassMachine(); 00043 00048 virtual void set_labels(CLabels* lab); 00049 00056 inline bool set_machine(int32_t num, CMachine* machine) 00057 { 00058 ASSERT(num<m_machines->get_num_elements() && num>=0); 00059 if (machine != NULL && !is_acceptable_machine(machine)) 00060 SG_ERROR("Machine %s is not acceptable by %s", machine->get_name(), this->get_name()); 00061 00062 m_machines->set_element(machine, num); 00063 return true; 00064 } 00065 00071 inline CMachine* get_machine(int32_t num) const 00072 { 00073 return (CMachine*) m_machines->get_element_safe(num); 00074 } 00075 00080 virtual CBinaryLabels* get_submachine_outputs(int32_t i); 00081 00087 virtual float64_t get_submachine_output(int32_t i, int32_t num); 00088 00093 virtual CMulticlassLabels* apply_multiclass(CFeatures* data=NULL); 00094 00099 virtual CMulticlassMultipleOutputLabels* apply_multiclass_multiple_output(CFeatures* data=NULL, int32_t n_outputs=5); 00100 00105 virtual float64_t apply_one(int32_t vec_idx); 00106 00111 inline CMulticlassStrategy* get_multiclass_strategy() const 00112 { 00113 SG_REF(m_multiclass_strategy); 00114 return m_multiclass_strategy; 00115 } 00116 00121 inline CRejectionStrategy* get_rejection_strategy() const 00122 { 00123 return m_multiclass_strategy->get_rejection_strategy(); 00124 } 00125 00130 inline void set_rejection_strategy(CRejectionStrategy* rejection_strategy) 00131 { 00132 m_multiclass_strategy->set_rejection_strategy(rejection_strategy); 00133 } 00134 00136 virtual const char* get_name() const 00137 { 00138 return "MulticlassMachine"; 00139 } 00140 00141 protected: 00143 void init_strategy(); 00144 00146 void clear_machines(); 00147 00149 virtual bool train_machine(CFeatures* data = NULL); 00150 00152 virtual bool init_machine_for_train(CFeatures* data) = 0; 00153 00155 virtual bool init_machines_for_apply(CFeatures* data) = 0; 00156 00158 virtual bool is_ready() = 0; 00159 00161 virtual CMachine* get_machine_from_trained(CMachine* machine) = 0; 00162 00164 virtual int32_t get_num_rhs_vectors() = 0; 00165 00170 virtual void add_machine_subset(SGVector<index_t> subset) = 0; 00171 00173 virtual void remove_machine_subset() = 0; 00174 00176 virtual bool is_acceptable_machine(CMachine *machine) 00177 { 00178 return true; 00179 } 00180 00181 private: 00182 00184 void register_parameters(); 00185 00186 protected: 00188 CMulticlassStrategy *m_multiclass_strategy; 00189 00191 CMachine* m_machine; 00192 }; 00193 } 00194 #endif