Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/mathematics/Math.h>
00013
00014 using namespace shogun;
00015
00016 CROCEvaluation::~CROCEvaluation()
00017 {
00018 SG_FREE(m_ROC_graph);
00019 }
00020
00021 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00022 {
00023 ASSERT(predicted && ground_truth);
00024 ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels());
00025 ASSERT(ground_truth->is_two_class_labeling());
00026
00027
00028 float64_t threshold = CMath::ALMOST_NEG_INFTY;
00029
00030 float64_t fp = 0.0;
00031
00032 float64_t tp=0.0;
00033
00034 int32_t i;
00035
00036 int32_t pos_count=0;
00037 int32_t neg_count=0;
00038
00039
00040 SGVector<float64_t> orig_labels = predicted->get_labels();
00041 int32_t length = orig_labels.vlen;
00042 float64_t* labels = CMath::clone_vector(orig_labels.vector, length);
00043 orig_labels.free_vector();
00044
00045
00046 int32_t* idxs = SG_MALLOC(int32_t, length);
00047 for(i=0; i<length; i++)
00048 idxs[i] = i;
00049
00050 CMath::qsort_backward_index(labels,idxs,length);
00051
00052
00053 int32_t diff_count=1;
00054
00055
00056 for (i=0; i<length-1; i++)
00057 {
00058 if (labels[i] != labels[i+1])
00059 diff_count++;
00060 }
00061
00062 delete [] labels;
00063
00064
00065 SG_FREE(m_ROC_graph);
00066 m_ROC_graph = SG_MALLOC(float64_t, diff_count*2+2);
00067 m_thresholds = SG_MALLOC(float64_t, length);
00068 m_auROC = 0.0;
00069
00070
00071 for(i=0; i<length; i++)
00072 {
00073 if (ground_truth->get_label(i) > 0)
00074 pos_count++;
00075 else
00076 neg_count++;
00077 }
00078
00079
00080 ASSERT(pos_count>0 && neg_count>0);
00081
00082 int32_t j = 0;
00083 float64_t label;
00084
00085
00086 for(i=0; i<length; i++)
00087 {
00088 label = predicted->get_label(idxs[i]);
00089
00090 if (label != threshold)
00091 {
00092 threshold = label;
00093 m_ROC_graph[2*j] = fp/neg_count;
00094 m_ROC_graph[2*j+1] = tp/pos_count;
00095 j++;
00096 }
00097
00098 m_thresholds[i]=threshold;
00099
00100 if (ground_truth->get_label(idxs[i]) > 0)
00101 tp+=1.0;
00102 else
00103 fp+=1.0;
00104 }
00105
00106
00107 m_ROC_graph[2*diff_count] = 1.0;
00108 m_ROC_graph[2*diff_count+1] = 1.0;
00109
00110
00111 m_ROC_length = diff_count+1;
00112
00113
00114 m_auROC = CMath::area_under_curve(m_ROC_graph,m_ROC_length,false);
00115
00116 m_computed = true;
00117
00118 return m_auROC;
00119 }
00120
00121 SGMatrix<float64_t> CROCEvaluation::get_ROC()
00122 {
00123 if (!m_computed)
00124 SG_ERROR("Uninitialized, please call evaluate first");
00125
00126 ASSERT(m_ROC_graph);
00127
00128 return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length);
00129 }
00130
00131 SGVector<float64_t> CROCEvaluation::get_thresholds()
00132 {
00133 if (!m_computed)
00134 SG_ERROR("Uninitialized, please call evaluate first");
00135
00136 ASSERT(m_thresholds);
00137
00138 return SGVector<float64_t>(m_thresholds,m_ROC_length);
00139 }
00140
00141 float64_t CROCEvaluation::get_auROC()
00142 {
00143 if (!m_computed)
00144 SG_ERROR("Uninitialized, please call evaluate first");
00145
00146 return m_auROC;
00147 }