Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/evaluation/CrossValidationPrintOutput.h>
00012 #include <shogun/machine/LinearMachine.h>
00013 #include <shogun/machine/LinearMulticlassMachine.h>
00014 #include <shogun/machine/KernelMachine.h>
00015 #include <shogun/machine/KernelMulticlassMachine.h>
00016 #include <shogun/kernel/CombinedKernel.h>
00017 #include <shogun/classifier/mkl/MKL.h>
00018 #include <shogun/classifier/mkl/MKLMulticlass.h>
00019
00020 using namespace shogun;
00021
00022 void CCrossValidationPrintOutput::init_num_runs(index_t num_runs,
00023 const char* prefix)
00024 {
00025 SG_PRINT("%scross validation number of runs %d\n", prefix, num_runs);
00026 }
00027
00029 void CCrossValidationPrintOutput::init_num_folds(index_t num_folds,
00030 const char* prefix)
00031 {
00032 SG_PRINT("%scross validation number of folds %d\n", prefix, num_folds);
00033 }
00034
00035 void CCrossValidationPrintOutput::update_run_index(index_t run_index,
00036 const char* prefix)
00037 {
00038 SG_PRINT("%scross validation run %d\n", prefix, run_index);
00039 }
00040
00041 void CCrossValidationPrintOutput::update_fold_index(index_t fold_index,
00042 const char* prefix)
00043 {
00044 SG_PRINT("%sfold %d\n", prefix, fold_index);
00045 }
00046
00047 void CCrossValidationPrintOutput::update_train_indices(
00048 SGVector<index_t> indices, const char* prefix)
00049 {
00050 indices.display_vector("train_indices", prefix);
00051 }
00052
00053 void CCrossValidationPrintOutput::update_test_indices(
00054 SGVector<index_t> indices, const char* prefix)
00055 {
00056 indices.display_vector("test_indices", prefix);
00057 }
00058
00059 void CCrossValidationPrintOutput::update_trained_machine(
00060 CMachine* machine, const char* prefix)
00061 {
00062 if (dynamic_cast<CLinearMachine*>(machine))
00063 {
00064 CLinearMachine* linear_machine=(CLinearMachine*)machine;
00065 linear_machine->get_w().display_vector("learned_w", prefix);
00066 SG_PRINT("%slearned_bias=%f\n", prefix, linear_machine->get_bias());
00067 }
00068
00069 if (dynamic_cast<CKernelMachine*>(machine))
00070 {
00071 CKernelMachine* kernel_machine=(CKernelMachine*)machine;
00072 kernel_machine->get_alphas().display_vector("learned_alphas", prefix);
00073 SG_PRINT("%slearned_bias=%f\n", prefix, kernel_machine->get_bias());
00074 }
00075
00076 if (dynamic_cast<CLinearMulticlassMachine*>(machine)
00077 || dynamic_cast<CKernelMulticlassMachine*>(machine))
00078 {
00079
00080 char* new_prefix=append_tab_to_string(prefix);
00081
00082 CMulticlassMachine* mc_machine=(CMulticlassMachine*)machine;
00083 for (int i=0; i<mc_machine->get_num_machines(); i++)
00084 {
00085 CMachine* sub_machine=mc_machine->get_machine(i);
00086
00087 this->update_trained_machine(sub_machine, new_prefix);
00088 SG_UNREF(sub_machine);
00089 }
00090
00091
00092 SG_FREE(new_prefix);
00093 }
00094
00095 if (dynamic_cast<CMKL*>(machine))
00096 {
00097 CMKL* mkl=(CMKL*)machine;
00098 CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>(
00099 mkl->get_kernel());
00100 kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights",
00101 prefix);
00102 SG_UNREF(kernel);
00103 }
00104
00105 if (dynamic_cast<CMKLMulticlass*>(machine))
00106 {
00107 CMKLMulticlass* mkl=(CMKLMulticlass*)machine;
00108 CCombinedKernel* kernel=dynamic_cast<CCombinedKernel*>(
00109 mkl->get_kernel());
00110 kernel->get_subkernel_weights().display_vector("MKL sub-kernel weights",
00111 prefix);
00112 SG_UNREF(kernel);
00113 }
00114 }
00115
00116 void CCrossValidationPrintOutput::update_test_result(CLabels* results,
00117 const char* prefix)
00118 {
00119 results->get_values().display_vector("test_labels", prefix);
00120 }
00121
00122 void CCrossValidationPrintOutput::update_test_true_result(CLabels* results,
00123 const char* prefix)
00124 {
00125 results->get_values().display_vector("true_labels", prefix);
00126 }
00127
00128 void CCrossValidationPrintOutput::update_evaluation_result(float64_t result,
00129 const char* prefix)
00130 {
00131 SG_PRINT("%sevaluation result=%f\n", prefix, result);
00132 }
00133
00134 char* CCrossValidationPrintOutput::append_tab_to_string(const char* string)
00135 {
00136
00137 index_t len=strlen(string);
00138 char* new_prefix=SG_MALLOC(char, len+2);
00139 memcpy(new_prefix, string, sizeof(char*)*len);
00140 new_prefix[len]='\t';
00141 new_prefix[len+1]='\0';
00142
00143 return new_prefix;
00144 }