MKLMulticlass.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Alexander Binder
00008  * Copyright (C) 2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00012 #include <shogun/classifier/mkl/MKLMulticlass.h>
00013 #include <shogun/io/SGIO.h>
00014 #include <shogun/labels/MulticlassLabels.h>
00015 
00016 using namespace shogun;
00017 
00018 
00019 CMKLMulticlass::CMKLMulticlass()
00020 : CMulticlassSVM(new CMulticlassOneVsRestStrategy())
00021 {
00022     svm=NULL;
00023     lpw=NULL;
00024 
00025     mkl_eps=0.01;
00026     max_num_mkl_iters=999;
00027     pnorm=1;
00028 }
00029 
00030 CMKLMulticlass::CMKLMulticlass(float64_t C, CKernel* k, CLabels* lab)
00031 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab)
00032 {
00033     svm=NULL;
00034     lpw=NULL;
00035 
00036     mkl_eps=0.01;
00037     max_num_mkl_iters=999;
00038     pnorm=1;
00039 }
00040 
00041 
00042 CMKLMulticlass::~CMKLMulticlass()
00043 {
00044     SG_UNREF(svm);
00045     svm=NULL;
00046     delete lpw;
00047     lpw=NULL;
00048 }
00049 
00050 CMKLMulticlass::CMKLMulticlass( const CMKLMulticlass & cm)
00051 : CMulticlassSVM(new CMulticlassOneVsRestStrategy())
00052 {
00053     svm=NULL;
00054     lpw=NULL;
00055     SG_ERROR(
00056             " CMKLMulticlass::CMKLMulticlass(const CMKLMulticlass & cm): must "
00057             "not be called, glpk structure is currently not copyable");
00058 }
00059 
00060 CMKLMulticlass CMKLMulticlass::operator=( const CMKLMulticlass & cm)
00061 {
00062         SG_ERROR(
00063             " CMKLMulticlass CMKLMulticlass::operator=(...): must "
00064             "not be called, glpk structure is currently not copyable");
00065     return (*this);
00066 }
00067 
00068 
00069 void CMKLMulticlass::initsvm()
00070 {
00071     if (!m_labels)
00072     {
00073         SG_ERROR("CMKLMulticlass::initsvm(): the set labels is NULL\n");
00074     }
00075 
00076     SG_UNREF(svm);
00077     svm=new CGMNPSVM;
00078     SG_REF(svm);
00079 
00080     svm->set_C(get_C());
00081     svm->set_epsilon(get_epsilon());
00082 
00083     if (m_labels->get_num_labels()<=0)
00084     {
00085         SG_ERROR("CMKLMulticlass::initsvm(): the number of labels is "
00086                 "nonpositive, do not know how to handle this!\n");
00087     }
00088 
00089     svm->set_labels(m_labels);
00090 }
00091 
00092 void CMKLMulticlass::initlpsolver()
00093 {
00094     if (!m_kernel)
00095     {
00096         SG_ERROR("CMKLMulticlass::initlpsolver(): the set kernel is NULL\n");
00097     }
00098 
00099     if (m_kernel->get_kernel_type()!=K_COMBINED)
00100     {
00101         SG_ERROR("CMKLMulticlass::initlpsolver(): given kernel is not of type"
00102                 " K_COMBINED %d required by Multiclass Mkl \n",
00103                 m_kernel->get_kernel_type());
00104     }
00105 
00106     int numker=dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();
00107 
00108     ASSERT(numker>0);
00109     /*
00110     if (lpw)
00111     {
00112         delete lpw;
00113     }
00114     */
00115 
00116     //lpw=new MKLMulticlassGLPK;
00117     if(pnorm>1)
00118     {
00119         lpw=new MKLMulticlassGradient;
00120         lpw->set_mkl_norm(pnorm);
00121     }
00122     else
00123     {
00124         lpw=new MKLMulticlassGLPK;
00125     }
00126     lpw->setup(numker);
00127 
00128 }
00129 
00130 
00131 bool CMKLMulticlass::evaluatefinishcriterion(const int32_t
00132         numberofsilpiterations)
00133 {
00134     if ( (max_num_mkl_iters>0) && (numberofsilpiterations>=max_num_mkl_iters) )
00135     {
00136         return(true);
00137     }
00138 
00139     if (weightshistory.size()>1)
00140     {
00141         std::vector<float64_t> wold,wnew;
00142 
00143         wold=weightshistory[ weightshistory.size()-2 ];
00144         wnew=weightshistory.back();
00145         float64_t delta=0;
00146 
00147         ASSERT (wold.size()==wnew.size());
00148 
00149 
00150         if((pnorm<=1)&&(!normweightssquared.empty()))
00151         {
00152 
00153             delta=0;
00154             for (size_t i=0;i< wnew.size();++i)
00155             {
00156                 delta+=(wold[i]-wnew[i])*(wold[i]-wnew[i]);
00157             }
00158             delta=sqrt(delta);
00159             SG_SDEBUG("L1 Norm chosen, weight delta %f \n",delta);
00160 
00161 
00162             //check dual gap part for mkl
00163             int32_t maxind=0;
00164             float64_t maxval=normweightssquared[maxind];
00165             delta=0;
00166             for (size_t i=0;i< wnew.size();++i)
00167             {
00168                 delta+=normweightssquared[i]*wnew[i];
00169                 if(wnew[i]>maxval)
00170                 {
00171                     maxind=i;
00172                     maxval=wnew[i];
00173                 }
00174             }
00175             delta-=normweightssquared[maxind];
00176             delta=fabs(delta);
00177             SG_SDEBUG("L1 Norm chosen, MKL part of duality gap %f \n",delta);
00178             if( (delta < mkl_eps) && (numberofsilpiterations>=1) )
00179             {
00180                 return(true);
00181             }
00182 
00183 
00184 
00185         }
00186         else
00187         {
00188             delta=0;
00189             for (size_t i=0;i< wnew.size();++i)
00190             {
00191                 delta+=(wold[i]-wnew[i])*(wold[i]-wnew[i]);
00192             }
00193             delta=sqrt(delta);
00194             SG_SDEBUG("Lp Norm chosen, weight delta %f \n",delta);
00195 
00196             if( (delta < mkl_eps) && (numberofsilpiterations>=1) )
00197             {
00198                 return(true);
00199             }
00200 
00201         }
00202     }
00203 
00204     return(false);
00205 }
00206 
00207 void CMKLMulticlass::addingweightsstep( const std::vector<float64_t> &
00208         curweights)
00209 {
00210 
00211     if (weightshistory.size()>2)
00212     {
00213         weightshistory.erase(weightshistory.begin());
00214     }
00215 
00216     SGVector<float64_t> weights(curweights.size());
00217     std::copy(curweights.begin(),curweights.end(),weights.vector);
00218 
00219     m_kernel->set_subkernel_weights(weights);
00220 
00221     initsvm();
00222 
00223     svm->set_kernel(m_kernel);
00224     svm->train();
00225 
00226     float64_t sumofsignfreealphas=getsumofsignfreealphas();
00227     int32_t numkernels=
00228             dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();
00229 
00230 
00231     normweightssquared.resize(numkernels);
00232     for (int32_t ind=0; ind < numkernels; ++ind )
00233     {
00234         normweightssquared[ind]=getsquarenormofprimalcoefficients( ind );
00235     }
00236 
00237     lpw->addconstraint(normweightssquared,sumofsignfreealphas);
00238 }
00239 
00240 float64_t CMKLMulticlass::getsumofsignfreealphas()
00241 {
00242     std::vector<int> trainlabels2(m_labels->get_num_labels());
00243     SGVector<int32_t> lab=((CMulticlassLabels*) m_labels)->get_int_labels();
00244     std::copy(lab.vector,lab.vector+lab.vlen, trainlabels2.begin());
00245 
00246     ASSERT (trainlabels2.size()>0);
00247     float64_t sum=0;
00248 
00249     for (int32_t nc=0; nc< ((CMulticlassLabels*) m_labels)->get_num_classes();++nc)
00250     {
00251         CSVM * sm=svm->get_svm(nc);
00252 
00253         float64_t bia=sm->get_bias();
00254         sum+= bia*bia;
00255 
00256         SG_UNREF(sm);
00257     }
00258 
00259     index_t basealphas_y = 0, basealphas_x = 0;
00260     float64_t* basealphas = svm->get_basealphas_ptr(&basealphas_y,
00261                                                     &basealphas_x);
00262 
00263     for (size_t lb=0; lb< trainlabels2.size();++lb)
00264     {
00265         for (int32_t nc=0; nc< ((CMulticlassLabels*) m_labels)->get_num_classes();++nc)
00266         {
00267             CSVM * sm=svm->get_svm(nc);
00268 
00269             if ((int)nc!=trainlabels2[lb])
00270             {
00271                 CSVM * sm2=svm->get_svm(trainlabels2[lb]);
00272 
00273                 float64_t bia1=sm2->get_bias();
00274                 float64_t bia2=sm->get_bias();
00275                 SG_UNREF(sm2);
00276 
00277                 sum+= -basealphas[lb*basealphas_y + nc]*(bia1-bia2-1);
00278             }
00279             SG_UNREF(sm);
00280         }
00281     }
00282 
00283     return(sum);
00284 }
00285 
00286 float64_t CMKLMulticlass::getsquarenormofprimalcoefficients(
00287         const int32_t ind)
00288 {
00289     CKernel * ker=dynamic_cast<CCombinedKernel *>(m_kernel)->get_kernel(ind);
00290 
00291     float64_t tmp=0;
00292 
00293     for (int32_t classindex=0; classindex< ((CMulticlassLabels*) m_labels)->get_num_classes();
00294             ++classindex)
00295     {
00296         CSVM * sm=svm->get_svm(classindex);
00297 
00298         for (int32_t i=0; i < sm->get_num_support_vectors(); ++i)
00299         {
00300             float64_t alphai=sm->get_alpha(i);
00301             int32_t svindi= sm->get_support_vector(i);
00302 
00303             for (int32_t k=0; k < sm->get_num_support_vectors(); ++k)
00304             {
00305                 float64_t alphak=sm->get_alpha(k);
00306                 int32_t svindk=sm->get_support_vector(k);
00307 
00308                 tmp+=alphai*ker->kernel(svindi,svindk)
00309                 *alphak;
00310 
00311             }
00312         }
00313         SG_UNREF(sm);
00314     }
00315     SG_UNREF(ker);
00316     ker=NULL;
00317 
00318     return(tmp);
00319 }
00320 
00321 
00322 bool CMKLMulticlass::train_machine(CFeatures* data)
00323 {
00324     ASSERT(m_kernel);
00325     ASSERT(m_labels && m_labels->get_num_labels());
00326     ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00327     int numcl=((CMulticlassLabels*) m_labels)->get_num_classes();
00328 
00329     if (data)
00330     {
00331         if (m_labels->get_num_labels() != data->get_num_vectors())
00332         {
00333             SG_ERROR("%s::train_machine(): Number of training vectors (%d) does"
00334                     " not match number of labels (%d)\n", get_name(),
00335                     data->get_num_vectors(), m_labels->get_num_labels());
00336         }
00337         m_kernel->init(data, data);
00338     }
00339 
00340     initlpsolver();
00341 
00342     weightshistory.clear();
00343 
00344     int32_t numkernels=
00345             dynamic_cast<CCombinedKernel *>(m_kernel)->get_num_subkernels();
00346 
00347     ::std::vector<float64_t> curweights(numkernels,1.0/numkernels);
00348     weightshistory.push_back(curweights);
00349 
00350     addingweightsstep(curweights);
00351 
00352     int32_t numberofsilpiterations=0;
00353     bool final=false;
00354     while (!final)
00355     {
00356 
00357         //curweights.clear();
00358         lpw->computeweights(curweights);
00359         weightshistory.push_back(curweights);
00360 
00361 
00362         final=evaluatefinishcriterion(numberofsilpiterations);
00363         ++numberofsilpiterations;
00364 
00365         addingweightsstep(curweights);
00366 
00367     } // while(false==final)
00368 
00369 
00370     //set alphas, bias, support vecs
00371     ASSERT(numcl>=1);
00372     create_multiclass_svm(numcl);
00373 
00374     for (int32_t i=0; i<numcl; i++)
00375     {
00376         CSVM* osvm=svm->get_svm(i);
00377         CSVM* nsvm=new CSVM(osvm->get_num_support_vectors());
00378 
00379         for (int32_t k=0; k<osvm->get_num_support_vectors() ; k++)
00380         {
00381             nsvm->set_alpha(k, osvm->get_alpha(k) );
00382             nsvm->set_support_vector(k,osvm->get_support_vector(k) );
00383         }
00384         nsvm->set_bias(osvm->get_bias() );
00385         set_svm(i, nsvm);
00386 
00387         SG_UNREF(osvm);
00388         osvm=NULL;
00389     }
00390 
00391     SG_UNREF(svm);
00392     svm=NULL;
00393     if (lpw)
00394     {
00395         delete lpw;
00396     }
00397     lpw=NULL;
00398     return(true);
00399 }
00400 
00401 
00402 
00403 
00404 float64_t* CMKLMulticlass::getsubkernelweights(int32_t & numweights)
00405 {
00406     if ( weightshistory.empty() )
00407     {
00408         numweights=0;
00409         return NULL;
00410     }
00411 
00412     std::vector<float64_t> subkerw=weightshistory.back();
00413     numweights=weightshistory.back().size();
00414 
00415     float64_t* res=SG_MALLOC(float64_t, numweights);
00416     std::copy(weightshistory.back().begin(), weightshistory.back().end(),res);
00417     return res;
00418 }
00419 
00420 void CMKLMulticlass::set_mkl_epsilon(float64_t eps )
00421 {
00422     mkl_eps=eps;
00423 }
00424 
00425 void CMKLMulticlass::set_max_num_mkliters(int32_t maxnum)
00426 {
00427     max_num_mkl_iters=maxnum;
00428 }
00429 
00430 void CMKLMulticlass::set_mkl_norm(float64_t norm)
00431 {
00432     pnorm=norm;
00433     if(pnorm<1 )
00434         SG_ERROR("CMKLMulticlass::set_mkl_norm(float64_t norm) : parameter pnorm<1");
00435 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation