ClusteringEvaluation.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <set>
00012 #include <map>
00013 #include <vector>
00014 #include <algorithm>
00015 
00016 #include <shogun/evaluation/ClusteringEvaluation.h>
00017 #include <shogun/labels/MulticlassLabels.h>
00018 #include <shogun/mathematics/munkres.h>
00019 
00020 using namespace shogun;
00021 using namespace std;
00022 
00023 int32_t CClusteringEvaluation::find_match_count(SGVector<int32_t> l1, int32_t m1, SGVector<int32_t> l2, int32_t m2)
00024 {
00025     int32_t match_count=0;
00026     for (int32_t i=l1.vlen-1; i >= 0; --i)
00027     {
00028         if (l1[i] == m1 && l2[i] == m2)
00029             match_count++;
00030     }
00031 
00032     return match_count;
00033 }
00034 
00035 int32_t CClusteringEvaluation::find_mismatch_count(SGVector<int32_t> l1, int32_t m1, SGVector<int32_t> l2, int32_t m2)
00036 {
00037     return l1.vlen - find_match_count(l1, m1, l2, m2);
00038 }
00039 
00040 void CClusteringEvaluation::best_map(CLabels* predicted, CLabels* ground_truth)
00041 {
00042     ASSERT(predicted->get_num_labels() == ground_truth->get_num_labels());
00043     ASSERT(predicted->get_label_type() == LT_MULTICLASS);
00044     ASSERT(ground_truth->get_label_type() == LT_MULTICLASS);
00045 
00046     SGVector<float64_t> label_p=((CMulticlassLabels*) predicted)->get_unique_labels();
00047     SGVector<float64_t> label_g=((CMulticlassLabels*) ground_truth)->get_unique_labels();
00048 
00049     SGVector<int32_t> predicted_ilabels=((CMulticlassLabels*) predicted)->get_int_labels();
00050     SGVector<int32_t> groundtruth_ilabels=((CMulticlassLabels*) ground_truth)->get_int_labels();
00051 
00052     int32_t n_class=max(label_p.vlen, label_g.vlen);
00053     SGMatrix<float64_t> G(n_class, n_class);
00054     G.zero();
00055 
00056     for (int32_t i=0; i < label_g.vlen; ++i)
00057     {
00058         for (int32_t j=0; j < label_p.vlen; ++j)
00059         {
00060             G(i, j)=find_mismatch_count(groundtruth_ilabels, static_cast<int32_t>(label_g[i]),
00061                 predicted_ilabels, static_cast<int32_t>(label_p[j]));
00062         }
00063     }
00064 
00065     Munkres munkres_solver(G);
00066     munkres_solver.solve();
00067 
00068     std::map<int32_t, int32_t> label_map;
00069     for (int32_t i=0; i < label_p.vlen; ++i)
00070     {
00071         for (int32_t j=0; j < label_g.vlen; ++j)
00072         {
00073             if (G(j, i) == 0)
00074             {
00075                 label_map.insert(make_pair(static_cast<int32_t>(label_p[i]), 
00076                         static_cast<int32_t>(label_g[j])));
00077                 break;
00078             }
00079         }
00080     }
00081 
00082     for (int32_t i= 0; i < predicted_ilabels.vlen; ++i)
00083         ((CMulticlassLabels*) predicted)->set_int_label(i, label_map[predicted_ilabels[i]]);
00084 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation