Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/MulticlassAccuracy.h>
00012 #include <shogun/labels/Labels.h>
00013 #include <shogun/labels/MulticlassLabels.h>
00014 #include <shogun/mathematics/Math.h>
00015
00016 using namespace shogun;
00017
00018 float64_t CMulticlassAccuracy::evaluate(CLabels* predicted, CLabels* ground_truth)
00019 {
00020 ASSERT(predicted && ground_truth);
00021 ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels());
00022 ASSERT(predicted->get_label_type() == LT_MULTICLASS);
00023 ASSERT(ground_truth->get_label_type() == LT_MULTICLASS);
00024 int32_t length = predicted->get_num_labels();
00025 int32_t correct = 0;
00026 if (m_ignore_rejects)
00027 {
00028 for (int32_t i=0; i<length; i++)
00029 {
00030 if (((CMulticlassLabels*) predicted)->get_int_label(i)==((CMulticlassLabels*) ground_truth)->get_int_label(i))
00031 correct++;
00032 }
00033 return ((float64_t)correct)/length;
00034 }
00035 else
00036 {
00037 int32_t total = length;
00038 for (int32_t i=0; i<length; i++)
00039 {
00040 int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i);
00041
00042 if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL)
00043 total--;
00044 else if (predicted_label==((CMulticlassLabels*) ground_truth)->get_int_label(i))
00045 correct++;
00046 }
00047 m_rejects_num = length-total;
00048 SG_DEBUG("correct=%d, total=%d, rejected=%d\n",correct,total,length-total);
00049 return ((float64_t)correct)/total;
00050 }
00051 return 0.0;
00052 }
00053
00054 SGMatrix<int32_t> CMulticlassAccuracy::get_confusion_matrix(CLabels* predicted, CLabels* ground_truth)
00055 {
00056 ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels());
00057 int32_t length = ground_truth->get_num_labels();
00058 int32_t num_classes = ((CMulticlassLabels*) ground_truth)->get_num_classes();
00059 SGMatrix<int32_t> confusion_matrix(num_classes, num_classes);
00060 memset(confusion_matrix.matrix,0,sizeof(int32_t)*num_classes*num_classes);
00061 for (int32_t i=0; i<length; i++)
00062 {
00063 int32_t predicted_label = ((CMulticlassLabels*) predicted)->get_int_label(i);
00064 int32_t ground_truth_label = ((CMulticlassLabels*) ground_truth)->get_int_label(i);
00065
00066 if (predicted_label==((CMulticlassLabels*) predicted)->REJECTION_LABEL)
00067 continue;
00068
00069 confusion_matrix[predicted_label*num_classes+ground_truth_label]++;
00070 }
00071 return confusion_matrix;
00072 }
00073