SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
StructuredAccuracy.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012-2013 Fernando José Iglesias García
8  * Copyright (C) 2012-2013 Fernando José Iglesias García
9  */
10 
17 
18 using namespace shogun;
19 
21 {
22 }
23 
25 {
26 }
27 
29 {
30  REQUIRE(predicted && ground_truth, "CLabels objects passed to evaluate "
31  "cannot be null\n");
32  REQUIRE(predicted->get_num_labels() == ground_truth->get_num_labels(),
33  "The number of predicted and ground truth labels must "
34  "be the same\n");
35  REQUIRE(predicted->get_label_type() == LT_STRUCTURED, "The predicted "
36  "labels must be of type CStructuredLabels\n");
37  REQUIRE(ground_truth->get_label_type() == LT_STRUCTURED, "The ground truth "
38  "labels must be of type CStructuredLabels\n");
39 
40  CStructuredLabels * pred_labs = CLabelsFactory::to_structured(predicted);
41  CStructuredLabels * true_labs = CLabelsFactory::to_structured(ground_truth);
42 
43  REQUIRE(pred_labs->get_structured_data_type() ==
44  true_labs->get_structured_data_type(), "Predicted and ground truth "
45  "labels must be composed of the same structured data\n");
46 
47  switch (pred_labs->get_structured_data_type())
48  {
49  case (SDT_REAL):
50  return evaluate_real(pred_labs, true_labs);
51 
52  case (SDT_SEQUENCE):
53  return evaluate_sequence(pred_labs, true_labs);
54 
55  case (SDT_SPARSE_MULTILABEL):
56  return evaluate_sparse_multilabel(pred_labs, true_labs);
57 
58  default:
59  SG_ERROR("Unknown structured data type for evaluation\n")
60  }
61 
62  return 0.0;
63 }
64 
66  CLabels * predicted, CLabels * ground_truth)
67 {
68  SG_SERROR("Not implemented\n")
69  return SGMatrix<int32_t>();
70 }
71 
72 float64_t CStructuredAccuracy::evaluate_real(CStructuredLabels * predicted,
73  CStructuredLabels * ground_truth)
74 {
75  int32_t length = predicted->get_num_labels();
76  int32_t num_equal = 0;
77 
78  for (int32_t i = 0 ; i < length ; ++i)
79  {
80  CRealNumber * truth = CRealNumber::obtain_from_generic(ground_truth->get_label(i));
82 
83  num_equal += truth->value == pred->value;
84 
85  SG_UNREF(truth);
86  SG_UNREF(pred);
87  }
88 
89  return (1.0 * num_equal) / length;
90 }
91 
92 float64_t CStructuredAccuracy::evaluate_sequence(CStructuredLabels * predicted,
93  CStructuredLabels * ground_truth)
94 {
95  int32_t length = predicted->get_num_labels();
96  // Accuracy of each each label
97  SGVector<float64_t> accuracies(length);
98  int32_t num_equal = 0;
99 
100  for (int32_t i = 0 ; i < length ; ++i)
101  {
102  CSequence * true_seq = CSequence::obtain_from_generic(ground_truth->get_label(i));
103  CSequence * pred_seq = CSequence::obtain_from_generic(predicted->get_label(i));
104 
105  SGVector<int32_t> true_seq_data = true_seq->get_data();
106  SGVector<int32_t> pred_seq_data = pred_seq->get_data();
107 
108  REQUIRE(true_seq_data.size() == pred_seq_data.size(), "Corresponding ground "
109  "truth and predicted sequences must be equally long\n");
110 
111  num_equal = 0;
112 
113  // Count the number of elements that are equal in both sequences
114  for (int32_t j = 0 ; j < true_seq_data.size() ; ++j)
115  {
116  num_equal += true_seq_data[j] == pred_seq_data[j];
117  }
118 
119  accuracies[i] = (1.0 * num_equal) / true_seq_data.size();
120 
121  SG_UNREF(true_seq);
122  SG_UNREF(pred_seq);
123  }
124 
125  return CStatistics::mean(accuracies);
126 }
127 
128 float64_t CStructuredAccuracy::evaluate_sparse_multilabel(CStructuredLabels * predicted,
129  CStructuredLabels * ground_truth)
130 {
131  CMultilabelSOLabels * multi_pred = (CMultilabelSOLabels *) predicted;
132  CMultilabelSOLabels * multi_truth = (CMultilabelSOLabels *) ground_truth;
133 
134  CMultilabelAccuracy * evaluator = new CMultilabelAccuracy();
135  SG_REF(evaluator);
136 
137  float64_t accuracy = evaluator->evaluate(multi_pred->get_multilabel_labels(),
138  multi_truth->get_multilabel_labels());
139 
140  SG_UNREF(evaluator);
141 
142  return accuracy;
143 }
144 
Base class of the labels used in Structured Output (SO) problems.
virtual ELabelType get_label_type() const =0
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
Class CMultilabelSOLabels used in the application of Structured Output (SO) learning to Multilabel Cl...
The class Labels models labels, i.e. class assignments of objects.
Definition: Labels.h:43
virtual int32_t get_num_labels() const =0
structured labels (e.g. sequences, trees) used in Structured Output problems
Definition: LabelTypes.h:24
virtual float64_t evaluate(CLabels *predicted, CLabels *ground_truth)
#define SG_ERROR(...)
Definition: SGIO.h:129
#define REQUIRE(x,...)
Definition: SGIO.h:206
#define SG_REF(x)
Definition: SGObject.h:51
static CRealNumber * obtain_from_generic(CStructuredData *base_data)
Class CSequence to be used in the application of Structured Output (SO) learning to Hidden Markov Sup...
int32_t size() const
Definition: SGVector.h:115
double float64_t
Definition: common.h:50
static floatmax_t mean(SGVector< T > vec)
Definition: Statistics.h:44
static SGMatrix< int32_t > get_confusion_matrix(CLabels *predicted, CLabels *ground_truth)
#define SG_UNREF(x)
Definition: SGObject.h:52
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
virtual int32_t get_num_labels() const
Class CMultilabelAccuracy used to compute accuracy of multilabel classification.
#define SG_SERROR(...)
Definition: SGIO.h:179
SGVector< int32_t > get_data() const
virtual CStructuredData * get_label(int32_t idx)
static CSequence * obtain_from_generic(CStructuredData *base_data)
static CStructuredLabels * to_structured(CLabels *base_labels)
Class CRealNumber to be used in the application of Structured Output (SO) learning to multiclass clas...
EStructuredDataType get_structured_data_type()
Class Evaluation, a base class for other classes used to evaluate labels, e.g. accuracy of classifica...
Definition: Evaluation.h:40
virtual CMultilabelLabels * get_multilabel_labels()

SHOGUN Machine Learning Toolbox - Documentation