CPLEXSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/CPLEXSVM.h"
00012 #include "lib/common.h"
00013 
00014 #ifdef USE_CPLEX
00015 #include "lib/io.h"
00016 #include "lib/Mathematics.h"
00017 #include "lib/Cplex.h"
00018 #include "features/Labels.h"
00019 
00020 using namespace shogun;
00021 
00022 CCPLEXSVM::CCPLEXSVM()
00023 : CSVM()
00024 {
00025 }
00026 
00027 CCPLEXSVM::~CCPLEXSVM()
00028 {
00029 }
00030 
00031 bool CCPLEXSVM::train(CFeatures* data)
00032 {
00033     bool result = false;
00034     CCplex cplex;
00035 
00036     if (data)
00037     {
00038         if (labels->get_num_labels() != data->get_num_vectors())
00039             SG_ERROR("Number of training vectors does not match number of labels\n");
00040         kernel->init(data, data);
00041     }
00042 
00043     if (cplex.init(E_QP))
00044     {
00045         int32_t n,m;
00046         int32_t num_label=0;
00047         float64_t* y = labels->get_labels(num_label);
00048         float64_t* H = kernel->get_kernel_matrix<float64_t>(m, n, NULL);
00049         ASSERT(n>0 && n==m && n==num_label);
00050         float64_t* alphas=new float64_t[n];
00051         float64_t* lb=new float64_t[n];
00052         float64_t* ub=new float64_t[n];
00053 
00054         //hessian y'y.*K
00055         for (int32_t i=0; i<n; i++)
00056         {
00057             lb[i]=0;
00058             ub[i]=get_C1();
00059 
00060             for (int32_t j=0; j<n; j++)
00061                 H[i*n+j]*=y[j]*y[i];
00062         }
00063 
00064         //feed qp to cplex
00065 
00066 
00067         int32_t j=0;
00068         for (int32_t i=0; i<n; i++)
00069         {
00070             if (alphas[i]>0)
00071             {
00072                 //set_alpha(j, alphas[i]*labels->get_label(i)/etas[1]);
00073                 set_alpha(j, alphas[i]*labels->get_label(i));
00074                 set_support_vector(j, i);
00075                 j++;
00076             }
00077         }
00078         //compute_objective();
00079         SG_INFO( "obj = %.16f, rho = %.16f\n",get_objective(),get_bias());
00080         SG_INFO( "Number of SV: %ld\n", get_num_support_vectors());
00081 
00082         delete[] alphas;
00083         delete[] lb;
00084         delete[] ub;
00085         delete[] H;
00086         delete[] y;
00087 
00088         result = true;
00089     }
00090 
00091     if (!result)
00092         SG_ERROR( "cplex svm failed");
00093 
00094     return result;
00095 }
00096 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation