ClusteringMutualInformation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/lib/SGVector.h>
00012 #include <shogun/labels/MulticlassLabels.h>
00013 #include <shogun/evaluation/ClusteringMutualInformation.h>
00014 
00015 using namespace shogun;
00016 
00017 float64_t CClusteringMutualInformation::evaluate(CLabels* predicted, CLabels* ground_truth)
00018 {
00019     ASSERT(predicted && ground_truth);
00020     ASSERT(predicted->get_label_type() == LT_MULTICLASS);
00021     ASSERT(ground_truth->get_label_type() == LT_MULTICLASS);
00022     SGVector<float64_t> label_p=((CMulticlassLabels*) predicted)->get_unique_labels();
00023     SGVector<float64_t> label_g=((CMulticlassLabels*) ground_truth)->get_unique_labels();
00024 
00025     if (label_p.vlen != label_g.vlen)
00026         SG_ERROR("Number of classes are different\n");
00027     uint32_t n_class=label_p.vlen;
00028     float64_t n_label=predicted->get_num_labels();
00029 
00030     SGVector<int32_t> ilabels_p=((CMulticlassLabels*) predicted)->get_int_labels();
00031     SGVector<int32_t> ilabels_g=((CMulticlassLabels*) ground_truth)->get_int_labels();
00032 
00033     SGMatrix<float64_t> G(n_class, n_class);
00034     for (size_t i=0; i < n_class; ++i)
00035     {
00036         for (size_t j=0; j < n_class; ++j)
00037             G(i, j)=find_match_count(ilabels_g, label_g[i],
00038                 ilabels_p, label_p[j])/n_label;
00039     }
00040 
00041     SGVector<float64_t> G_rowsum(n_class);
00042     SGVector<float64_t> G_colsum(n_class);
00043     for (size_t i=0; i < n_class; ++i)
00044     {
00045         for (size_t j=0; j < n_class; ++j)
00046         {
00047             G_rowsum[i] += G(i, j);
00048             G_colsum[i] += G(j, i);
00049         }
00050     }
00051 
00052     float64_t mutual_info = 0;
00053     for (size_t i=0; i < n_class; ++i)
00054     {
00055         for (size_t j=0; j < n_class; ++j)
00056         {
00057             if (G(i, j) != 0)
00058                 mutual_info += G(i, j) * log(G(i,j) /
00059                     (G_rowsum[i]*G_colsum[j]))/log(2.);
00060         }
00061     }
00062 
00063     float64_t entropy_p = 0;
00064     float64_t entropy_g = 0;
00065     for (size_t i=0; i < n_class; ++i)
00066     {
00067         entropy_g += -G_rowsum[i] * log(G_rowsum[i])/log(2.);
00068         entropy_p += -G_colsum[i] * log(G_colsum[i])/log(2.);
00069     }
00070 
00071     return mutual_info / CMath::max(entropy_g, entropy_p);
00072 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation