Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/CrossValidationSplitting.h>
00012 #include <shogun/multiclass/tree/RelaxedTreeUtil.h>
00013 #include <shogun/evaluation/MulticlassAccuracy.h>
00014
00015 using namespace shogun;
00016
00017 SGMatrix<float64_t> RelaxedTreeUtil::estimate_confusion_matrix(CBaseMulticlassMachine *machine, CFeatures *X, CMulticlassLabels *Y, int32_t num_classes)
00018 {
00019 const int32_t N_splits = 2;
00020 CCrossValidationSplitting *split = new CCrossValidationSplitting(Y, N_splits);
00021 split->build_subsets();
00022
00023 SGMatrix<float64_t> conf_mat(num_classes, num_classes), tmp_mat(num_classes, num_classes);
00024 conf_mat.zero();
00025
00026 machine->set_labels(Y);
00027 machine->set_store_model_features(true);
00028
00029 for (int32_t i=0; i < N_splits; ++i)
00030 {
00031
00032 SGVector<index_t> inverse_subset_indices = split->generate_subset_inverse(i);
00033 X->add_subset(inverse_subset_indices);
00034 Y->add_subset(inverse_subset_indices);
00035
00036 machine->train(X);
00037 X->remove_subset();
00038 Y->remove_subset();
00039
00040
00041 SGVector<index_t> subset_indices = split->generate_subset_indices(i);
00042 X->add_subset(subset_indices);
00043 Y->add_subset(subset_indices);
00044
00045 CMulticlassLabels *pred = machine->apply_multiclass(X);
00046
00047 get_confusion_matrix(tmp_mat, Y, pred);
00048
00049 for (index_t j=0; j < tmp_mat.num_rows; ++j)
00050 {
00051 for (index_t k=0; k < tmp_mat.num_cols; ++k)
00052 {
00053 conf_mat(j, k) += tmp_mat(j, k);
00054 }
00055 }
00056
00057 SG_UNREF(pred);
00058
00059 X->remove_subset();
00060 Y->remove_subset();
00061 }
00062
00063 SG_UNREF(split);
00064
00065 for (index_t j=0; j < tmp_mat.num_rows; ++j)
00066 {
00067 for (index_t k=0; k < tmp_mat.num_cols; ++k)
00068 {
00069 conf_mat(j, k) /= N_splits;
00070 }
00071 }
00072
00073 return conf_mat;
00074 }
00075
00076 void RelaxedTreeUtil::get_confusion_matrix(SGMatrix<float64_t> &conf_mat, CMulticlassLabels *gt, CMulticlassLabels *pred)
00077 {
00078 SGMatrix<int32_t> conf_mat_int = CMulticlassAccuracy::get_confusion_matrix(pred, gt);
00079
00080 for (index_t i=0; i < conf_mat.num_rows; ++i)
00081 {
00082 float64_t n=0;
00083 for (index_t j=0; j < conf_mat.num_cols; ++j)
00084 {
00085 conf_mat(i, j) = conf_mat_int(i, j);
00086 n += conf_mat(i, j);
00087 }
00088
00089 if (n != 0)
00090 {
00091 for (index_t j=0; j < conf_mat.num_cols; ++j)
00092 conf_mat(i, j) /= n;
00093 }
00094 }
00095 }