CrossValidationPrintOutput.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Sergey Lisitsyn
00008  * Written (W) 2012 Heiko Strathmann
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidationPrintOutput.h>
00012 #include <shogun/machine/LinearMachine.h>
00013 #include <shogun/machine/LinearMulticlassMachine.h>
00014 #include <shogun/machine/KernelMachine.h>
00015 #include <shogun/machine/KernelMulticlassMachine.h>
00016 #include <shogun/kernel/CombinedKernel.h>
00017 #include <shogun/classifier/mkl/MKL.h>
00018 #include <shogun/classifier/mkl/MKLMulticlass.h>
00019 
00020 using namespace shogun;
00021 
00022 void CCrossValidationPrintOutput::init_num_runs(index_t num_runs,
00023         const char* prefix)
00024 {
00025     SG_PRINT("%scross validation number of runs %d\n", prefix, num_runs);
00026 }
00027 
00029 void CCrossValidationPrintOutput::init_num_folds(index_t num_folds,
00030         const char* prefix)
00031 {
00032     SG_PRINT("%scross validation number of folds %d\n", prefix, num_folds);
00033 }
00034 
00035 void CCrossValidationPrintOutput::update_run_index(index_t run_index,
00036         const char* prefix)
00037 {
00038     SG_PRINT("%scross validation run %d\n", prefix, run_index);
00039 }
00040 
00041 void CCrossValidationPrintOutput::update_fold_index(index_t fold_index,
00042         const char* prefix)
00043 {
00044     SG_PRINT("%sfold %d\n", prefix, fold_index);
00045 }
00046 
00047 void CCrossValidationPrintOutput::update_train_indices(
00048         SGVector<index_t> indices, const char* prefix)
00049 {
00050     indices.display_vector("train_indices", prefix);
00051 }
00052 
00053 void CCrossValidationPrintOutput::update_test_indices(
00054         SGVector<index_t> indices, const char* prefix)
00055 {
00056     indices.display_vector("test_indices", prefix);
00057 }
00058 
00059 void CCrossValidationPrintOutput::update_trained_machine(
00060         CMachine* machine, const char* prefix)
00061 {
00062     if (dynamic_cast<CLinearMachine*>(machine))
00063     {
00064         CLinearMachine* linear_machine=(CLinearMachine*)machine;
00065         linear_machine->get_w().display_vector("learned_w", prefix);
00066         SG_PRINT("%slearned_bias=%f\n", prefix, linear_machine->get_bias());
00067     }
00068 
00069     if (dynamic_cast<CKernelMachine*>(machine))
00070     {
00071         CKernelMachine* kernel_machine=(CKernelMachine*)machine;
00072         kernel_machine->get_alphas().display_vector("learned_alphas", prefix);
00073         SG_PRINT("%slearned_bias=%f\n", prefix, kernel_machine->get_bias());
00074     }
00075 
00076     if (dynamic_cast<CLinearMulticlassMachine*>(machine)
00077             || dynamic_cast<CKernelMulticlassMachine*>(machine))
00078     {
00079         /* append one tab to prefix */
00080         char* new_prefix=append_tab_to_string(prefix);
00081 
00082         CMulticlassMachine* mc_machine=(CMulticlassMachine*)machine;
00083         for (int i=0; i<mc_machine->get_num_machines(); i++)
00084         {
00085             CMachine* sub_machine=mc_machine->get_machine(i);
00086             //SG_PRINT("%smulti-class machine %d:\n", i, sub_machine);
00087             this->update_trained_machine(sub_machine, new_prefix);
00088             SG_UNREF(sub_machine);
00089         }
00090 
00091         /* clean up */
00092         SG_FREE(new_prefix);
00093     }
00094 
00095     if (dynamic_cast<CMKL*>(machine))
00096     {
00097         CMKL* mkl=(CMKL*)machine;
00098         CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>(
00099                 mkl->get_kernel());
00100         kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights",
00101                 prefix);
00102         SG_UNREF(kernel);
00103     }
00104     
00105     if (dynamic_cast<CMKLMulticlass*>(machine))
00106     {
00107         CMKLMulticlass* mkl=(CMKLMulticlass*)machine;
00108         CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>(
00109                 mkl->get_kernel());
00110         kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights",
00111                 prefix);
00112         SG_UNREF(kernel);
00113     }
00114 }
00115 
00116 void CCrossValidationPrintOutput::update_test_result(CLabels* results,
00117         const char* prefix)
00118 {
00119     results->get_values().display_vector("test_labels", prefix);
00120 }
00121 
00122 void CCrossValidationPrintOutput::update_test_true_result(CLabels* results,
00123         const char* prefix)
00124 {
00125     results->get_values().display_vector("true_labels", prefix);
00126 }
00127 
00128 void CCrossValidationPrintOutput::update_evaluation_result(float64_t result,
00129         const char* prefix)
00130 {
00131     SG_PRINT("%sevaluation result=%f\n", prefix, result);
00132 }
00133 
00134 char* CCrossValidationPrintOutput::append_tab_to_string(const char* string)
00135 {
00136     /* allocate memory, concatenate and add termination character */
00137     index_t len=strlen(string);
00138     char* new_prefix=SG_MALLOC(char, len+2);
00139     memcpy(new_prefix, string, sizeof(char*)*len);
00140     new_prefix[len]='\t';
00141     new_prefix[len+1]='\0';
00142 
00143     return new_prefix;
00144 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation