GPBTSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/GPBTSVM.h>
00012 #include <shogun/lib/external/gpdt.h>
00013 #include <shogun/lib/external/gpdtsolve.h>
00014 #include <shogun/io/SGIO.h>
00015 #include <shogun/labels/BinaryLabels.h>
00016 
00017 using namespace shogun;
00018 
00019 CGPBTSVM::CGPBTSVM()
00020 : CSVM(), model(NULL)
00021 {
00022 }
00023 
00024 CGPBTSVM::CGPBTSVM(float64_t C, CKernel* k, CLabels* lab)
00025 : CSVM(C, k, lab), model(NULL)
00026 {
00027 }
00028 
00029 CGPBTSVM::~CGPBTSVM()
00030 {
00031     SG_FREE(model);
00032 }
00033 
00034 bool CGPBTSVM::train_machine(CFeatures* data)
00035 {
00036     float64_t* solution;                     /* store the solution found       */
00037     QPproblem prob;                          /* object containing the solvers  */
00038 
00039     ASSERT(kernel);
00040     ASSERT(m_labels && m_labels->get_num_labels());
00041     ASSERT(m_labels->get_label_type() == LT_BINARY);
00042     if (data)
00043     {
00044         if (m_labels->get_num_labels() != data->get_num_vectors())
00045             SG_ERROR("Number of training vectors does not match number of labels\n");
00046         kernel->init(data, data);
00047     }
00048 
00049     SGVector<int32_t> lab=((CBinaryLabels*) m_labels)->get_int_labels();
00050     prob.KER=new sKernel(kernel, lab.vlen);
00051     prob.y=lab.vector;
00052     prob.ell=lab.vlen;
00053     SG_INFO( "%d trainlabels\n", prob.ell);
00054 
00055     //  /*** set options defaults ***/
00056     prob.delta = epsilon;
00057     prob.maxmw = kernel->get_cache_size();
00058     prob.verbosity       = 0;
00059     prob.preprocess_size = -1;
00060     prob.projection_projector = -1;
00061     prob.c_const = get_C1();
00062     prob.chunk_size = get_qpsize();
00063     prob.linadd = get_linadd_enabled();
00064 
00065     if (prob.chunk_size < 2)      prob.chunk_size = 2;
00066     if (prob.q <= 0)              prob.q = prob.chunk_size / 3;
00067     if (prob.q < 2)               prob.q = 2;
00068     if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
00069     prob.q = prob.q & (~1);
00070     if (prob.maxmw < 5)
00071         prob.maxmw = 5;
00072 
00073     /*** set the problem description for final report ***/
00074     SG_INFO( "\nTRAINING PARAMETERS:\n");
00075     SG_INFO( "\tNumber of training documents: %d\n", prob.ell);
00076     SG_INFO( "\tq: %d\n", prob.chunk_size);
00077     SG_INFO( "\tn: %d\n", prob.q);
00078     SG_INFO( "\tC: %lf\n", prob.c_const);
00079     SG_INFO( "\tkernel type: %d\n", prob.ker_type);
00080     SG_INFO( "\tcache size: %dMb\n", prob.maxmw);
00081     SG_INFO( "\tStopping tolerance: %lf\n", prob.delta);
00082 
00083     //  /*** compute the number of cache rows up to maxmw Mb. ***/
00084     if (prob.preprocess_size == -1)
00085         prob.preprocess_size = (int32_t) ( (float64_t)prob.chunk_size * 1.5 );
00086 
00087     if (prob.projection_projector == -1)
00088     {
00089         if (prob.chunk_size <= 20) prob.projection_projector = 0;
00090         else prob.projection_projector = 1;
00091     }
00092 
00093     /*** compute the problem solution *******************************************/
00094     solution = SG_MALLOC(float64_t, prob.ell);
00095     prob.gpdtsolve(solution);
00096     /****************************************************************************/
00097 
00098     CSVM::set_objective(prob.objective_value);
00099 
00100     int32_t num_sv=0;
00101     int32_t bsv=0;
00102     int32_t i=0;
00103     int32_t k=0;
00104 
00105     for (i = 0; i < prob.ell; i++)
00106     {
00107         if (solution[i] > prob.DELTAsv)
00108         {
00109             num_sv++;
00110             if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
00111         }
00112     }
00113 
00114     create_new_model(num_sv);
00115     set_bias(prob.bee);
00116 
00117     SG_INFO("SV: %d BSV = %d\n", num_sv, bsv);
00118 
00119     for (i = 0; i < prob.ell; i++)
00120     {
00121         if (solution[i] > prob.DELTAsv)
00122         {
00123             set_support_vector(k, i);
00124             set_alpha(k++, solution[i]*((CBinaryLabels*) m_labels)->get_label(i));
00125         }
00126     }
00127 
00128     delete prob.KER;
00129     SG_FREE(solution);
00130 
00131     return true;
00132 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation