GMNPSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2008 Vojtech Franc, xfrancv@cmp.felk.cvut.cz
00008  * Copyright (C) 1999-2008 Center for Machine Perception, CTU FEL Prague
00009  */
00010 
00011 #include <shogun/io/SGIO.h>
00012 #include <shogun/labels/MulticlassLabels.h>
00013 #include <shogun/multiclass/GMNPSVM.h>
00014 #include <shogun/multiclass/GMNPLib.h>
00015 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00016 
00017 #define INDEX(ROW,COL,DIM) (((COL)*(DIM))+(ROW))
00018 #define MINUS_INF INT_MIN
00019 #define PLUS_INF  INT_MAX
00020 #define KDELTA(A,B) (A==B)
00021 #define KDELTA4(A1,A2,A3,A4) ((A1==A2)||(A1==A3)||(A1==A4)||(A2==A3)||(A2==A4)||(A3==A4))
00022 
00023 using namespace shogun;
00024 
00025 CGMNPSVM::CGMNPSVM()
00026 : CMulticlassSVM(new CMulticlassOneVsRestStrategy())
00027 {
00028     init();
00029 }
00030 
00031 CGMNPSVM::CGMNPSVM(float64_t C, CKernel* k, CLabels* lab)
00032 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab)
00033 {
00034     init();
00035 }
00036 
00037 CGMNPSVM::~CGMNPSVM()
00038 {
00039     if (m_basealphas != NULL) SG_FREE(m_basealphas);
00040 }
00041 
00042 void
00043 CGMNPSVM::init()
00044 {
00045     m_parameters->add_matrix(&m_basealphas,
00046                              &m_basealphas_y, &m_basealphas_x,
00047                              "m_basealphas",
00048                              "Is the basic untransformed alpha.");
00049 
00050     m_basealphas = NULL, m_basealphas_y = 0, m_basealphas_x = 0;
00051 }
00052 
00053 bool CGMNPSVM::train_machine(CFeatures* data)
00054 {
00055     ASSERT(m_kernel);
00056     ASSERT(m_labels && m_labels->get_num_labels());
00057     ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00058 
00059     if (data)
00060     {
00061         if (m_labels->get_num_labels() != data->get_num_vectors())
00062         {
00063             SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00064                     " not match number of labels (%d)\n", get_name(),
00065                     data->get_num_vectors(), m_labels->get_num_labels());
00066         }
00067         m_kernel->init(data, data);
00068     }
00069 
00070     int32_t num_data = m_labels->get_num_labels();
00071     int32_t num_classes = m_multiclass_strategy->get_num_classes();
00072     int32_t num_virtual_data= num_data*(num_classes-1);
00073 
00074     SG_INFO( "%d trainlabels, %d classes\n", num_data, num_classes);
00075 
00076     float64_t* vector_y = SG_MALLOC(float64_t, num_data);
00077     for (int32_t i=0; i<num_data; i++)
00078     {
00079         vector_y[i] = ((CMulticlassLabels*) m_labels)->get_label(i)+1;
00080 
00081     }
00082 
00083     float64_t C = get_C();
00084     int32_t tmax = 1000000000;
00085     float64_t tolabs = 0;
00086     float64_t tolrel = get_epsilon();
00087 
00088     float64_t reg_const=0;
00089     if( C!=0 )
00090         reg_const = 1/(2*C);
00091 
00092 
00093     float64_t* alpha = SG_MALLOC(float64_t, num_virtual_data);
00094     float64_t* vector_c = SG_MALLOC(float64_t, num_virtual_data);
00095     memset(vector_c, 0, num_virtual_data*sizeof(float64_t));
00096 
00097     float64_t thlb = 10000000000.0;
00098     int32_t t = 0;
00099     float64_t* History = NULL;
00100     int32_t verb = 0;
00101 
00102     CGMNPLib mnp(vector_y,m_kernel,num_data, num_virtual_data, num_classes, reg_const);
00103 
00104     mnp.gmnp_imdm(vector_c, num_virtual_data, tmax,
00105                   tolabs, tolrel, thlb, alpha, &t, &History, verb);
00106 
00107     /* matrix alpha [num_classes x num_data] */
00108     float64_t* all_alphas= SG_MALLOC(float64_t, num_classes*num_data);
00109     memset(all_alphas,0,num_classes*num_data*sizeof(float64_t));
00110 
00111     /* bias vector b [num_classes x 1] */
00112     float64_t* all_bs=SG_MALLOC(float64_t, num_classes);
00113     memset(all_bs,0,num_classes*sizeof(float64_t));
00114 
00115     /* compute alpha/b from virt_data */
00116     for(int32_t i=0; i < num_classes; i++ )
00117     {
00118         for(int32_t j=0; j < num_virtual_data; j++ )
00119         {
00120             int32_t inx1=0;
00121             int32_t inx2=0;
00122 
00123             mnp.get_indices2( &inx1, &inx2, j );
00124 
00125             all_alphas[(inx1*num_classes)+i] +=
00126                 alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00127             all_bs[i] += alpha[j]*(KDELTA(vector_y[inx1],i+1)-KDELTA(i+1,inx2));
00128         }
00129     }
00130 
00131     create_multiclass_svm(num_classes);
00132 
00133     for (int32_t i=0; i<num_classes; i++)
00134     {
00135         int32_t num_sv=0;
00136         for (int32_t j=0; j<num_data; j++)
00137         {
00138             if (all_alphas[j*num_classes+i] != 0)
00139                 num_sv++;
00140         }
00141         ASSERT(num_sv>0);
00142         SG_DEBUG("svm[%d] has %d sv, b=%f\n", i, num_sv, all_bs[i]);
00143 
00144         CSVM* svm=new CSVM(num_sv);
00145 
00146         int32_t k=0;
00147         for (int32_t j=0; j<num_data; j++)
00148         {
00149             if (all_alphas[j*num_classes+i] != 0)
00150             {
00151                 svm->set_alpha(k, all_alphas[j*num_classes+i]);
00152                 svm->set_support_vector(k, j);
00153                 k++;
00154             }
00155         }
00156 
00157         svm->set_bias(all_bs[i]);
00158         set_svm(i, svm);
00159     }
00160 
00161     if (m_basealphas != NULL) SG_FREE(m_basealphas);
00162     m_basealphas_y = num_classes, m_basealphas_x = num_data;
00163     m_basealphas = SG_MALLOC(float64_t, m_basealphas_y*m_basealphas_x);
00164     for (index_t i=0; i<m_basealphas_y*m_basealphas_x; i++)
00165         m_basealphas[i] = 0.0;
00166 
00167     for(index_t j=0; j<num_virtual_data; j++)
00168     {
00169         index_t inx1=0, inx2=0;
00170 
00171         mnp.get_indices2(&inx1, &inx2, j);
00172         m_basealphas[inx1*m_basealphas_y + (inx2-1)] = alpha[j];
00173     }
00174 
00175     SG_FREE(vector_c);
00176     SG_FREE(alpha);
00177     SG_FREE(all_alphas);
00178     SG_FREE(all_bs);
00179     SG_FREE(vector_y);
00180     SG_FREE(History);
00181 
00182     return true;
00183 }
00184 
00185 float64_t*
00186 CGMNPSVM::get_basealphas_ptr(index_t* y, index_t* x)
00187 {
00188     if (y == NULL || x == NULL) return NULL;
00189 
00190     *y = m_basealphas_y, *x = m_basealphas_x;
00191     return m_basealphas;
00192 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation