00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011-2012 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/modelselection/GridSearchModelSelection.h> 00012 #include <shogun/modelselection/ParameterCombination.h> 00013 #include <shogun/modelselection/ModelSelectionParameters.h> 00014 #include <shogun/evaluation/CrossValidation.h> 00015 #include <shogun/machine/Machine.h> 00016 00017 using namespace shogun; 00018 00019 CGridSearchModelSelection::CGridSearchModelSelection() : 00020 CModelSelection(NULL, NULL) 00021 { 00022 00023 } 00024 00025 CGridSearchModelSelection::CGridSearchModelSelection( 00026 CModelSelectionParameters* model_parameters, 00027 CMachineEvaluation* machine_eval) : 00028 CModelSelection(model_parameters, machine_eval) 00029 { 00030 00031 } 00032 00033 CGridSearchModelSelection::~CGridSearchModelSelection() 00034 { 00035 } 00036 00037 CParameterCombination* CGridSearchModelSelection::select_model(bool print_state) 00038 { 00039 if (print_state) 00040 SG_PRINT("Generating parameter combinations\n"); 00041 00042 /* Retrieve all possible parameter combinations */ 00043 CDynamicObjectArray* combinations= 00044 (CDynamicObjectArray*)m_model_parameters->get_combinations(); 00045 00046 CCrossValidationResult* best_result = new CCrossValidationResult(); 00047 00048 CParameterCombination* best_combination=NULL; 00049 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00050 { 00051 if (print_state) SG_PRINT("Direction is maximize\n"); 00052 best_result->mean=CMath::ALMOST_NEG_INFTY; 00053 } 00054 else 00055 { 00056 if (print_state) SG_PRINT("Direction is maximize\n"); 00057 best_result->mean=CMath::ALMOST_INFTY; 00058 } 00059 00060 /* underlying learning machine */ 00061 CMachine* machine=m_machine_eval->get_machine(); 00062 00063 /* apply all combinations and search for best one */ 00064 for (index_t i=0; i<combinations->get_num_elements(); ++i) 00065 { 00066 CParameterCombination* current_combination=(CParameterCombination*) 00067 combinations->get_element(i); 00068 00069 /* eventually print */ 00070 if (print_state) 00071 { 00072 SG_PRINT("trying combination:\n"); 00073 current_combination->print_tree(); 00074 } 00075 00076 current_combination->apply_to_modsel_parameter( 00077 machine->m_model_selection_parameters); 00078 00079 /* note that this may implicitly lock and unlockthe machine */ 00080 CCrossValidationResult* result = 00081 (CCrossValidationResult*)(m_machine_eval->evaluate()); 00082 00083 if (result->get_result_type() != CROSSVALIDATION_RESULT) 00084 SG_ERROR("Evaluation result is not of type CCrossValidationResult!"); 00085 00086 if (print_state) 00087 result->print_result(); 00088 00089 /* check if current result is better, delete old combinations */ 00090 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00091 { 00092 if (result->mean>best_result->mean) 00093 { 00094 if (best_combination) 00095 SG_UNREF(best_combination); 00096 00097 best_combination=(CParameterCombination*) 00098 combinations->get_element(i); 00099 00100 SG_UNREF(best_result); 00101 SG_REF(result); 00102 best_result = result; 00103 } 00104 else 00105 { 00106 CParameterCombination* combination=(CParameterCombination*) 00107 combinations->get_element(i); 00108 SG_UNREF(combination); 00109 } 00110 } 00111 else 00112 { 00113 if (result->mean<best_result->mean) 00114 { 00115 if (best_combination) 00116 SG_UNREF(best_combination); 00117 00118 best_combination=(CParameterCombination*) 00119 combinations->get_element(i); 00120 00121 SG_UNREF(best_result); 00122 SG_REF(result); 00123 best_result = result; 00124 } 00125 else 00126 { 00127 CParameterCombination* combination=(CParameterCombination*) 00128 combinations->get_element(i); 00129 SG_UNREF(combination); 00130 } 00131 } 00132 00133 SG_UNREF(result); 00134 SG_UNREF(current_combination); 00135 } 00136 00137 SG_UNREF(best_result); 00138 SG_UNREF(machine); 00139 SG_UNREF(combinations); 00140 00141 return best_combination; 00142 } 00143