MulticlassStrategy.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #ifndef MULTICLASSSTRATEGY_H__
00012 #define MULTICLASSSTRATEGY_H__
00013 
00014 #include <shogun/base/SGObject.h>
00015 #include <shogun/labels/BinaryLabels.h>
00016 #include <shogun/labels/MulticlassLabels.h>
00017 #include <shogun/multiclass/RejectionStrategy.h>
00018 
00019 namespace shogun
00020 {
00021 
00025 class CMulticlassStrategy: public CSGObject
00026 {
00027 public:
00029     CMulticlassStrategy();
00030 
00032     virtual ~CMulticlassStrategy() {}
00033 
00035     virtual const char* get_name() const
00036     {
00037         return "MulticlassStrategy";
00038     };
00039 
00041     void set_num_classes(int32_t num_classes)
00042     {
00043         m_num_classes = num_classes;
00044     }
00045 
00047     int32_t get_num_classes() const
00048     {
00049         return m_num_classes;
00050     }
00051 
00053     CRejectionStrategy *get_rejection_strategy()
00054     {
00055         SG_REF(m_rejection_strategy);
00056         return m_rejection_strategy;
00057     }
00058 
00060     void set_rejection_strategy(CRejectionStrategy *rejection_strategy)
00061     {
00062         SG_REF(rejection_strategy);
00063         SG_UNREF(m_rejection_strategy);
00064         m_rejection_strategy = rejection_strategy;
00065     }
00066 
00068     virtual void train_start(CMulticlassLabels *orig_labels, CBinaryLabels *train_labels);
00069 
00071     virtual bool train_has_more()=0;
00072 
00076     virtual SGVector<int32_t> train_prepare_next();
00077 
00079     virtual void train_stop();
00080 
00084     virtual int32_t decide_label(SGVector<float64_t> outputs)=0;
00085     
00090     virtual SGVector<index_t> decide_label_multiple_output(SGVector<float64_t> outputs, int32_t n_outputs)
00091     {
00092         SG_NOTIMPLEMENTED;
00093         return SGVector<index_t>();
00094     }
00095 
00098     virtual int32_t get_num_machines()=0;
00099 
00100 protected:
00101 
00102     CRejectionStrategy* m_rejection_strategy; 
00103     CBinaryLabels *m_train_labels;    
00104     CMulticlassLabels *m_orig_labels; 
00105     int32_t m_train_iter;             
00106     int32_t m_num_classes;            
00107 };
00108 
00109 } // namespace shogun
00110 
00111 #endif /* end of include guard: MULTICLASSSTRATEGY_H__ */
00112 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation