Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <shogun/evaluation/MulticlassOVREvaluation.h>
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/evaluation/PRCEvaluation.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/mathematics/Statistics.h>
00015
00016 using namespace shogun;
00017
00018 CMulticlassOVREvaluation::CMulticlassOVREvaluation() :
00019 CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
00020 {
00021 }
00022
00023 CMulticlassOVREvaluation::CMulticlassOVREvaluation(CBinaryClassEvaluation* binary_evaluation) :
00024 CEvaluation(), m_binary_evaluation(NULL), m_graph_results(NULL), m_num_graph_results(0)
00025 {
00026 set_binary_evaluation(binary_evaluation);
00027 }
00028
00029 CMulticlassOVREvaluation::~CMulticlassOVREvaluation()
00030 {
00031 SG_UNREF(m_binary_evaluation);
00032 if (m_graph_results)
00033 {
00034 for (int32_t i=0; i<m_num_graph_results; i++)
00035 m_graph_results[i].~SGMatrix<float64_t>();
00036 SG_FREE(m_graph_results);
00037 }
00038 }
00039
00040 float64_t CMulticlassOVREvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00041 {
00042 ASSERT(m_binary_evaluation);
00043 ASSERT(predicted);
00044 ASSERT(ground_truth);
00045 int32_t n_labels = predicted->get_num_labels();
00046 ASSERT(n_labels);
00047 CMulticlassLabels* predicted_mc = (CMulticlassLabels*)predicted;
00048 CMulticlassLabels* ground_truth_mc = (CMulticlassLabels*)ground_truth;
00049 int32_t n_classes = predicted_mc->get_multiclass_confidences(0).size();
00050 ASSERT(n_classes>0);
00051 m_last_results = SGVector<float64_t>(n_classes);
00052
00053 SGMatrix<float64_t> all(n_labels,n_classes);
00054 for (int32_t i=0; i<n_labels; i++)
00055 {
00056 SGVector<float64_t> confs = predicted_mc->get_multiclass_confidences(i);
00057 for (int32_t j=0; j<n_classes; j++)
00058 {
00059 all(i,j) = confs[j];
00060 }
00061 }
00062 if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation) || dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
00063 {
00064 for (int32_t i=0; i<m_num_graph_results; i++)
00065 m_graph_results[i].~SGMatrix<float64_t>();
00066 SG_FREE(m_graph_results);
00067 m_graph_results = SG_MALLOC(SGMatrix<float64_t>, n_classes);
00068 m_num_graph_results = n_classes;
00069 }
00070 for (int32_t c=0; c<n_classes; c++)
00071 {
00072 CLabels* pred = new CBinaryLabels(SGVector<float64_t>(all.get_column_vector(c),n_labels,false));
00073 SGVector<float64_t> gt_vec(n_labels);
00074 for (int32_t i=0; i<n_labels; i++)
00075 {
00076 if (ground_truth_mc->get_label(i)==c)
00077 gt_vec[i] = +1.0;
00078 else
00079 gt_vec[i] = -1.0;
00080 }
00081 CLabels* gt = new CBinaryLabels(gt_vec);
00082 m_last_results[c] = m_binary_evaluation->evaluate(pred, gt);
00083
00084 if (dynamic_cast<CROCEvaluation*>(m_binary_evaluation))
00085 {
00086 new (&m_graph_results[c]) SGMatrix<float64_t>();
00087 m_graph_results[c] = ((CROCEvaluation*)m_binary_evaluation)->get_ROC();
00088 }
00089 if (dynamic_cast<CPRCEvaluation*>(m_binary_evaluation))
00090 {
00091 new (&m_graph_results[c]) SGMatrix<float64_t>();
00092 m_graph_results[c] = ((CPRCEvaluation*)m_binary_evaluation)->get_PRC();
00093 }
00094 }
00095 return CStatistics::mean(m_last_results);
00096 }