RandomSearchModelSelection.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (C) 2011 Heiko Strathmann
00008  * Copyright (C) 2012 Sergey Lisitsyn
00009  */
00010 
00011 #include <shogun/modelselection/RandomSearchModelSelection.h>
00012 #include <shogun/modelselection/ParameterCombination.h>
00013 #include <shogun/modelselection/ModelSelectionParameters.h>
00014 #include <shogun/evaluation/CrossValidation.h>
00015 #include <shogun/mathematics/Statistics.h>
00016 #include <shogun/machine/Machine.h>
00017 
00018 using namespace shogun;
00019 
00020 CRandomSearchModelSelection::CRandomSearchModelSelection() :
00021     CModelSelection(NULL, NULL)
00022 {
00023     set_ratio(0.5);
00024 }
00025 
00026 CRandomSearchModelSelection::CRandomSearchModelSelection(
00027         CModelSelectionParameters* model_parameters,
00028         CMachineEvaluation* machine_eval, float64_t ratio) :
00029     CModelSelection(model_parameters, machine_eval)
00030 {
00031     set_ratio(ratio);
00032 }
00033 
00034 CRandomSearchModelSelection::~CRandomSearchModelSelection()
00035 {
00036 }
00037 
00038 CParameterCombination* CRandomSearchModelSelection::select_model(bool print_state)
00039 {
00040     if (print_state)
00041         SG_PRINT("Generating parameter combinations\n");
00042 
00043     /* Retrieve all possible parameter combinations */
00044     CDynamicObjectArray* all_combinations=
00045             (CDynamicObjectArray*)m_model_parameters->get_combinations();
00046 
00047     int32_t n_all_combinations = all_combinations->get_num_elements();
00048     SGVector<index_t> combinations_indices = CStatistics::sample_indices(n_all_combinations*m_ratio, n_all_combinations);
00049 
00050     CDynamicObjectArray* combinations = new CDynamicObjectArray();
00051 
00052     for (int32_t i=0; i<combinations_indices.vlen; i++)
00053         combinations->append_element(all_combinations->get_element(i));
00054 
00055     CCrossValidationResult* best_result = new CCrossValidationResult();
00056 
00057     CParameterCombination* best_combination=NULL;
00058     if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00059     {
00060         if (print_state) SG_PRINT("Direction is maximize\n");
00061         best_result->mean=CMath::ALMOST_NEG_INFTY;
00062     }
00063     else
00064     {
00065         if (print_state) SG_PRINT("Direction is maximize\n");
00066         best_result->mean=CMath::ALMOST_INFTY;
00067     }
00068 
00069     /* underlying learning machine */
00070     CMachine* machine=m_machine_eval->get_machine();
00071 
00072     /* apply all combinations and search for best one */
00073     for (index_t i=0; i<combinations->get_num_elements(); ++i)
00074     {
00075         CParameterCombination* current_combination=(CParameterCombination*)
00076                 combinations->get_element(i);
00077 
00078         /* eventually print */
00079         if (print_state)
00080         {
00081             SG_PRINT("trying combination:\n");
00082             current_combination->print_tree();
00083         }
00084 
00085         current_combination->apply_to_modsel_parameter(
00086                 machine->m_model_selection_parameters);
00087 
00088         /* note that this may implicitly lock and unlockthe machine */
00089         CCrossValidationResult* result =
00090                 (CCrossValidationResult*)(m_machine_eval->evaluate());
00091 
00092         if (result->get_result_type() != CROSSVALIDATION_RESULT)
00093             SG_ERROR("Evaluation result is not of type CCrossValidationResult!");
00094 
00095         if (print_state)
00096             result->print_result();
00097 
00098         /* check if current result is better, delete old combinations */
00099         if (m_machine_eval->get_evaluation_direction()==ED_MAXIMIZE)
00100         {
00101             if (result->mean>best_result->mean)
00102             {
00103                 if (best_combination)
00104                     SG_UNREF(best_combination);
00105 
00106                 best_combination=(CParameterCombination*)
00107                         combinations->get_element(i);
00108 
00109                 SG_UNREF(best_result);
00110                 SG_REF(result);
00111                 best_result = result;
00112             }
00113             else
00114             {
00115                 CParameterCombination* combination=(CParameterCombination*)
00116                         combinations->get_element(i);
00117                 SG_UNREF(combination);
00118             }
00119         }
00120         else
00121         {
00122             if (result->mean<best_result->mean)
00123             {
00124                 if (best_combination)
00125                     SG_UNREF(best_combination);
00126 
00127                 best_combination=(CParameterCombination*)
00128                         combinations->get_element(i);
00129 
00130                 SG_UNREF(best_result);
00131                 SG_REF(result);
00132                 best_result = result;
00133             }
00134             else
00135             {
00136                 CParameterCombination* combination=(CParameterCombination*)
00137                         combinations->get_element(i);
00138                 SG_UNREF(combination);
00139             }
00140         }
00141 
00142         SG_UNREF(result);
00143         SG_UNREF(current_combination);
00144     }
00145 
00146     SG_UNREF(best_result);
00147     SG_UNREF(machine);
00148     SG_UNREF(combinations);
00149 
00150     return best_combination;
00151 }
00152 
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation