Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/GPBTSVM.h>
00012 #include <shogun/lib/external/gpdt.h>
00013 #include <shogun/lib/external/gpdtsolve.h>
00014 #include <shogun/io/SGIO.h>
00015 #include <shogun/labels/BinaryLabels.h>
00016
00017 using namespace shogun;
00018
00019 CGPBTSVM::CGPBTSVM()
00020 : CSVM(), model(NULL)
00021 {
00022 }
00023
00024 CGPBTSVM::CGPBTSVM(float64_t C, CKernel* k, CLabels* lab)
00025 : CSVM(C, k, lab), model(NULL)
00026 {
00027 }
00028
00029 CGPBTSVM::~CGPBTSVM()
00030 {
00031 SG_FREE(model);
00032 }
00033
00034 bool CGPBTSVM::train_machine(CFeatures* data)
00035 {
00036 float64_t* solution;
00037 QPproblem prob;
00038
00039 ASSERT(kernel);
00040 ASSERT(m_labels && m_labels->get_num_labels());
00041 ASSERT(m_labels->get_label_type() == LT_BINARY);
00042 if (data)
00043 {
00044 if (m_labels->get_num_labels() != data->get_num_vectors())
00045 SG_ERROR("Number of training vectors does not match number of labels\n");
00046 kernel->init(data, data);
00047 }
00048
00049 SGVector<int32_t> lab=((CBinaryLabels*) m_labels)->get_int_labels();
00050 prob.KER=new sKernel(kernel, lab.vlen);
00051 prob.y=lab.vector;
00052 prob.ell=lab.vlen;
00053 SG_INFO( "%d trainlabels\n", prob.ell);
00054
00055
00056 prob.delta = epsilon;
00057 prob.maxmw = kernel->get_cache_size();
00058 prob.verbosity = 0;
00059 prob.preprocess_size = -1;
00060 prob.projection_projector = -1;
00061 prob.c_const = get_C1();
00062 prob.chunk_size = get_qpsize();
00063 prob.linadd = get_linadd_enabled();
00064
00065 if (prob.chunk_size < 2) prob.chunk_size = 2;
00066 if (prob.q <= 0) prob.q = prob.chunk_size / 3;
00067 if (prob.q < 2) prob.q = 2;
00068 if (prob.q > prob.chunk_size) prob.q = prob.chunk_size;
00069 prob.q = prob.q & (~1);
00070 if (prob.maxmw < 5)
00071 prob.maxmw = 5;
00072
00073
00074 SG_INFO( "\nTRAINING PARAMETERS:\n");
00075 SG_INFO( "\tNumber of training documents: %d\n", prob.ell);
00076 SG_INFO( "\tq: %d\n", prob.chunk_size);
00077 SG_INFO( "\tn: %d\n", prob.q);
00078 SG_INFO( "\tC: %lf\n", prob.c_const);
00079 SG_INFO( "\tkernel type: %d\n", prob.ker_type);
00080 SG_INFO( "\tcache size: %dMb\n", prob.maxmw);
00081 SG_INFO( "\tStopping tolerance: %lf\n", prob.delta);
00082
00083
00084 if (prob.preprocess_size == -1)
00085 prob.preprocess_size = (int32_t) ( (float64_t)prob.chunk_size * 1.5 );
00086
00087 if (prob.projection_projector == -1)
00088 {
00089 if (prob.chunk_size <= 20) prob.projection_projector = 0;
00090 else prob.projection_projector = 1;
00091 }
00092
00093
00094 solution = SG_MALLOC(float64_t, prob.ell);
00095 prob.gpdtsolve(solution);
00096
00097
00098 CSVM::set_objective(prob.objective_value);
00099
00100 int32_t num_sv=0;
00101 int32_t bsv=0;
00102 int32_t i=0;
00103 int32_t k=0;
00104
00105 for (i = 0; i < prob.ell; i++)
00106 {
00107 if (solution[i] > prob.DELTAsv)
00108 {
00109 num_sv++;
00110 if (solution[i] > (prob.c_const - prob.DELTAsv)) bsv++;
00111 }
00112 }
00113
00114 create_new_model(num_sv);
00115 set_bias(prob.bee);
00116
00117 SG_INFO("SV: %d BSV = %d\n", num_sv, bsv);
00118
00119 for (i = 0; i < prob.ell; i++)
00120 {
00121 if (solution[i] > prob.DELTAsv)
00122 {
00123 set_support_vector(k, i);
00124 set_alpha(k++, solution[i]*((CBinaryLabels*) m_labels)->get_label(i));
00125 }
00126 }
00127
00128 delete prob.KER;
00129 SG_FREE(solution);
00130
00131 return true;
00132 }