MulticlassAccuracy.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/MulticlassAccuracy.h>
00012 #include <shogun/labels/Labels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/mathematics/Math.h>
00015 
00016 using namespace shogun;
00017 
00018 float64_t CMulticlassAccuracy::evaluate(CLabels* predicted, CLabels* ground_truth)
00019 {
00020     ASSERT(predicted && ground_truth);
00021     ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels());
00022     ASSERT(predicted->get_label_type() == LT_MULTICLASS);
00023     ASSERT(ground_truth->get_label_type() == LT_MULTICLASS);
00024     int32_t length = predicted->get_num_labels();
00025     int32_t correct = 0;
00026     if (m_ignore_rejects)
00027     {
00028         for (int32_t i=0; i<length; i++)
00029         {
00030             if (((CMulticlassLabels*) predicted)->get_int_label(i)==((CMulticlassLabels*) ground_truth)->get_int_label(i))
00031                 correct++;
00032         }
00033         return ((float64_t)correct)/length;
00034     }
00035     else
00036     {
00037         int32_t total = length;
00038         for (int32_t i=0; i<length; i++)
00039         {
00040             int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i);
00041 
00042             if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL)
00043                 total--;
00044             else if (predicted_label==((CMulticlassLabels*) ground_truth)->get_int_label(i))
00045                 correct++;
00046         }
00047         m_rejects_num = length-total;
00048         SG_DEBUG("correct=%d, total=%d, rejected=%d\n",correct,total,length-total);
00049         return ((float64_t)correct)/total;
00050     }
00051     return 0.0;
00052 }
00053 
00054 SGMatrix<int32_t> CMulticlassAccuracy::get_confusion_matrix(CLabels* predicted, CLabels* ground_truth)
00055 {
00056     ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels());
00057     int32_t length = ground_truth->get_num_labels();
00058     int32_t num_classes = ((CMulticlassLabels*) ground_truth)->get_num_classes();
00059     SGMatrix<int32_t> confusion_matrix(num_classes, num_classes);
00060     memset(confusion_matrix.matrix,0,sizeof(int32_t)*num_classes*num_classes);
00061     for (int32_t i=0; i<length; i++)
00062     {
00063         int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i);
00064         int32_t ground_truth_label = ((CMulticlassLabels*) ground_truth)->get_int_label(i);
00065 
00066         if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL)
00067             continue;
00068 
00069         confusion_matrix[predicted_label*num_classes+ground_truth_label]++;
00070     }
00071     return confusion_matrix;
00072 }
00073 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation