StructuredAccuracy.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Fernando José Iglesias García
00008  * Copyright (C) 2012 Fernando José Iglesias García
00009  */
00010 
00011 #include <shogun/evaluation/StructuredAccuracy.h>
00012 #include <shogun/structure/HMSVMLabels.h>
00013 #include <shogun/structure/MulticlassSOLabels.h>
00014 
00015 using namespace shogun;
00016 
00017 CStructuredAccuracy::CStructuredAccuracy() : CEvaluation()
00018 {
00019 }
00020 
00021 CStructuredAccuracy::~CStructuredAccuracy()
00022 {
00023 }
00024 
00025 float64_t CStructuredAccuracy::evaluate(CLabels* predicted, CLabels* ground_truth)
00026 {
00027     REQUIRE(predicted && ground_truth, "CLabels objects passed to evaluate "
00028             "cannot be null\n");
00029     REQUIRE(predicted->get_num_labels() == ground_truth->get_num_labels(),
00030             "The number of predicted and ground truth labels must "
00031             "be the same\n");
00032     REQUIRE(predicted->get_label_type() == LT_STRUCTURED, "The predicted "
00033             "labels must be of type CStructuredLabels\n");
00034     REQUIRE(ground_truth->get_label_type() == LT_STRUCTURED, "The ground truth "
00035             "labels must be of type CStructuredLabels\n");
00036 
00037     CStructuredLabels* pred_labs = CStructuredLabels::obtain_from_generic(predicted);
00038     CStructuredLabels* true_labs = CStructuredLabels::obtain_from_generic(ground_truth);
00039 
00040     REQUIRE(pred_labs->get_structured_data_type() ==
00041             true_labs->get_structured_data_type(), "Predicted and ground truth "
00042             "labels must be composed of the same structured data\n");
00043 
00044     switch ( pred_labs->get_structured_data_type() )
00045     {
00046         case (SDT_REAL):
00047             return evaluate_real(pred_labs, true_labs);
00048         case (SDT_SEQUENCE):
00049             return evaluate_sequence(pred_labs, true_labs);
00050         default:
00051             SG_ERROR("Unknown structured data type for evaluation\n");
00052     }
00053 
00054     return 0.0;
00055 }
00056 
00057 SGMatrix< int32_t > CStructuredAccuracy::get_confusion_matrix(
00058         CLabels* predicted, CLabels* ground_truth)
00059 {
00060     SG_SERROR("Not implemented\n");
00061     return SGMatrix< int32_t >();
00062 }
00063 
00064 float64_t CStructuredAccuracy::evaluate_real(CStructuredLabels* predicted,
00065         CStructuredLabels* ground_truth)
00066 {
00067     int32_t length = predicted->get_num_labels();
00068     int32_t num_equal = 0;
00069 
00070     for ( int32_t i = 0 ; i < length ; ++i )
00071     {
00072         CRealNumber* truth =
00073             CRealNumber::obtain_from_generic(ground_truth->get_label(i));
00074         CRealNumber* pred =
00075             CRealNumber::obtain_from_generic(predicted->get_label(i));
00076 
00077         num_equal += truth->value == pred->value;
00078 
00079         SG_UNREF(truth);
00080         SG_UNREF(pred);
00081     }
00082 
00083     return (1.0*num_equal) / length;
00084 }
00085 
00086 float64_t CStructuredAccuracy::evaluate_sequence(CStructuredLabels* predicted,
00087         CStructuredLabels* ground_truth)
00088 {
00089     int32_t length = predicted->get_num_labels();
00090     // Accuracy of each each label
00091     SGVector< float64_t > accuracies(length);
00092     int32_t num_equal = 0;
00093 
00094     for ( int32_t i = 0 ; i < length ; ++i )
00095     {
00096         CSequence* true_seq =
00097             CSequence::obtain_from_generic(ground_truth->get_label(i));
00098         CSequence* pred_seq =
00099             CSequence::obtain_from_generic(predicted->get_label(i));
00100 
00101         SGVector<int32_t> true_seq_data = true_seq->get_data();
00102         SGVector<int32_t> pred_seq_data = pred_seq->get_data();
00103 
00104         REQUIRE(true_seq_data.size() == pred_seq_data.size(), "Corresponding ground "
00105                 "truth and predicted sequences must be equally long\n");
00106 
00107         num_equal = 0;
00108         // Count the number of elements that are equal in both sequences
00109         for ( int32_t j = 0 ; j < true_seq_data.size() ; ++j )
00110             num_equal += true_seq_data[j] == pred_seq_data[j];
00111 
00112         accuracies[i] = (1.0*num_equal) / true_seq_data.size();
00113 
00114         SG_UNREF(true_seq);
00115         SG_UNREF(pred_seq);
00116     }
00117 
00118     return accuracies.mean();
00119 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation