GNPPSVM.cpp

Go to the documentation of this file.
00001 /* 
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Vojtech Franc, xfrancv@cmp.felk.cvut.cz
00008  * Copyright (C) 1999-2008 Center for Machine Perception, CTU FEL Prague 
00009  */
00010 
00011 #include <shogun/io/SGIO.h>
00012 #include <shogun/classifier/svm/GNPPSVM.h>
00013 #include <shogun/classifier/svm/GNPPLib.h>
00014 #include <shogun/labels/BinaryLabels.h>
00015 
00016 using namespace shogun;
00017 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW)) 
00018 
00019 CGNPPSVM::CGNPPSVM()
00020 : CSVM()
00021 {
00022 }
00023 
00024 CGNPPSVM::CGNPPSVM(float64_t C, CKernel* k, CLabels* lab)
00025 : CSVM(C, k, lab)
00026 {
00027 }
00028 
00029 CGNPPSVM::~CGNPPSVM()
00030 {
00031 }
00032 
00033 bool CGNPPSVM::train_machine(CFeatures* data)
00034 {
00035     ASSERT(kernel);
00036     ASSERT(m_labels && m_labels->get_num_labels());
00037     ASSERT(m_labels->get_label_type() == LT_BINARY);
00038 
00039     if (data)
00040     {
00041         if (m_labels->get_num_labels() != data->get_num_vectors())
00042             SG_ERROR("Number of training vectors does not match number of labels\n");
00043         kernel->init(data, data);
00044     }
00045 
00046     int32_t num_data=m_labels->get_num_labels();
00047     SG_INFO("%d trainlabels\n", num_data);
00048 
00049     float64_t* vector_y = SG_MALLOC(float64_t, num_data);
00050     for (int32_t i=0; i<num_data; i++)
00051     {
00052         float64_t lab=((CBinaryLabels*) m_labels)->get_label(i);
00053         if (lab==+1)
00054             vector_y[i]=1;
00055         else if (lab==-1)
00056             vector_y[i]=2;
00057         else
00058             SG_ERROR("label unknown (%f)\n", lab);
00059     }
00060 
00061     float64_t C=get_C1();
00062     int32_t tmax=1000000000;
00063     float64_t tolabs=0;
00064     float64_t tolrel=epsilon;
00065 
00066     float64_t reg_const=0;
00067     if (C!=0)
00068         reg_const=1/C;
00069 
00070     float64_t* diagK=SG_MALLOC(float64_t, num_data);
00071     for(int32_t i=0; i<num_data; i++) {
00072         diagK[i]=2*kernel->kernel(i,i)+reg_const;
00073     }
00074 
00075     float64_t* alpha=SG_MALLOC(float64_t, num_data);
00076     float64_t* vector_c=SG_MALLOC(float64_t, num_data);
00077     memset(vector_c, 0, num_data*sizeof(float64_t));
00078 
00079     float64_t thlb=10000000000.0;
00080     int32_t t=0;
00081     float64_t* History=NULL;
00082     int32_t verb=0;
00083     float64_t aHa11, aHa22;
00084 
00085     CGNPPLib npp(vector_y,kernel,num_data, reg_const);
00086 
00087     npp.gnpp_imdm(diagK, vector_c, vector_y, num_data, 
00088             tmax, tolabs, tolrel, thlb, alpha, &t, &aHa11, &aHa22, 
00089             &History, verb ); 
00090 
00091     int32_t num_sv = 0;
00092     float64_t nconst = History[INDEX(1,t,2)];
00093     float64_t trnerr = 0; /* counter of training error */
00094 
00095     for(int32_t i = 0; i < num_data; i++ )
00096     {
00097         if( alpha[i] != 0 ) num_sv++;
00098         if(vector_y[i] == 1) 
00099         {
00100             alpha[i] = alpha[i]*2/nconst;
00101             if( alpha[i]/(2*C) >= 1 ) trnerr++;
00102         }
00103         else
00104         {
00105             alpha[i] = -alpha[i]*2/nconst;
00106             if( alpha[i]/(2*C) <= -1 ) trnerr++;
00107         }
00108     }
00109 
00110     float64_t b = 0.5*(aHa22 - aHa11)/nconst;;
00111 
00112     create_new_model(num_sv);
00113     CSVM::set_objective(nconst);
00114 
00115     set_bias(b);
00116     int32_t j = 0;
00117     for (int32_t i=0; i<num_data; i++)
00118     {
00119         if( alpha[i] !=0)
00120         {
00121             set_support_vector(j, i);
00122             set_alpha(j, alpha[i]);
00123             j++;
00124         }
00125     }
00126 
00127     SG_FREE(vector_c);
00128     SG_FREE(alpha);
00129     SG_FREE(diagK);
00130     SG_FREE(vector_y);
00131     SG_FREE(History);
00132 
00133     return true;
00134 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation