ROCEvaluation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011 Sergey Lisitsyn
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/evaluation/ROCEvaluation.h>
00012 #include <shogun/mathematics/Math.h>
00013 
00014 using namespace shogun;
00015 
00016 CROCEvaluation::~CROCEvaluation()
00017 {
00018     SG_FREE(m_ROC_graph);
00019 }
00020 
00021 float64_t CROCEvaluation::evaluate(CLabels* predicted, CLabels* ground_truth)
00022 {
00023     ASSERT(predicted && ground_truth);
00024     ASSERT(predicted->get_num_labels()==ground_truth->get_num_labels());
00025     ASSERT(ground_truth->is_two_class_labeling());
00026 
00027     // assume threshold as negative infinity
00028     float64_t threshold = CMath::ALMOST_NEG_INFTY;
00029     // false positive rate
00030     float64_t fp = 0.0;
00031     // true positive rate
00032     float64_t tp=0.0;
00033 
00034     int32_t i;
00035     // total number of positive labels in predicted
00036     int32_t pos_count=0;
00037     int32_t neg_count=0;
00038 
00039     // initialize number of labels and labels
00040     SGVector<float64_t> orig_labels = predicted->get_labels();
00041     int32_t length = orig_labels.vlen;
00042     float64_t* labels = CMath::clone_vector(orig_labels.vector, length);
00043     orig_labels.free_vector();
00044 
00045     // get sorted indexes
00046     int32_t* idxs = SG_MALLOC(int32_t, length);
00047     for(i=0; i<length; i++)
00048         idxs[i] = i;
00049 
00050     CMath::qsort_backward_index(labels,idxs,length);
00051 
00052     // number of different predicted labels
00053     int32_t diff_count=1;
00054 
00055     // get number of different labels
00056     for (i=0; i<length-1; i++)
00057     {
00058         if (labels[i] != labels[i+1])
00059             diff_count++;
00060     }
00061 
00062     delete [] labels;
00063 
00064     // initialize graph and auROC
00065     SG_FREE(m_ROC_graph);
00066     m_ROC_graph = SG_MALLOC(float64_t, diff_count*2+2);
00067     m_auROC = 0.0;
00068 
00069     // get total numbers of positive and negative labels
00070     for(i=0; i<length; i++)
00071     {
00072         if (ground_truth->get_label(i) > 0)
00073             pos_count++;
00074         else
00075             neg_count++;
00076     }
00077 
00078     // assure both number of positive and negative examples is >0
00079     ASSERT(pos_count>0 && neg_count>0);
00080 
00081     int32_t j = 0;
00082     float64_t label;
00083 
00084     // create ROC curve and calculate auROC
00085     for(i=0; i<length; i++)
00086     {
00087         label = predicted->get_label(idxs[i]);
00088 
00089         if (label != threshold)
00090         {
00091             threshold = label;
00092             m_ROC_graph[2*j] = fp/neg_count;
00093             m_ROC_graph[2*j+1] = tp/pos_count;
00094             j++;
00095         }
00096 
00097         if (ground_truth->get_label(idxs[i]) > 0)
00098             tp+=1.0;
00099         else
00100             fp+=1.0;
00101     }
00102 
00103     // add (1,1) to ROC curve
00104     m_ROC_graph[2*diff_count] = 1.0;
00105     m_ROC_graph[2*diff_count+1] = 1.0;
00106 
00107     // set ROC length
00108     m_ROC_length = diff_count+1;
00109 
00110     // calc auROC using area under curve
00111     m_auROC = CMath::area_under_curve(m_ROC_graph,m_ROC_length,false);
00112 
00113     m_computed = true;
00114 
00115     return m_auROC;
00116 }
00117 
00118 SGMatrix<float64_t> CROCEvaluation::get_ROC()
00119 {
00120     if (!m_computed)
00121         SG_ERROR("Uninitialized, please call evaluate first");
00122 
00123     ASSERT(m_ROC_graph);
00124 
00125     return SGMatrix<float64_t>(m_ROC_graph,2,m_ROC_length);
00126 }
00127 
00128 float64_t CROCEvaluation::get_auROC()
00129 {
00130     if (!m_computed)
00131             SG_ERROR("Uninitialized, please call evaluate first");
00132 
00133     return m_auROC;
00134 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation