Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/classifier/svm/CPLEXSVM.h>
00012 #include <shogun/lib/common.h>
00013
00014 #ifdef USE_CPLEX
00015 #include <shogun/io/SGIO.h>
00016 #include <shogun/mathematics/Math.h>
00017 #include <shogun/mathematics/Cplex.h>
00018 #include <shogun/labels/Labels.h>
00019
00020 using namespace shogun;
00021
00022 CCPLEXSVM::CCPLEXSVM()
00023 : CSVM()
00024 {
00025 }
00026
00027 CCPLEXSVM::~CCPLEXSVM()
00028 {
00029 }
00030
00031 bool CCPLEXSVM::train_machine(CFeatures* data)
00032 {
00033 ASSERT(m_labels);
00034 ASSERT(m_labels->get_label_type() == LT_BINARY);
00035
00036 bool result = false;
00037 CCplex cplex;
00038
00039 if (data)
00040 {
00041 if (m_labels->get_num_labels() != data->get_num_vectors())
00042 {
00043 SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00044 " not match number of labels (%d)\n", get_name(),
00045 data->get_num_vectors(), m_labels->get_num_labels());
00046 }
00047 kernel->init(data, data);
00048 }
00049
00050 if (cplex.init(E_QP))
00051 {
00052 int32_t n,m;
00053 int32_t num_label=0;
00054 SGVector<float64_t> y=((CBinaryLabels*)m_labels)->get_labels();
00055 SGMatrix<float64_t> H=kernel->get_kernel_matrix();
00056 m=H.num_rows;
00057 n=H.num_cols;
00058 ASSERT(n>0 && n==m && n==num_label);
00059 float64_t* alphas=SG_MALLOC(float64_t, n);
00060 float64_t* lb=SG_MALLOC(float64_t, n);
00061 float64_t* ub=SG_MALLOC(float64_t, n);
00062
00063
00064 for (int32_t i=0; i<n; i++)
00065 {
00066 lb[i]=0;
00067 ub[i]=get_C1();
00068
00069 for (int32_t j=0; j<n; j++)
00070 H[i*n+j]*=y[j]*y[i];
00071 }
00072
00073
00074
00075
00076 int32_t j=0;
00077 for (int32_t i=0; i<n; i++)
00078 {
00079 if (alphas[i]>0)
00080 {
00081
00082 set_alpha(j, alphas[i]*((CBinaryLabels*) m_labels)->get_int_label(i));
00083 set_support_vector(j, i);
00084 j++;
00085 }
00086 }
00087
00088 SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00089 SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00090
00091 SG_FREE(alphas);
00092 SG_FREE(lb);
00093 SG_FREE(ub);
00094
00095 result = true;
00096 }
00097
00098 if (!result)
00099 SG_ERROR( "cplex svm failed");
00100
00101 return result;
00102 }
00103 #endif