SHOGUN  3.2.1
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
StructuredAccuracy.cpp
Go to the documentation of this file.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2012-2013 Fernando José Iglesias García
8  * Copyright (C) 2012-2013 Fernando José Iglesias García
9  */
10 
16 
17 using namespace shogun;
18 
20 {
21 }
22 
24 {
25 }
26 
28 {
29  REQUIRE(predicted && ground_truth, "CLabels objects passed to evaluate "
30  "cannot be null\n");
31  REQUIRE(predicted->get_num_labels() == ground_truth->get_num_labels(),
32  "The number of predicted and ground truth labels must "
33  "be the same\n");
34  REQUIRE(predicted->get_label_type() == LT_STRUCTURED, "The predicted "
35  "labels must be of type CStructuredLabels\n");
36  REQUIRE(ground_truth->get_label_type() == LT_STRUCTURED, "The ground truth "
37  "labels must be of type CStructuredLabels\n");
38 
39  CStructuredLabels * pred_labs = CLabelsFactory::to_structured(predicted);
40  CStructuredLabels * true_labs = CLabelsFactory::to_structured(ground_truth);
41 
42  REQUIRE(pred_labs->get_structured_data_type() ==
43  true_labs->get_structured_data_type(), "Predicted and ground truth "
44  "labels must be composed of the same structured data\n");
45 
46  switch (pred_labs->get_structured_data_type())
47  {
48  case (SDT_REAL):
49  return evaluate_real(pred_labs, true_labs);
50 
51  case (SDT_SEQUENCE):
52  return evaluate_sequence(pred_labs, true_labs);
53 
54  case (SDT_SPARSE_MULTILABEL):
55  return evaluate_sparse_multilabel(pred_labs, true_labs);
56 
57  default:
58  SG_ERROR("Unknown structured data type for evaluation\n")
59  }
60 
61  return 0.0;
62 }
63 
65  CLabels * predicted, CLabels * ground_truth)
66 {
67  SG_SERROR("Not implemented\n")
68  return SGMatrix<int32_t>();
69 }
70 
71 float64_t CStructuredAccuracy::evaluate_real(CStructuredLabels * predicted,
72  CStructuredLabels * ground_truth)
73 {
74  int32_t length = predicted->get_num_labels();
75  int32_t num_equal = 0;
76 
77  for (int32_t i = 0 ; i < length ; ++i)
78  {
79  CRealNumber * truth = CRealNumber::obtain_from_generic(ground_truth->get_label(i));
81 
82  num_equal += truth->value == pred->value;
83 
84  SG_UNREF(truth);
85  SG_UNREF(pred);
86  }
87 
88  return (1.0 * num_equal) / length;
89 }
90 
91 float64_t CStructuredAccuracy::evaluate_sequence(CStructuredLabels * predicted,
92  CStructuredLabels * ground_truth)
93 {
94  int32_t length = predicted->get_num_labels();
95  // Accuracy of each each label
96  SGVector<float64_t> accuracies(length);
97  int32_t num_equal = 0;
98 
99  for (int32_t i = 0 ; i < length ; ++i)
100  {
101  CSequence * true_seq = CSequence::obtain_from_generic(ground_truth->get_label(i));
102  CSequence * pred_seq = CSequence::obtain_from_generic(predicted->get_label(i));
103 
104  SGVector<int32_t> true_seq_data = true_seq->get_data();
105  SGVector<int32_t> pred_seq_data = pred_seq->get_data();
106 
107  REQUIRE(true_seq_data.size() == pred_seq_data.size(), "Corresponding ground "
108  "truth and predicted sequences must be equally long\n");
109 
110  num_equal = 0;
111 
112  // Count the number of elements that are equal in both sequences
113  for (int32_t j = 0 ; j < true_seq_data.size() ; ++j)
114  {
115  num_equal += true_seq_data[j] == pred_seq_data[j];
116  }
117 
118  accuracies[i] = (1.0 * num_equal) / true_seq_data.size();
119 
120  SG_UNREF(true_seq);
121  SG_UNREF(pred_seq);
122  }
123 
124  return accuracies.mean();
125 }
126 
127 float64_t CStructuredAccuracy::evaluate_sparse_multilabel(CStructuredLabels * predicted,
128  CStructuredLabels * ground_truth)
129 {
130  CMultilabelSOLabels * multi_pred = (CMultilabelSOLabels *) predicted;
131  CMultilabelSOLabels * multi_truth = (CMultilabelSOLabels *) ground_truth;
132 
133  CMultilabelAccuracy * evaluator = new CMultilabelAccuracy();
134  SG_REF(evaluator);
135 
136  float64_t accuracy = evaluator->evaluate(multi_pred->get_multilabel_labels(),
137  multi_truth->get_multilabel_labels());
138 
139  SG_UNREF(evaluator);
140 
141  return accuracy;
142 }
143 

SHOGUN Machine Learning Toolbox - Documentation