Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/evaluation/CrossValidationMulticlassStorage.h>
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/evaluation/PRCEvaluation.h>
00013 #include <shogun/evaluation/MulticlassAccuracy.h>
00014
00015 using namespace shogun;
00016
00017 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) :
00018 CCrossValidationOutput()
00019 {
00020 m_initialized = false;
00021 m_compute_ROC = compute_ROC;
00022 m_compute_PRC = compute_PRC;
00023 m_compute_conf_matrices = compute_conf_matrices;
00024 m_pred_labels = NULL;
00025 m_true_labels = NULL;
00026 m_num_classes = 0;
00027 m_binary_evaluations = new CDynamicObjectArray();
00028
00029 m_fold_ROC_graphs=NULL;
00030 m_conf_matrices=NULL;
00031 }
00032
00033
00034 CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage()
00035 {
00036 if (m_compute_ROC)
00037 {
00038 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00039 m_fold_ROC_graphs[i].~SGMatrix<float64_t>();
00040
00041 SG_FREE(m_fold_ROC_graphs);
00042 }
00043
00044 if (m_compute_PRC)
00045 {
00046 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00047 m_fold_PRC_graphs[i].~SGMatrix<float64_t>();
00048
00049 SG_FREE(m_fold_PRC_graphs);
00050 }
00051
00052 if (m_compute_conf_matrices)
00053 {
00054 for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00055 m_conf_matrices[i].~SGMatrix<int32_t>();
00056
00057 SG_FREE(m_conf_matrices);
00058 }
00059
00060 SG_UNREF(m_binary_evaluations);
00061 };
00062
00063
00064 void CCrossValidationMulticlassStorage::post_init()
00065 {
00066 if (m_initialized)
00067 SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n");
00068
00069 if (m_compute_ROC)
00070 {
00071 SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00072 m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00073 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00074 new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>();
00075 }
00076
00077 if (m_compute_PRC)
00078 {
00079 SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00080 m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00081 for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00082 new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>();
00083 }
00084
00085 if (m_binary_evaluations->get_num_elements())
00086 m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements());
00087
00088 m_accuracies = SGVector<float64_t>(m_num_folds*m_num_runs);
00089
00090 if (m_compute_conf_matrices)
00091 {
00092 m_conf_matrices = SG_MALLOC(SGMatrix<int32_t>, m_num_folds*m_num_runs);
00093 for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00094 new (&m_conf_matrices[i]) SGMatrix<int32_t>();
00095 }
00096
00097 m_initialized = true;
00098 }
00099
00100 void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)
00101 {
00102 ASSERT((CMulticlassLabels*)labels);
00103 m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
00104 }
00105
00106 void CCrossValidationMulticlassStorage::post_update_results()
00107 {
00108 CROCEvaluation eval_ROC;
00109 CPRCEvaluation eval_PRC;
00110 int32_t n_evals = m_binary_evaluations->get_num_elements();
00111 for (int32_t c=0; c<m_num_classes; c++)
00112 {
00113 SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
00114 CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
00115 CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
00116 if (m_compute_ROC)
00117 {
00118 eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
00119 m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] =
00120 eval_ROC.get_ROC();
00121 }
00122 if (m_compute_PRC)
00123 {
00124 eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
00125 m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] =
00126 eval_PRC.get_PRC();
00127 }
00128
00129 for (int32_t i=0; i<n_evals; i++)
00130 {
00131 CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i);
00132 m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
00133 evaluator->evaluate(pred_labels_binary, true_labels_binary);
00134 SG_UNREF(evaluator);
00135 }
00136
00137 SG_UNREF(pred_labels_binary);
00138 SG_UNREF(true_labels_binary);
00139 }
00140 CMulticlassAccuracy accuracy;
00141
00142 m_accuracies[m_current_run_index*m_num_folds+m_current_fold_index] = accuracy.evaluate(m_pred_labels, m_true_labels);
00143
00144 if (m_compute_conf_matrices)
00145 {
00146 m_conf_matrices[m_current_run_index*m_num_folds+m_current_fold_index] = CMulticlassAccuracy::get_confusion_matrix(m_pred_labels, m_true_labels);
00147 }
00148 }
00149
00150 void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix)
00151 {
00152 m_pred_labels = (CMulticlassLabels*)results;
00153 }
00154
00155 void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix)
00156 {
00157 m_true_labels = (CMulticlassLabels*)results;
00158 }
00159