ROCEvaluation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/mathematics/Math.h>
00013 
00014 using namespace shogun;
00015 
00016 CROCEvaluation::~CROCEvaluation()
00017 {
00018 }
00019 
00020 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00021 {
00022     return evaluate_roc(predicted,ground_truth);
00023 }
00024 
00025 float64_t CROCEvaluation::evaluate_roc(CLabels* predicted, CLabels* ground_truth)
00026 {
00027     ASSERT(predicted && ground_truth);
00028     ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels());
00029     ASSERT(predicted->get_label_type()==LT_BINARY);
00030     ASSERT(ground_truth->get_label_type()==LT_BINARY);
00031     ground_truth->ensure_valid();
00032 
00033     // assume threshold as negative infinity
00034     float64_t threshold = CMath::ALMOST_NEG_INFTY;
00035     // false positive rate
00036     float64_t fp = 0.0;
00037     // true positive rate
00038     float64_t tp=0.0;
00039 
00040     int32_t i;
00041     // total number of positive labels in predicted
00042     int32_t pos_count=0;
00043     int32_t neg_count=0;
00044 
00045     // initialize number of labels and labels
00046     SGVector<float64_t> orig_labels(predicted->get_num_labels());
00047     int32_t length = orig_labels.vlen;
00048     for (i=0; i<length; i++)
00049         orig_labels[i] = predicted->get_value(i);
00050     float64_t* labels = SGVector<float64_t>::clone_vector(orig_labels.vector, length);
00051 
00052     // get sorted indexes
00053     int32_t* idxs = SG_MALLOC(int32_t, length);
00054     for(i=0; i<length; i++)
00055         idxs[i] = i;
00056 
00057     CMath::qsort_backward_index(labels,idxs,length);
00058 
00059     // number of different predicted labels
00060     int32_t diff_count=1;
00061 
00062     // get number of different labels
00063     for (i=0; i<length-1; i++)
00064     {
00065         if (labels[i] != labels[i+1])
00066             diff_count++;
00067     }
00068 
00069     SG_FREE(labels);
00070 
00071     // initialize graph and auROC
00072     m_ROC_graph = SGMatrix<float64_t>(2,diff_count+1);
00073     m_thresholds = SGVector<float64_t>(length);
00074     m_auROC = 0.0;
00075 
00076     // get total numbers of positive and negative labels
00077     for(i=0; i<length; i++)
00078     {
00079         if (ground_truth->get_value(i) >= 0)
00080             pos_count++;
00081         else
00082             neg_count++;
00083     }
00084 
00085     // assure both number of positive and negative examples is >0
00086     ASSERT(pos_count>0 && neg_count>0);
00087 
00088     int32_t j = 0;
00089     float64_t label;
00090 
00091     // create ROC curve and calculate auROC
00092     for(i=0; i<length; i++)
00093     {
00094         label = predicted->get_value(idxs[i]);
00095 
00096         if (label != threshold)
00097         {
00098             threshold = label;
00099             m_ROC_graph[2*j] = fp/neg_count;
00100             m_ROC_graph[2*j+1] = tp/pos_count;
00101             j++;
00102         }
00103 
00104         m_thresholds[i]=threshold;
00105 
00106         if (ground_truth->get_value(idxs[i]) > 0)
00107             tp+=1.0;
00108         else
00109             fp+=1.0;
00110     }
00111 
00112     // add (1,1) to ROC curve
00113     m_ROC_graph[2*diff_count] = 1.0;
00114     m_ROC_graph[2*diff_count+1] = 1.0;
00115 
00116     // calc auROC using area under curve
00117     m_auROC = CMath::area_under_curve(m_ROC_graph.matrix,diff_count+1,false);
00118 
00119     m_computed = true;
00120 
00121     return m_auROC;
00122 }
00123 
00124 SGMatrix<float64_t> CROCEvaluation::get_ROC()
00125 {
00126     if (!m_computed)
00127         SG_ERROR("Uninitialized, please call evaluate first");
00128 
00129     return m_ROC_graph;
00130 }
00131 
00132 SGVector<float64_t> CROCEvaluation::get_thresholds()
00133 {
00134     if (!m_computed)
00135         SG_ERROR("Uninitialized, please call evaluate first");
00136 
00137     return m_thresholds;
00138 }
00139 
00140 float64_t CROCEvaluation::get_auROC()
00141 {
00142     if (!m_computed)
00143             SG_ERROR("Uninitialized, please call evaluate first");
00144 
00145     return m_auROC;
00146 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation