MulticlassOVREvaluation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2012 Sergey Lisitsyn
00008  */
00009 
00010 #include <shogun/evaluation/MulticlassOVREvaluation.h>
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/evaluation/PRCEvaluation.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/mathematics/Statistics.h>
00015 
00016 using namespace shogun;
00017 
00018 CMulticlassOVREvaluation::CMulticlassOVREvaluation() :
00019     CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
00020 {
00021 }
00022 
00023 CMulticlassOVREvaluation::CMulticlassOVREvaluation(CBinaryClassEvaluation* binary_evaluation) :
00024     CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
00025 {
00026     set_binary_evaluation(binary_evaluation);
00027 }
00028 
00029 CMulticlassOVREvaluation::~CMulticlassOVREvaluation()
00030 {
00031     SG_UNREF(m_binary_evaluation);
00032     if (m_graph_results)
00033     {
00034         for (int32_t i=0; i<m_num_graph_results; i++)
00035             m_graph_results[i].~SGMatrix<float64_t>();
00036         SG_FREE(m_graph_results);
00037     }
00038 }
00039 
00040 float64_t CMulticlassOVREvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00041 {
00042     ASSERT(m_binary_evaluation);
00043     ASSERT(predicted);
00044     ASSERT(ground_truth);
00045     int32_t n_labels = predicted->get_num_labels();
00046     ASSERT(n_labels);
00047     CMulticlassLabels* predicted_mc = (CMulticlassLabels*)predicted;
00048     CMulticlassLabels* ground_truth_mc = (CMulticlassLabels*)ground_truth;
00049     int32_t n_classes = predicted_mc->get_multiclass_confidences(0).size();
00050     ASSERT(n_classes>0);
00051     m_last_results = SGVector<float64_t>(n_classes);
00052     
00053     SGMatrix<float64_t> all(n_labels,n_classes);
00054     for (int32_t i=0; i<n_labels; i++)
00055     {
00056         SGVector<float64_t> confs = predicted_mc->get_multiclass_confidences(i);
00057         for (int32_t j=0; j<n_classes; j++)
00058         {
00059             all(i,j) = confs[j];
00060         }
00061     }
00062     if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation) || dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
00063     {
00064         for (int32_t i=0; i<m_num_graph_results; i++)
00065             m_graph_results[i].~SGMatrix<float64_t>();
00066         SG_FREE(m_graph_results);
00067         m_graph_results = SG_MALLOC(SGMatrix<float64_t>, n_classes);
00068         m_num_graph_results = n_classes;
00069     }
00070     for (int32_t c=0; c<n_classes; c++)
00071     {
00072         CLabels* pred = new CBinaryLabels(SGVector<float64_t>(all.get_column_vector(c),n_labels,false));
00073         SGVector<float64_t> gt_vec(n_labels);
00074         for (int32_t i=0; i<n_labels; i++)
00075         {
00076             if (ground_truth_mc->get_label(i)==c)
00077                 gt_vec[i] = +1.0;
00078             else
00079                 gt_vec[i] = -1.0;
00080         }
00081         CLabels* gt = new CBinaryLabels(gt_vec);
00082         m_last_results[c] = m_binary_evaluation->evaluate(pred, gt);
00083 
00084         if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation))
00085         {
00086             new (&m_graph_results[c]) SGMatrix<float64_t>();
00087             m_graph_results[c] = ((CROCEvaluation*)m_binary_evaluation)->get_ROC();
00088         }
00089         if (dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
00090         {
00091             new (&m_graph_results[c]) SGMatrix<float64_t>();
00092             m_graph_results[c] = ((CPRCEvaluation*)m_binary_evaluation)->get_PRC();
00093         }
00094     }
00095     return CStatistics::mean(m_last_results);
00096 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation