CrossValidationMulticlassStorage.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn, Heiko Strathmann
00008  */
00009 
00010 #include <shogun/evaluation/CrossValidationMulticlassStorage.h>
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/evaluation/PRCEvaluation.h>
00013 #include <shogun/evaluation/MulticlassAccuracy.h>
00014 
00015 using namespace shogun;
00016 
00017 CCrossValidationMulticlassStorage::CCrossValidationMulticlassStorage(bool compute_ROC, bool compute_PRC, bool compute_conf_matrices) : 
00018     CCrossValidationOutput()
00019 {
00020     m_initialized = false;
00021     m_compute_ROC = compute_ROC;
00022     m_compute_PRC = compute_PRC;
00023     m_compute_conf_matrices = compute_conf_matrices;
00024     m_pred_labels = NULL;
00025     m_true_labels = NULL;
00026     m_num_classes = 0;
00027     m_binary_evaluations = new CDynamicObjectArray();
00028 
00029     m_fold_ROC_graphs=NULL;
00030     m_conf_matrices=NULL;
00031 }
00032 
00033 
00034 CCrossValidationMulticlassStorage::~CCrossValidationMulticlassStorage()
00035 {
00036     if (m_compute_ROC)
00037     {
00038         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00039             m_fold_ROC_graphs[i].~SGMatrix<float64_t>();
00040 
00041         SG_FREE(m_fold_ROC_graphs);
00042     }
00043 
00044     if (m_compute_PRC)
00045     {
00046         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00047             m_fold_PRC_graphs[i].~SGMatrix<float64_t>();
00048 
00049         SG_FREE(m_fold_PRC_graphs);
00050     }
00051 
00052     if (m_compute_conf_matrices)
00053     {
00054         for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00055             m_conf_matrices[i].~SGMatrix<int32_t>();
00056     
00057         SG_FREE(m_conf_matrices);
00058     }
00059 
00060     SG_UNREF(m_binary_evaluations);
00061 };
00062 
00063 
00064 void CCrossValidationMulticlassStorage::post_init()
00065 {
00066     if (m_initialized)
00067         SG_ERROR("CrossValidationMulticlassStorage was already initialized once\n");
00068 
00069     if (m_compute_ROC)
00070     {
00071         SG_DEBUG("Allocating %d ROC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00072         m_fold_ROC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00073         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00074             new (&m_fold_ROC_graphs[i]) SGMatrix<float64_t>();
00075     }
00076 
00077     if (m_compute_PRC)
00078     {
00079         SG_DEBUG("Allocating %d PRC graphs\n", m_num_folds*m_num_runs*m_num_classes);
00080         m_fold_PRC_graphs = SG_MALLOC(SGMatrix<float64_t>, m_num_folds*m_num_runs*m_num_classes);
00081         for (int32_t i=0; i<m_num_folds*m_num_runs*m_num_classes; i++)
00082             new (&m_fold_PRC_graphs[i]) SGMatrix<float64_t>();
00083     }
00084 
00085     if (m_binary_evaluations->get_num_elements())
00086         m_evaluations_results = SGVector<float64_t>(m_num_folds*m_num_runs*m_num_classes*m_binary_evaluations->get_num_elements());
00087 
00088     m_accuracies = SGVector<float64_t>(m_num_folds*m_num_runs);
00089 
00090     if (m_compute_conf_matrices)
00091     {
00092         m_conf_matrices = SG_MALLOC(SGMatrix<int32_t>, m_num_folds*m_num_runs);
00093         for (int32_t i=0; i<m_num_folds*m_num_runs; i++)
00094             new (&m_conf_matrices[i]) SGMatrix<int32_t>();
00095     }
00096 
00097     m_initialized = true;
00098 }
00099 
00100 void CCrossValidationMulticlassStorage::init_expose_labels(CLabels* labels)
00101 {
00102     ASSERT((CMulticlassLabels*)labels);
00103     m_num_classes = ((CMulticlassLabels*)labels)->get_num_classes();
00104 }
00105 
00106 void CCrossValidationMulticlassStorage::post_update_results()
00107 {
00108     CROCEvaluation eval_ROC;
00109     CPRCEvaluation eval_PRC;
00110     int32_t n_evals = m_binary_evaluations->get_num_elements();
00111     for (int32_t c=0; c<m_num_classes; c++)
00112     {
00113         SG_DEBUG("Computing ROC for run %d fold %d class %d", m_current_run_index, m_current_fold_index, c);
00114         CBinaryLabels* pred_labels_binary = m_pred_labels->get_binary_for_class(c);
00115         CBinaryLabels* true_labels_binary = m_true_labels->get_binary_for_class(c);
00116         if (m_compute_ROC)
00117         {
00118             eval_ROC.evaluate(pred_labels_binary, true_labels_binary);
00119             m_fold_ROC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 
00120                 eval_ROC.get_ROC();
00121         }
00122         if (m_compute_PRC)
00123         {
00124             eval_PRC.evaluate(pred_labels_binary, true_labels_binary);
00125             m_fold_PRC_graphs[m_current_run_index*m_num_folds*m_num_classes+m_current_fold_index*m_num_classes+c] = 
00126                 eval_PRC.get_PRC();
00127         }
00128         
00129         for (int32_t i=0; i<n_evals; i++)
00130         {
00131             CBinaryClassEvaluation* evaluator = (CBinaryClassEvaluation*)m_binary_evaluations->get_element_safe(i);
00132             m_evaluations_results[m_current_run_index*m_num_folds*m_num_classes*n_evals+m_current_fold_index*m_num_classes*n_evals+c*n_evals+i] =
00133                 evaluator->evaluate(pred_labels_binary, true_labels_binary);
00134             SG_UNREF(evaluator);
00135         }
00136 
00137         SG_UNREF(pred_labels_binary);
00138         SG_UNREF(true_labels_binary);
00139     }
00140     CMulticlassAccuracy accuracy;
00141 
00142     m_accuracies[m_current_run_index*m_num_folds+m_current_fold_index] = accuracy.evaluate(m_pred_labels, m_true_labels);
00143     
00144     if (m_compute_conf_matrices)
00145     {
00146         m_conf_matrices[m_current_run_index*m_num_folds+m_current_fold_index] = CMulticlassAccuracy::get_confusion_matrix(m_pred_labels, m_true_labels);
00147     }
00148 }
00149 
00150 void CCrossValidationMulticlassStorage::update_test_result(CLabels* results, const char* prefix)
00151 {
00152     m_pred_labels = (CMulticlassLabels*)results;
00153 }
00154 
00155 void CCrossValidationMulticlassStorage::update_test_true_result(CLabels* results, const char* prefix)
00156 {
00157     m_true_labels = (CMulticlassLabels*)results;
00158 }
00159 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation