Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef __CROSSVALIDATION_H_
00012 #define __CROSSVALIDATION_H_
00013
00014 #include <shogun/base/SGObject.h>
00015 #include <shogun/evaluation/Evaluation.h>
00016
00017 namespace shogun
00018 {
00019
00020 class CMachine;
00021 class CFeatures;
00022 class CLabels;
00023 class CSplittingStrategy;
00024 class CEvaluation;
00025
00031 typedef struct
00032 {
00034 float64_t mean;
00036 bool has_conf_int;
00038 float64_t conf_int_low;
00040 float64_t conf_int_up;
00042 float64_t conf_int_alpha;
00043
00045 void print_result()
00046 {
00047 if (has_conf_int)
00048 {
00049 SG_SPRINT("[%f,%f] with alpha=%f, mean=%f\n", conf_int_low, conf_int_up,
00050 conf_int_alpha, mean);
00051 }
00052 else
00053 SG_SPRINT("%f\n", mean);
00054 }
00055 } CrossValidationResult;
00056
00077 class CCrossValidation: public CSGObject
00078 {
00079 public:
00081 CCrossValidation();
00082
00090 CCrossValidation(CMachine* machine, CFeatures* features, CLabels* labels,
00091 CSplittingStrategy* splitting_strategy,
00092 CEvaluation* evaluation_criterium);
00093
00095 virtual ~CCrossValidation();
00096
00098 inline EEvaluationDirection get_evaluation_direction()
00099 {
00100 return m_evaluation_criterium->get_evaluation_direction();
00101 }
00102
00110 CrossValidationResult evaluate();
00111
00113 CMachine* get_machine() const;
00114
00116 void set_num_runs(int32_t num_runs);
00117
00119 void set_conf_int_alpha(float64_t m_conf_int_alpha);
00120
00122 inline virtual const char* get_name() const
00123 {
00124 return "CrossValidation";
00125 }
00126
00127 private:
00128 void init();
00129
00130 protected:
00139 virtual float64_t evaluate_one_run();
00140
00141 private:
00142 int32_t m_num_runs;
00143 float64_t m_conf_int_alpha;
00144
00145 CMachine* m_machine;
00146 CFeatures* m_features;
00147 CLabels* m_labels;
00148 CSplittingStrategy* m_splitting_strategy;
00149 CEvaluation* m_evaluation_criterium;
00150 };
00151
00152 }
00153
00154 #endif