GridSearchModelSelection.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2011-2012 Heiko Strathmann
00008  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/modelselection/GridSearchModelSelection.h>
00012 #include <shogun/modelselection/ParameterCombination.h>
00013 #include <shogun/modelselection/ModelSelectionParameters.h>
00014 #include <shogun/evaluation/CrossValidation.h>
00015 #include <shogun/machine/Machine.h>
00016 
00017 using namespace shogun;
00018 
00019 CGridSearchModelSelection::CGridSearchModelSelection() :
00020     CModelSelection(NULL, NULL)
00021 {
00022 
00023 }
00024 
00025 CGridSearchModelSelection::CGridSearchModelSelection(
00026         CModelSelectionParameters* model_parameters,
00027         CMachineEvaluation* machine_eval) :
00028     CModelSelection(model_parameters, machine_eval)
00029 {
00030 
00031 }
00032 
00033 CGridSearchModelSelection::~CGridSearchModelSelection()
00034 {
00035 }
00036 
00037 CParameterCombination* CGridSearchModelSelection::select_model(bool print_state)
00038 {
00039     if (print_state)
00040         SG_PRINT("Generating parameter combinations\n");
00041 
00042     /* Retrieve all possible parameter combinations */
00043     CDynamicObjectArray* combinations=
00044             (CDynamicObjectArray*)m_model_parameters->get_combinations();
00045 
00046     CCrossValidationResult* best_result = new CCrossValidationResult();
00047 
00048     CParameterCombination* best_combination=NULL;
00049     if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00050     {
00051         if (print_state) SG_PRINT("Direction is maximize\n");
00052         best_result->mean=CMath::ALMOST_NEG_INFTY;
00053     }
00054     else
00055     {
00056         if (print_state) SG_PRINT("Direction is maximize\n");
00057         best_result->mean=CMath::ALMOST_INFTY;
00058     }
00059 
00060     /* underlying learning machine */
00061     CMachine* machine=m_machine_eval->get_machine();
00062 
00063     /* apply all combinations and search for best one */
00064     for (index_t i=0; i<combinations->get_num_elements(); ++i)
00065     {
00066         CParameterCombination* current_combination=(CParameterCombination*)
00067                 combinations->get_element(i);
00068 
00069         /* eventually print */
00070         if (print_state)
00071         {
00072             SG_PRINT("trying combination:\n");
00073             current_combination->print_tree();
00074         }
00075 
00076         current_combination->apply_to_modsel_parameter(
00077                 machine->m_model_selection_parameters);
00078 
00079         /* note that this may implicitly lock and unlockthe machine */
00080         CCrossValidationResult* result =
00081                 (CCrossValidationResult*)(m_machine_eval->evaluate());
00082 
00083         if (result->get_result_type() != CROSSVALIDATION_RESULT)
00084             SG_ERROR("Evaluation result is not of type CCrossValidationResult!");
00085 
00086         if (print_state)
00087             result->print_result();
00088 
00089         /* check if current result is better, delete old combinations */
00090         if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00091         {
00092             if (result->mean>best_result->mean)
00093             {
00094                 if (best_combination)
00095                     SG_UNREF(best_combination);
00096 
00097                 best_combination=(CParameterCombination*)
00098                         combinations->get_element(i);
00099 
00100                 SG_UNREF(best_result);
00101                 SG_REF(result);
00102                 best_result = result;
00103             }
00104             else
00105             {
00106                 CParameterCombination* combination=(CParameterCombination*)
00107                         combinations->get_element(i);
00108                 SG_UNREF(combination);
00109             }
00110         }
00111         else
00112         {
00113             if (result->mean<best_result->mean)
00114             {
00115                 if (best_combination)
00116                     SG_UNREF(best_combination);
00117 
00118                 best_combination=(CParameterCombination*)
00119                         combinations->get_element(i);
00120 
00121                 SG_UNREF(best_result);
00122                 SG_REF(result);
00123                 best_result = result;
00124             }
00125             else
00126             {
00127                 CParameterCombination* combination=(CParameterCombination*)
00128                         combinations->get_element(i);
00129                 SG_UNREF(combination);
00130             }
00131         }
00132 
00133         SG_UNREF(result);
00134         SG_UNREF(current_combination);
00135     }
00136 
00137     SG_UNREF(best_result);
00138     SG_UNREF(machine);
00139     SG_UNREF(combinations);
00140 
00141     return best_combination;
00142 }
00143 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation