00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Copyright (C) 2011 Heiko Strathmann 00008 * Copyright (C) 2012 Sergey Lisitsyn 00009 */ 00010 00011 #include <shogun/modelselection/RandomSearchModelSelection.h> 00012 #include <shogun/modelselection/ParameterCombination.h> 00013 #include <shogun/modelselection/ModelSelectionParameters.h> 00014 #include <shogun/evaluation/CrossValidation.h> 00015 #include <shogun/mathematics/Statistics.h> 00016 #include <shogun/machine/Machine.h> 00017 00018 using namespace shogun; 00019 00020 CRandomSearchModelSelection::CRandomSearchModelSelection() : 00021 CModelSelection(NULL, NULL) 00022 { 00023 set_ratio(0.5); 00024 } 00025 00026 CRandomSearchModelSelection::CRandomSearchModelSelection( 00027 CModelSelectionParameters* model_parameters, 00028 CMachineEvaluation* machine_eval, float64_t ratio) : 00029 CModelSelection(model_parameters, machine_eval) 00030 { 00031 set_ratio(ratio); 00032 } 00033 00034 CRandomSearchModelSelection::~CRandomSearchModelSelection() 00035 { 00036 } 00037 00038 CParameterCombination* CRandomSearchModelSelection::select_model(bool print_state) 00039 { 00040 if (print_state) 00041 SG_PRINT("Generating parameter combinations\n"); 00042 00043 /* Retrieve all possible parameter combinations */ 00044 CDynamicObjectArray* all_combinations= 00045 (CDynamicObjectArray*)m_model_parameters->get_combinations(); 00046 00047 int32_t n_all_combinations = all_combinations->get_num_elements(); 00048 SGVector<index_t> combinations_indices = CStatistics::sample_indices(n_all_combinations*m_ratio, n_all_combinations); 00049 00050 CDynamicObjectArray* combinations = new CDynamicObjectArray(); 00051 00052 for (int32_t i=0; i<combinations_indices.vlen; i++) 00053 combinations->append_element(all_combinations->get_element(i)); 00054 00055 CCrossValidationResult* best_result = new CCrossValidationResult(); 00056 00057 CParameterCombination* best_combination=NULL; 00058 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00059 { 00060 if (print_state) SG_PRINT("Direction is maximize\n"); 00061 best_result->mean=CMath::ALMOST_NEG_INFTY; 00062 } 00063 else 00064 { 00065 if (print_state) SG_PRINT("Direction is maximize\n"); 00066 best_result->mean=CMath::ALMOST_INFTY; 00067 } 00068 00069 /* underlying learning machine */ 00070 CMachine* machine=m_machine_eval->get_machine(); 00071 00072 /* apply all combinations and search for best one */ 00073 for (index_t i=0; i<combinations->get_num_elements(); ++i) 00074 { 00075 CParameterCombination* current_combination=(CParameterCombination*) 00076 combinations->get_element(i); 00077 00078 /* eventually print */ 00079 if (print_state) 00080 { 00081 SG_PRINT("trying combination:\n"); 00082 current_combination->print_tree(); 00083 } 00084 00085 current_combination->apply_to_modsel_parameter( 00086 machine->m_model_selection_parameters); 00087 00088 /* note that this may implicitly lock and unlockthe machine */ 00089 CCrossValidationResult* result = 00090 (CCrossValidationResult*)(m_machine_eval->evaluate()); 00091 00092 if (result->get_result_type() != CROSSVALIDATION_RESULT) 00093 SG_ERROR("Evaluation result is not of type CCrossValidationResult!"); 00094 00095 if (print_state) 00096 result->print_result(); 00097 00098 /* check if current result is better, delete old combinations */ 00099 if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE) 00100 { 00101 if (result->mean>best_result->mean) 00102 { 00103 if (best_combination) 00104 SG_UNREF(best_combination); 00105 00106 best_combination=(CParameterCombination*) 00107 combinations->get_element(i); 00108 00109 SG_UNREF(best_result); 00110 SG_REF(result); 00111 best_result = result; 00112 } 00113 else 00114 { 00115 CParameterCombination* combination=(CParameterCombination*) 00116 combinations->get_element(i); 00117 SG_UNREF(combination); 00118 } 00119 } 00120 else 00121 { 00122 if (result->mean<best_result->mean) 00123 { 00124 if (best_combination) 00125 SG_UNREF(best_combination); 00126 00127 best_combination=(CParameterCombination*) 00128 combinations->get_element(i); 00129 00130 SG_UNREF(best_result); 00131 SG_REF(result); 00132 best_result = result; 00133 } 00134 else 00135 { 00136 CParameterCombination* combination=(CParameterCombination*) 00137 combinations->get_element(i); 00138 SG_UNREF(combination); 00139 } 00140 } 00141 00142 SG_UNREF(result); 00143 SG_UNREF(current_combination); 00144 } 00145 00146 SG_UNREF(best_result); 00147 SG_UNREF(machine); 00148 SG_UNREF(combinations); 00149 00150 return best_combination; 00151 } 00152