Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _LINEARMULTICLASSMACHINE_H___
00012 #define _LINEARMULTICLASSMACHINE_H___
00013
00014 #include <shogun/lib/common.h>
00015 #include <shogun/features/DotFeatures.h>
00016 #include <shogun/machine/LinearMachine.h>
00017 #include <shogun/machine/MulticlassMachine.h>
00018
00019 namespace shogun
00020 {
00021
00022 class CDotFeatures;
00023 class CLinearMachine;
00024 class CMulticlassStrategy;
00025
00027 class CLinearMulticlassMachine : public CMulticlassMachine
00028 {
00029 public:
00031 CLinearMulticlassMachine() : CMulticlassMachine(), m_features(NULL)
00032 {
00033 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.",
00034 MS_NOT_AVAILABLE);
00035 }
00036
00043 CLinearMulticlassMachine(CMulticlassStrategy *strategy, CDotFeatures* features, CLinearMachine* machine, CLabels* labs) :
00044 CMulticlassMachine(strategy,(CMachine*)machine,labs), m_features(NULL)
00045 {
00046 set_features(features);
00047 SG_ADD((CSGObject**)&m_features, "m_features", "Feature object.",
00048 MS_NOT_AVAILABLE);
00049 }
00050
00052 virtual ~CLinearMulticlassMachine()
00053 {
00054 SG_UNREF(m_features);
00055 }
00056
00058 virtual const char* get_name() const
00059 {
00060 return "LinearMulticlassMachine";
00061 }
00062
00067 void set_features(CDotFeatures* f)
00068 {
00069 SG_REF(f);
00070 SG_UNREF(m_features);
00071 m_features = f;
00072 }
00073
00078 CDotFeatures* get_features() const
00079 {
00080 SG_REF(m_features);
00081 return m_features;
00082 }
00083
00084 protected:
00085
00087 virtual bool init_machine_for_train(CFeatures* data)
00088 {
00089 if (!m_machine)
00090 SG_ERROR("No machine given in Multiclass constructor\n");
00091
00092 if (data)
00093 set_features((CDotFeatures*)data);
00094
00095 ((CLinearMachine*)m_machine)->set_features(m_features);
00096
00097 return true;
00098 }
00099
00101 virtual bool init_machines_for_apply(CFeatures* data)
00102 {
00103 if (data)
00104 set_features((CDotFeatures*)data);
00105
00106 for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00107 {
00108 CLinearMachine* machine = (CLinearMachine*)m_machines->get_element(i);
00109 ASSERT(m_features);
00110 ASSERT(machine);
00111 machine->set_features(m_features);
00112 SG_UNREF(machine);
00113 }
00114
00115 return true;
00116 }
00117
00119 virtual bool is_ready()
00120 {
00121 if (m_features)
00122 return true;
00123
00124 return false;
00125 }
00126
00128 virtual CMachine* get_machine_from_trained(CMachine* machine)
00129 {
00130 return new CLinearMachine((CLinearMachine*)machine);
00131 }
00132
00134 virtual int32_t get_num_rhs_vectors()
00135 {
00136 return m_features->get_num_vectors();
00137 }
00138
00143 virtual void add_machine_subset(SGVector<index_t> subset)
00144 {
00145
00146
00147 m_features->add_subset(subset);
00148 }
00149
00151 virtual void remove_machine_subset()
00152 {
00153
00154
00155 m_features->remove_subset();
00156 }
00157
00162 virtual void store_model_features() {}
00163
00164 protected:
00165
00167 CDotFeatures* m_features;
00168 };
00169 }
00170 #endif