Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/PRCEvaluation.h>
00012 #include <shogun/labels/RegressionLabels.h>
00013 #include <shogun/labels/BinaryLabels.h>
00014 #include <shogun/mathematics/Math.h>
00015
00016 using namespace shogun;
00017
00018 CPRCEvaluation::~CPRCEvaluation()
00019 {
00020 }
00021
00022 float64_t CPRCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00023 {
00024 ASSERT(predicted && ground_truth);
00025 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels());
00026 ASSERT(predicted->get_label_type()==LT_BINARY);
00027 ASSERT(ground_truth->get_label_type()==LT_BINARY);
00028 ground_truth->ensure_valid();
00029
00030
00031 float64_t tp = 0.0;
00032 int32_t i;
00033
00034
00035 int32_t pos_count=0;
00036
00037
00038 SGVector<float64_t> orig_labels = predicted->get_values();
00039 int32_t length = orig_labels.vlen;
00040 float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length);
00041
00042
00043 int32_t* idxs = SG_MALLOC(int32_t, length);
00044 for(i=0; i<length; i++)
00045 idxs[i] = i;
00046
00047
00048 CMath::qsort_backward_index(labels,idxs,length);
00049
00050
00051 SG_FREE(labels);
00052 m_PRC_graph = SGMatrix<float64_t>(2,length);
00053 m_thresholds = SGVector<float64_t>(length);
00054 m_auPRC = 0.0;
00055
00056
00057 for (i=0; i<length; i++)
00058 {
00059 if (ground_truth->get_value(i) > 0)
00060 pos_count++;
00061 }
00062
00063
00064 ASSERT(pos_count>0);
00065
00066
00067 for (i=0; i<length; i++)
00068 {
00069
00070 if (ground_truth->get_value(idxs[i]) > 0)
00071 tp += 1.0;
00072
00073
00074 m_PRC_graph[2*i] = tp/float64_t(i+1);
00075
00076 m_PRC_graph[2*i+1] = tp/float64_t(pos_count);
00077
00078 m_thresholds[i]= predicted->get_value(idxs[i]);
00079 }
00080
00081
00082 m_auPRC = CMath::area_under_curve(m_PRC_graph.matrix,length,true);
00083
00084
00085 m_computed = true;
00086
00087 return m_auPRC;
00088 }
00089
00090 SGMatrix<float64_t> CPRCEvaluation::get_PRC()
00091 {
00092 if (!m_computed)
00093 SG_ERROR("Uninitialized, please call evaluate first");
00094
00095 return m_PRC_graph;
00096 }
00097
00098 SGVector<float64_t> CPRCEvaluation::get_thresholds()
00099 {
00100 if (!m_computed)
00101 SG_ERROR("Uninitialized, please call evaluate first");
00102
00103 return m_thresholds;
00104 }
00105
00106 float64_t CPRCEvaluation::get_auPRC()
00107 {
00108 if (!m_computed)
00109 SG_ERROR("Uninitialized, please call evaluate first");
00110
00111 return m_auPRC;
00112 }