RelaxedTreeUtil.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Chiyuan Zhang
00008  * Copyright (C) 2012 Chiyuan Zhang
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidationSplitting.h>
00012 #include <shogun/multiclass/tree/RelaxedTreeUtil.h>
00013 #include <shogun/evaluation/MulticlassAccuracy.h>
00014 
00015 using namespace shogun;
00016 
00017 SGMatrix<float64_t> RelaxedTreeUtil::estimate_confusion_matrix(CBaseMulticlassMachine *machine, CFeatures *X, CMulticlassLabels *Y, int32_t num_classes)
00018 {
00019     const int32_t N_splits = 2; // 5
00020     CCrossValidationSplitting *split = new CCrossValidationSplitting(Y, N_splits);
00021     split->build_subsets();
00022 
00023     SGMatrix<float64_t> conf_mat(num_classes, num_classes), tmp_mat(num_classes, num_classes);
00024     conf_mat.zero();
00025 
00026     machine->set_labels(Y);
00027     machine->set_store_model_features(true);
00028 
00029     for (int32_t i=0; i < N_splits; ++i)
00030     {
00031         // subset for training
00032         SGVector<index_t> inverse_subset_indices = split->generate_subset_inverse(i);
00033         X->add_subset(inverse_subset_indices);
00034         Y->add_subset(inverse_subset_indices);
00035 
00036         machine->train(X);
00037         X->remove_subset();
00038         Y->remove_subset();
00039 
00040         // subset for predicting
00041         SGVector<index_t> subset_indices = split->generate_subset_indices(i);
00042         X->add_subset(subset_indices);
00043         Y->add_subset(subset_indices);
00044 
00045         CMulticlassLabels *pred = machine->apply_multiclass(X);
00046 
00047         get_confusion_matrix(tmp_mat, Y, pred);
00048 
00049         for (index_t j=0; j < tmp_mat.num_rows; ++j)
00050         {
00051             for (index_t k=0; k < tmp_mat.num_cols; ++k)
00052             {
00053                 conf_mat(j, k) += tmp_mat(j, k);
00054             }
00055         }
00056 
00057         SG_UNREF(pred);
00058 
00059         X->remove_subset();
00060         Y->remove_subset();
00061     }
00062 
00063     SG_UNREF(split);
00064 
00065     for (index_t j=0; j < tmp_mat.num_rows; ++j)
00066     {
00067         for (index_t k=0; k < tmp_mat.num_cols; ++k)
00068         {
00069             conf_mat(j, k) /= N_splits;
00070         }
00071     }
00072 
00073     return conf_mat;
00074 }
00075 
00076 void RelaxedTreeUtil::get_confusion_matrix(SGMatrix<float64_t> &conf_mat, CMulticlassLabels *gt, CMulticlassLabels *pred)
00077 {
00078     SGMatrix<int32_t> conf_mat_int = CMulticlassAccuracy::get_confusion_matrix(pred, gt);
00079 
00080     for (index_t i=0; i < conf_mat.num_rows; ++i)
00081     {
00082         float64_t n=0;
00083         for (index_t j=0; j < conf_mat.num_cols; ++j)
00084         {
00085             conf_mat(i, j) = conf_mat_int(i, j);
00086             n += conf_mat(i, j);
00087         }
00088 
00089         if (n != 0)
00090         {
00091             for (index_t j=0; j < conf_mat.num_cols; ++j)
00092                 conf_mat(i, j) /= n;
00093         }
00094     }
00095 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation