00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2012 Chiyuan Zhang 00008 * Copyright (C) 2012 Chiyuan Zhang 00009 */ 00010 00011 #ifndef MULTICLASSSTRATEGY_H__ 00012 #define MULTICLASSSTRATEGY_H__ 00013 00014 #include <shogun/base/SGObject.h> 00015 #include <shogun/labels/BinaryLabels.h> 00016 #include <shogun/labels/MulticlassLabels.h> 00017 #include <shogun/multiclass/RejectionStrategy.h> 00018 00019 namespace shogun 00020 { 00021 00025 class CMulticlassStrategy: public CSGObject 00026 { 00027 public: 00029 CMulticlassStrategy(); 00030 00032 virtual ~CMulticlassStrategy() {} 00033 00035 virtual const char* get_name() const 00036 { 00037 return "MulticlassStrategy"; 00038 }; 00039 00041 void set_num_classes(int32_t num_classes) 00042 { 00043 m_num_classes = num_classes; 00044 } 00045 00047 int32_t get_num_classes() const 00048 { 00049 return m_num_classes; 00050 } 00051 00053 CRejectionStrategy *get_rejection_strategy() 00054 { 00055 SG_REF(m_rejection_strategy); 00056 return m_rejection_strategy; 00057 } 00058 00060 void set_rejection_strategy(CRejectionStrategy *rejection_strategy) 00061 { 00062 SG_REF(rejection_strategy); 00063 SG_UNREF(m_rejection_strategy); 00064 m_rejection_strategy = rejection_strategy; 00065 } 00066 00068 virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels); 00069 00071 virtual bool train_has_more()=0; 00072 00076 virtual SGVector<int32_t> train_prepare_next(); 00077 00079 virtual void train_stop(); 00080 00084 virtual int32_t decide_label(SGVector<float64_t> outputs)=0; 00085 00090 virtual SGVector<index_t> decide_label_multiple_output(SGVector<float64_t> outputs, int32_t n_outputs) 00091 { 00092 SG_NOTIMPLEMENTED; 00093 return SGVector<index_t>(); 00094 } 00095 00098 virtual int32_t get_num_machines()=0; 00099 00100 protected: 00101 00102 CRejectionStrategy* m_rejection_strategy; 00103 CBinaryLabels *m_train_labels; 00104 CMulticlassLabels *m_orig_labels; 00105 int32_t m_train_iter; 00106 int32_t m_num_classes; 00107 }; 00108 00109 } // namespace shogun 00110 00111 #endif /* end of include guard: MULTICLASSSTRATEGY_H__ */ 00112