00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2011 Heiko Strathmann 00008 * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society 00009 */ 00010 00011 #include <shogun/modelselection/GridSearchModelSelection.h> 00012 #include <shogun/modelselection/ParameterCombination.h> 00013 #include <shogun/modelselection/ModelSelectionParameters.h> 00014 #include <shogun/evaluation/CrossValidation.h> 00015 #include <shogun/machine/Machine.h> 00016 00017 using namespace shogun; 00018 00019 CGridSearchModelSelection::CGridSearchModelSelection() : 00020 CModelSelection(NULL, NULL) 00021 { 00022 00023 } 00024 00025 CGridSearchModelSelection::CGridSearchModelSelection( 00026 CModelSelectionParameters* model_parameters, 00027 CCrossValidation* cross_validation) : 00028 CModelSelection(model_parameters, cross_validation) 00029 { 00030 00031 } 00032 00033 CGridSearchModelSelection::~CGridSearchModelSelection() 00034 { 00035 } 00036 00037 CParameterCombination* CGridSearchModelSelection::select_model() 00038 { 00039 /* Retrieve all possible parameter combinations */ 00040 CDynamicObjectArray<CParameterCombination>* combinations= 00041 m_model_parameters->get_combinations(); 00042 00043 CrossValidationResult best_result; 00044 00045 CParameterCombination* best_combination=NULL; 00046 if (m_cross_validation->get_evaluation_direction()==ED_MAXIMIZE) 00047 best_result.mean=CMath::ALMOST_NEG_INFTY; 00048 else 00049 best_result.mean=CMath::ALMOST_INFTY; 00050 00051 /* underlying learning machine */ 00052 CMachine* machine=m_cross_validation->get_machine(); 00053 00054 /* apply all combinations and search for best one */ 00055 for (index_t i=0; i<combinations->get_num_elements(); ++i) 00056 { 00057 CParameterCombination* current_combination=combinations->get_element(i); 00058 current_combination->apply_to_modsel_parameter( 00059 machine->m_model_selection_parameters); 00060 CrossValidationResult result=m_cross_validation->evaluate(); 00061 00062 /* check if current result is better, delete old combinations */ 00063 if (m_cross_validation->get_evaluation_direction()==ED_MAXIMIZE) 00064 { 00065 if (result.mean>best_result.mean) 00066 { 00067 if (best_combination) 00068 SG_UNREF(best_combination); 00069 00070 best_combination=combinations->get_element(i); 00071 best_result=result; 00072 } 00073 else 00074 { 00075 CParameterCombination* combination=combinations->get_element(i); 00076 SG_UNREF(combination); 00077 } 00078 } 00079 else 00080 { 00081 if (result.mean<best_result.mean) 00082 { 00083 if (best_combination) 00084 SG_UNREF(best_combination); 00085 00086 best_combination=combinations->get_element(i); 00087 best_result=result; 00088 } 00089 else 00090 { 00091 CParameterCombination* combination=combinations->get_element(i); 00092 SG_UNREF(combination); 00093 } 00094 } 00095 00096 SG_UNREF(current_combination); 00097 } 00098 00099 SG_UNREF(machine); 00100 SG_UNREF(combinations); 00101 00102 return best_combination; 00103 } 00104