ScatterSVM.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2009 Soeren Sonnenburg
00008  * Written (W) 2009 Marius Kloft
00009  * Copyright (C) 2009 TU Berlin and Max-Planck-Society
00010  */
00011 #ifdef USE_SVMLIGHT
00012 #include <shogun/classifier/svm/SVMLightOneClass.h>
00013 #endif //USE_SVMLIGHT
00014 
00015 #include <shogun/kernel/Kernel.h>
00016 #include <shogun/multiclass/ScatterSVM.h>
00017 #include <shogun/kernel/normalizer/ScatterKernelNormalizer.h>
00018 #include <shogun/multiclass/MulticlassOneVsRestStrategy.h>
00019 #include <shogun/io/SGIO.h>
00020 
00021 using namespace shogun;
00022 
00023 CScatterSVM::CScatterSVM()
00024 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(NO_BIAS_LIBSVM),
00025   model(NULL), norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00026 {
00027     SG_UNSTABLE("CScatterSVM::CScatterSVM()", "\n");
00028 }
00029 
00030 CScatterSVM::CScatterSVM(SCATTER_TYPE type)
00031 : CMulticlassSVM(new CMulticlassOneVsRestStrategy()), scatter_type(type), model(NULL),
00032     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00033 {
00034 }
00035 
00036 CScatterSVM::CScatterSVM(float64_t C, CKernel* k, CLabels* lab)
00037 : CMulticlassSVM(new CMulticlassOneVsRestStrategy(), C, k, lab), scatter_type(NO_BIAS_LIBSVM), model(NULL),
00038     norm_wc(NULL), norm_wcw(NULL), rho(0), m_num_classes(0)
00039 {
00040 }
00041 
00042 CScatterSVM::~CScatterSVM()
00043 {
00044     SG_FREE(norm_wc);
00045     SG_FREE(norm_wcw);
00046 }
00047 
00048 bool CScatterSVM::train_machine(CFeatures* data)
00049 {
00050     ASSERT(m_labels && m_labels->get_num_labels());
00051     ASSERT(m_labels->get_label_type() == LT_MULTICLASS);
00052 
00053     m_num_classes = m_multiclass_strategy->get_num_classes();
00054     int32_t num_vectors = m_labels->get_num_labels();
00055 
00056     if (data)
00057     {
00058         if (m_labels->get_num_labels() != data->get_num_vectors())
00059             SG_ERROR("Number of training vectors does not match number of labels\n");
00060         m_kernel->init(data, data);
00061     }
00062 
00063     int32_t* numc=SG_MALLOC(int32_t, m_num_classes);
00064     SGVector<int32_t>::fill_vector(numc, m_num_classes, 0);
00065 
00066     for (int32_t i=0; i<num_vectors; i++)
00067         numc[(int32_t) ((CMulticlassLabels*) m_labels)->get_int_label(i)]++;
00068 
00069     int32_t Nc=0;
00070     int32_t Nmin=num_vectors;
00071     for (int32_t i=0; i<m_num_classes; i++)
00072     {
00073         if (numc[i]>0)
00074         {
00075             Nc++;
00076             Nmin=CMath::min(Nmin, numc[i]);
00077         }
00078 
00079     }
00080     SG_FREE(numc);
00081     m_num_classes=m_num_classes;
00082 
00083     bool result=false;
00084 
00085     if (scatter_type==NO_BIAS_LIBSVM)
00086     {
00087         result=train_no_bias_libsvm();
00088     }
00089 #ifdef USE_SVMLIGHT
00090     else if (scatter_type==NO_BIAS_SVMLIGHT)
00091     {
00092         result=train_no_bias_svmlight();
00093     }
00094 #endif //USE_SVMLIGHT
00095     else if (scatter_type==TEST_RULE1 || scatter_type==TEST_RULE2)
00096     {
00097         float64_t nu_min=((float64_t) Nc)/num_vectors;
00098         float64_t nu_max=((float64_t) Nc)*Nmin/num_vectors;
00099 
00100         SG_INFO("valid nu interval [%f ... %f]\n", nu_min, nu_max);
00101 
00102         if (get_nu()<nu_min || get_nu()>nu_max)
00103             SG_ERROR("nu out of valid range [%f ... %f]\n", nu_min, nu_max);
00104 
00105         result=train_testrule12();
00106     }
00107     else
00108         SG_ERROR("Unknown Scatter type\n");
00109 
00110     return result;
00111 }
00112 
00113 bool CScatterSVM::train_no_bias_libsvm()
00114 {
00115     struct svm_node* x_space;
00116 
00117     problem.l=m_labels->get_num_labels();
00118     SG_INFO( "%d trainlabels\n", problem.l);
00119 
00120     problem.y=SG_MALLOC(float64_t, problem.l);
00121     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00122     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00123 
00124     for (int32_t i=0; i<problem.l; i++)
00125     {
00126         problem.y[i]=+1;
00127         problem.x[i]=&x_space[2*i];
00128         x_space[2*i].index=i;
00129         x_space[2*i+1].index=-1;
00130     }
00131 
00132     int32_t weights_label[2]={-1,+1};
00133     float64_t weights[2]={1.0,get_C()/get_C()};
00134 
00135     ASSERT(m_kernel && m_kernel->has_features());
00136     ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
00137 
00138     param.svm_type=C_SVC; // Nu MC SVM
00139     param.kernel_type = LINEAR;
00140     param.degree = 3;
00141     param.gamma = 0;    // 1/k
00142     param.coef0 = 0;
00143     param.nu = get_nu(); // Nu
00144     CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00145     m_kernel->set_normalizer(new CScatterKernelNormalizer(
00146                 m_num_classes-1, -1, m_labels, prev_normalizer));
00147     param.kernel=m_kernel;
00148     param.cache_size = m_kernel->get_cache_size();
00149     param.C = 0;
00150     param.eps = get_epsilon();
00151     param.p = 0.1;
00152     param.shrinking = 0;
00153     param.nr_weight = 2;
00154     param.weight_label = weights_label;
00155     param.weight = weights;
00156     param.nr_class=m_num_classes;
00157     param.use_bias = svm_proto()->get_bias_enabled();
00158 
00159     const char* error_msg = svm_check_parameter(&problem,&param);
00160 
00161     if(error_msg)
00162         SG_ERROR("Error: %s\n",error_msg);
00163 
00164     model = svm_train(&problem, &param);
00165     m_kernel->set_normalizer(prev_normalizer);
00166     SG_UNREF(prev_normalizer);
00167 
00168     if (model)
00169     {
00170         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00171 
00172         ASSERT(model->nr_class==m_num_classes);
00173         create_multiclass_svm(m_num_classes);
00174 
00175         rho=model->rho[0];
00176 
00177         SG_FREE(norm_wcw);
00178         norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00179 
00180         for (int32_t i=0; i<m_num_classes; i++)
00181         {
00182             int32_t num_sv=model->nSV[i];
00183 
00184             CSVM* svm=new CSVM(num_sv);
00185             svm->set_bias(model->rho[i+1]);
00186             norm_wcw[i]=model->normwcw[i];
00187 
00188 
00189             for (int32_t j=0; j<num_sv; j++)
00190             {
00191                 svm->set_alpha(j, model->sv_coef[i][j]);
00192                 svm->set_support_vector(j, model->SV[i][j].index);
00193             }
00194 
00195             set_svm(i, svm);
00196         }
00197 
00198         SG_FREE(problem.x);
00199         SG_FREE(problem.y);
00200         SG_FREE(x_space);
00201         for (int32_t i=0; i<m_num_classes; i++)
00202         {
00203             SG_FREE(model->SV[i]);
00204             model->SV[i]=NULL;
00205         }
00206         svm_destroy_model(model);
00207 
00208         if (scatter_type==TEST_RULE2)
00209             compute_norm_wc();
00210 
00211         model=NULL;
00212         return true;
00213     }
00214     else
00215         return false;
00216 }
00217 
00218 #ifdef USE_SVMLIGHT
00219 bool CScatterSVM::train_no_bias_svmlight()
00220 {
00221     CKernelNormalizer* prev_normalizer=m_kernel->get_normalizer();
00222     CScatterKernelNormalizer* n=new CScatterKernelNormalizer(
00223                  m_num_classes-1, -1, m_labels, prev_normalizer);
00224     m_kernel->set_normalizer(n);
00225     m_kernel->init_normalizer();
00226 
00227     CSVMLightOneClass* light=new CSVMLightOneClass(get_C(), m_kernel);
00228     light->set_linadd_enabled(false);
00229     light->train();
00230 
00231     SG_FREE(norm_wcw);
00232     norm_wcw = SG_MALLOC(float64_t, m_num_classes);
00233 
00234     int32_t num_sv=light->get_num_support_vectors();
00235     svm_proto()->create_new_model(num_sv);
00236 
00237     for (int32_t i=0; i<num_sv; i++)
00238     {
00239         svm_proto()->set_alpha(i, light->get_alpha(i));
00240         svm_proto()->set_support_vector(i, light->get_support_vector(i));
00241     }
00242 
00243     m_kernel->set_normalizer(prev_normalizer);
00244     return true;
00245 }
00246 #endif //USE_SVMLIGHT
00247 
00248 bool CScatterSVM::train_testrule12()
00249 {
00250     struct svm_node* x_space;
00251     problem.l=m_labels->get_num_labels();
00252     SG_INFO( "%d trainlabels\n", problem.l);
00253 
00254     problem.y=SG_MALLOC(float64_t, problem.l);
00255     problem.x=SG_MALLOC(struct svm_node*, problem.l);
00256     x_space=SG_MALLOC(struct svm_node, 2*problem.l);
00257 
00258     for (int32_t i=0; i<problem.l; i++)
00259     {
00260         problem.y[i]=((CMulticlassLabels*) m_labels)->get_label(i);
00261         problem.x[i]=&x_space[2*i];
00262         x_space[2*i].index=i;
00263         x_space[2*i+1].index=-1;
00264     }
00265 
00266     int32_t weights_label[2]={-1,+1};
00267     float64_t weights[2]={1.0,get_C()/get_C()};
00268 
00269     ASSERT(m_kernel && m_kernel->has_features());
00270     ASSERT(m_kernel->get_num_vec_lhs()==problem.l);
00271 
00272     param.svm_type=NU_MULTICLASS_SVC; // Nu MC SVM
00273     param.kernel_type = LINEAR;
00274     param.degree = 3;
00275     param.gamma = 0;    // 1/k
00276     param.coef0 = 0;
00277     param.nu = get_nu(); // Nu
00278     param.kernel=m_kernel;
00279     param.cache_size = m_kernel->get_cache_size();
00280     param.C = 0;
00281     param.eps = get_epsilon();
00282     param.p = 0.1;
00283     param.shrinking = 0;
00284     param.nr_weight = 2;
00285     param.weight_label = weights_label;
00286     param.weight = weights;
00287     param.nr_class=m_num_classes;
00288     param.use_bias = svm_proto()->get_bias_enabled();
00289 
00290     const char* error_msg = svm_check_parameter(&problem,&param);
00291 
00292     if(error_msg)
00293         SG_ERROR("Error: %s\n",error_msg);
00294 
00295     model = svm_train(&problem, &param);
00296 
00297     if (model)
00298     {
00299         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef));
00300 
00301         ASSERT(model->nr_class==m_num_classes);
00302         create_multiclass_svm(m_num_classes);
00303 
00304         rho=model->rho[0];
00305 
00306         SG_FREE(norm_wcw);
00307         norm_wcw = SG_MALLOC(float64_t, m_machines->get_num_elements());
00308 
00309         for (int32_t i=0; i<m_num_classes; i++)
00310         {
00311             int32_t num_sv=model->nSV[i];
00312 
00313             CSVM* svm=new CSVM(num_sv);
00314             svm->set_bias(model->rho[i+1]);
00315             norm_wcw[i]=model->normwcw[i];
00316 
00317 
00318             for (int32_t j=0; j<num_sv; j++)
00319             {
00320                 svm->set_alpha(j, model->sv_coef[i][j]);
00321                 svm->set_support_vector(j, model->SV[i][j].index);
00322             }
00323 
00324             set_svm(i, svm);
00325         }
00326 
00327         SG_FREE(problem.x);
00328         SG_FREE(problem.y);
00329         SG_FREE(x_space);
00330         for (int32_t i=0; i<m_num_classes; i++)
00331         {
00332             SG_FREE(model->SV[i]);
00333             model->SV[i]=NULL;
00334         }
00335         svm_destroy_model(model);
00336 
00337         if (scatter_type==TEST_RULE2)
00338             compute_norm_wc();
00339 
00340         model=NULL;
00341         return true;
00342     }
00343     else
00344         return false;
00345 }
00346 
00347 void CScatterSVM::compute_norm_wc()
00348 {
00349     SG_FREE(norm_wc);
00350     norm_wc = SG_MALLOC(float64_t, m_machines->get_num_elements());
00351     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00352         norm_wc[i]=0;
00353 
00354 
00355     for (int c=0; c<m_machines->get_num_elements(); c++)
00356     {
00357         CSVM* svm=get_svm(c);
00358         int32_t num_sv = svm->get_num_support_vectors();
00359 
00360         for (int32_t i=0; i<num_sv; i++)
00361         {
00362             int32_t ii=svm->get_support_vector(i);
00363             for (int32_t j=0; j<num_sv; j++)
00364             {
00365                 int32_t jj=svm->get_support_vector(j);
00366                 norm_wc[c]+=svm->get_alpha(i)*m_kernel->kernel(ii,jj)*svm->get_alpha(j);
00367             }
00368         }
00369     }
00370 
00371     for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00372         norm_wc[i]=CMath::sqrt(norm_wc[i]);
00373 
00374     SGVector<float64_t>::display_vector(norm_wc, m_machines->get_num_elements(), "norm_wc");
00375 }
00376 
00377 CLabels* CScatterSVM::classify_one_vs_rest()
00378 {
00379     CMulticlassLabels* output=NULL;
00380     if (!m_kernel)
00381     {
00382         SG_ERROR( "SVM can not proceed without kernel!\n");
00383         return NULL;
00384     }
00385 
00386     if (!( m_kernel && m_kernel->get_num_vec_lhs() && m_kernel->get_num_vec_rhs()))
00387         return NULL;
00388 
00389     int32_t num_vectors=m_kernel->get_num_vec_rhs();
00390 
00391     output=new CMulticlassLabels(num_vectors);
00392     SG_REF(output);
00393 
00394     if (scatter_type == TEST_RULE1)
00395     {
00396         ASSERT(m_machines->get_num_elements()>0);
00397         for (int32_t i=0; i<num_vectors; i++)
00398             output->set_label(i, apply(i));
00399     }
00400 #ifdef USE_SVMLIGHT
00401     else if (scatter_type == NO_BIAS_SVMLIGHT)
00402     {
00403         float64_t* outputs=SG_MALLOC(float64_t, num_vectors*m_num_classes);
00404         SGVector<float64_t>::fill_vector(outputs,num_vectors*m_num_classes,0.0);
00405 
00406         for (int32_t i=0; i<num_vectors; i++)
00407         {
00408             for (int32_t j=0; j<svm_proto()->get_num_support_vectors(); j++)
00409             {
00410                 float64_t score=m_kernel->kernel(svm_proto()->get_support_vector(j), i)*svm_proto()->get_alpha(j);
00411                 int32_t label=((CMulticlassLabels*) m_labels)->get_int_label(svm_proto()->get_support_vector(j));
00412                 for (int32_t c=0; c<m_num_classes; c++)
00413                 {
00414                     float64_t s= (label==c) ? (m_num_classes-1) : (-1);
00415                     outputs[c+i*m_num_classes]+=s*score;
00416                 }
00417             }
00418         }
00419 
00420         for (int32_t i=0; i<num_vectors; i++)
00421         {
00422             int32_t winner=0;
00423             float64_t max_out=outputs[i*m_num_classes+0];
00424 
00425             for (int32_t j=1; j<m_num_classes; j++)
00426             {
00427                 float64_t out=outputs[i*m_num_classes+j];
00428 
00429                 if (out>max_out)
00430                 {
00431                     winner=j;
00432                     max_out=out;
00433                 }
00434             }
00435 
00436             output->set_label(i, winner);
00437         }
00438 
00439         SG_FREE(outputs);
00440     }
00441 #endif //USE_SVMLIGHT
00442     else
00443     {
00444         ASSERT(m_machines->get_num_elements()>0);
00445         ASSERT(num_vectors==output->get_num_labels());
00446         CLabels** outputs=SG_MALLOC(CLabels*, m_machines->get_num_elements());
00447 
00448         for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00449         {
00450             //SG_PRINT("svm %d\n", i);
00451             CSVM *svm = get_svm(i);
00452             ASSERT(svm);
00453             svm->set_kernel(m_kernel);
00454             svm->set_labels(m_labels);
00455             outputs[i]=svm->apply();
00456             SG_UNREF(svm);
00457         }
00458 
00459         for (int32_t i=0; i<num_vectors; i++)
00460         {
00461             int32_t winner=0;
00462             float64_t max_out=((CRegressionLabels*) outputs[0])->get_label(i)/norm_wc[0];
00463 
00464             for (int32_t j=1; j<m_machines->get_num_elements(); j++)
00465             {
00466                 float64_t out=((CRegressionLabels*) outputs[j])->get_label(i)/norm_wc[j];
00467 
00468                 if (out>max_out)
00469                 {
00470                     winner=j;
00471                     max_out=out;
00472                 }
00473             }
00474 
00475             output->set_label(i, winner);
00476         }
00477 
00478         for (int32_t i=0; i<m_machines->get_num_elements(); i++)
00479             SG_UNREF(outputs[i]);
00480 
00481         SG_FREE(outputs);
00482     }
00483 
00484     return output;
00485 }
00486 
00487 float64_t CScatterSVM::apply(int32_t num)
00488 {
00489     ASSERT(m_machines->get_num_elements()>0);
00490     float64_t* outputs=SG_MALLOC(float64_t, m_machines->get_num_elements());
00491     int32_t winner=0;
00492 
00493     if (scatter_type == TEST_RULE1)
00494     {
00495         for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00496             outputs[c]=get_svm(c)->get_bias()-rho;
00497 
00498         for (int32_t c=0; c<m_machines->get_num_elements(); c++)
00499         {
00500             float64_t v=0;
00501 
00502             for (int32_t i=0; i<get_svm(c)->get_num_support_vectors(); i++)
00503             {
00504                 float64_t alpha=get_svm(c)->get_alpha(i);
00505                 int32_t svidx=get_svm(c)->get_support_vector(i);
00506                 v += alpha*m_kernel->kernel(svidx, num);
00507             }
00508 
00509             outputs[c] += v;
00510             for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00511                 outputs[j] -= v/m_machines->get_num_elements();
00512         }
00513 
00514         for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00515             outputs[j]/=norm_wcw[j];
00516 
00517         float64_t max_out=outputs[0];
00518         for (int32_t j=0; j<m_machines->get_num_elements(); j++)
00519         {
00520             if (outputs[j]>max_out)
00521             {
00522                 max_out=outputs[j];
00523                 winner=j;
00524             }
00525         }
00526     }
00527 #ifdef USE_SVMLIGHT
00528     else if (scatter_type == NO_BIAS_SVMLIGHT)
00529     {
00530         SG_ERROR("Use classify...\n");
00531     }
00532 #endif //USE_SVMLIGHT
00533     else
00534     {
00535         float64_t max_out=get_svm(0)->apply_one(num)/norm_wc[0];
00536 
00537         for (int32_t i=1; i<m_machines->get_num_elements(); i++)
00538         {
00539             outputs[i]=get_svm(i)->apply_one(num)/norm_wc[i];
00540             if (outputs[i]>max_out)
00541             {
00542                 winner=i;
00543                 max_out=outputs[i];
00544             }
00545         }
00546     }
00547 
00548     SG_FREE(outputs);
00549     return winner;
00550 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation