LibSVMOneClass.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006 Christian Gehl
00008  * Written (W) 2006-2009 Soeren Sonnenburg
00009  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00010  */
00011 
00012 #include "classifier/svm/LibSVMOneClass.h"
00013 #include "lib/io.h"
00014 
00015 using namespace shogun;
00016 
00017 CLibSVMOneClass::CLibSVMOneClass()
00018 : CSVM(), model(NULL)
00019 {
00020 }
00021 
00022 CLibSVMOneClass::CLibSVMOneClass(float64_t C, CKernel* k)
00023 : CSVM(C, k, NULL), model(NULL)
00024 {
00025 }
00026 
00027 CLibSVMOneClass::~CLibSVMOneClass()
00028 {
00029     free(model);
00030 }
00031 
00032 bool CLibSVMOneClass::train(CFeatures* data)
00033 {
00034     ASSERT(kernel);
00035     if (data)
00036         kernel->init(data, data);
00037 
00038     problem.l=kernel->get_num_vec_lhs();
00039 
00040     struct svm_node* x_space;
00041     SG_INFO("%d train data points\n", problem.l);
00042 
00043     problem.y=NULL;
00044     problem.x=new struct svm_node*[problem.l];
00045     x_space=new struct svm_node[2*problem.l];
00046 
00047     for (int32_t i=0; i<problem.l; i++)
00048     {
00049         problem.x[i]=&x_space[2*i];
00050         x_space[2*i].index=i;
00051         x_space[2*i+1].index=-1;
00052     }
00053 
00054     int32_t weights_label[2]={-1,+1};
00055     float64_t weights[2]={1.0,get_C2()/get_C1()};
00056 
00057     param.svm_type=ONE_CLASS; // C SVM
00058     param.kernel_type = LINEAR;
00059     param.degree = 3;
00060     param.gamma = 0;    // 1/k
00061     param.coef0 = 0;
00062     param.nu = get_nu();
00063     param.kernel=kernel;
00064     param.cache_size = kernel->get_cache_size();
00065     param.max_train_time = max_train_time;
00066     param.C = get_C1();
00067     param.eps = epsilon;
00068     param.p = 0.1;
00069     param.shrinking = 1;
00070     param.nr_weight = 2;
00071     param.weight_label = weights_label;
00072     param.weight = weights;
00073     param.use_bias = get_bias_enabled();
00074     
00075     const char* error_msg = svm_check_parameter(&problem,&param);
00076 
00077     if(error_msg)
00078         SG_ERROR("Error: %s\n",error_msg);
00079     
00080     model = svm_train(&problem, &param);
00081 
00082     if (model)
00083     {
00084         ASSERT(model->nr_class==2);
00085         ASSERT((model->l==0) || (model->l>0 && model->SV && model->sv_coef && model->sv_coef[0]));
00086 
00087         int32_t num_sv=model->l;
00088 
00089         create_new_model(num_sv);
00090         CSVM::set_objective(model->objective);
00091 
00092         set_bias(-model->rho[0]);
00093         for (int32_t i=0; i<num_sv; i++)
00094         {
00095             set_support_vector(i, (model->SV[i])->index);
00096             set_alpha(i, model->sv_coef[0][i]);
00097         }
00098 
00099         delete[] problem.x;
00100         delete[] x_space;
00101         svm_destroy_model(model);
00102         model=NULL;
00103 
00104         return true;
00105     }
00106     else
00107         return false;
00108 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation