CrossValidationMKLStorage.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2012 Sergey Lisitsyn
00008  * Written (W) 2012 Heiko Strathmann
00009  */
00010 
00011 #include <shogun/evaluation/CrossValidationMKLStorage.h>
00012 #include <shogun/kernel/CombinedKernel.h>
00013 #include <shogun/classifier/mkl/MKL.h>
00014 #include <shogun/classifier/mkl/MKLMulticlass.h>
00015 
00016 using namespace shogun;
00017 
00018 void CCrossValidationMKLStorage::update_trained_machine(
00019         CMachine* machine, const char* prefix)
00020 {
00021     REQUIRE(machine, "%s::update_trained_machine(): Provided Machine is NULL!\n",
00022             get_name());
00023 
00024     CMKL* mkl=dynamic_cast<CMKL*>(machine);
00025     CMKLMulticlass* mkl_multiclass=dynamic_cast<CMKLMulticlass*>(machine);
00026     REQUIRE(mkl || mkl_multiclass, "%s::update_trained_machine(): This method is only usable "
00027                 "with CMKL derived machines. This one is \"%s\"\n", get_name(),
00028                 machine->get_name());
00029 
00030     CKernel* kernel = NULL;
00031     if (mkl) 
00032         kernel = mkl->get_kernel();
00033     else
00034         kernel = mkl_multiclass->get_kernel();
00035 
00036     REQUIRE(kernel, "%s::update_trained_machine(): No kernel assigned to "
00037             "machine of type \"%s\"\n", get_name(), machine->get_name());
00038 
00039     CCombinedKernel* combined_kernel=dynamic_cast<CCombinedKernel*>(kernel);
00040     REQUIRE(combined_kernel, "%s::update_trained_machine(): This method is only"
00041             " usable with CCombinedKernel on machines. This one is \"s\"\n",
00042             get_name(), kernel->get_name());
00043 
00044     SGVector<float64_t> w=combined_kernel->get_subkernel_weights();
00045 
00046     /* evtl re-allocate memory (different number of runs from evaluation before) */
00047     if (m_mkl_weights.num_rows!=w.vlen ||
00048             m_mkl_weights.num_cols!=m_num_folds*m_num_runs)
00049     {
00050         if (m_mkl_weights.matrix)
00051         {
00052             SG_DEBUG("deleting memory for mkl weight matrix\n");
00053             m_mkl_weights=SGMatrix<float64_t>();
00054         }
00055     }
00056 
00057     /* evtl allocate memory (first call) */
00058     if (!m_mkl_weights.matrix)
00059     {
00060         SG_DEBUG("allocating memory for mkl weight matrix\n");
00061         m_mkl_weights=SGMatrix<float64_t>(w.vlen,m_num_folds*m_num_runs);
00062     }
00063 
00064     /* put current mkl weights into matrix, copy memory vector wise to make
00065      * things fast. Compute index of address to where vector goes */
00066 
00067     /* number of runs is w.vlen*m_num_folds shift */
00068     index_t run_shift=m_current_run_index*w.vlen*m_num_folds;
00069 
00070     /* fold shift is m_current_fold_index*w-vlen */
00071     index_t fold_shift=m_current_fold_index*w.vlen;
00072 
00073     /* add both index shifts */
00074     index_t first_idx=run_shift+fold_shift;
00075     SG_DEBUG("run %d, fold %d, matrix index %d\n",m_current_run_index,
00076             m_current_fold_index, first_idx);
00077 
00078     /* copy memory */
00079     memcpy(&m_mkl_weights.matrix[first_idx], w.vector,
00080             w.vlen*sizeof(float64_t));
00081 
00082     SG_UNREF(kernel);
00083 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation