SVMLin.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006-2009 Soeren Sonnenburg
00008  * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include <shogun/classifier/svm/SVMLin.h>
00012 #include <shogun/labels/Labels.h>
00013 #include <shogun/mathematics/Math.h>
00014 #include <shogun/lib/external/ssl.h>
00015 #include <shogun/machine/LinearMachine.h>
00016 #include <shogun/features/DotFeatures.h>
00017 #include <shogun/labels/Labels.h>
00018 #include <shogun/labels/BinaryLabels.h>
00019 
00020 using namespace shogun;
00021 
00022 CSVMLin::CSVMLin()
00023 : CLinearMachine(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00024 {
00025 }
00026 
00027 CSVMLin::CSVMLin(
00028     float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00029 : CLinearMachine(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00030 {
00031     set_features(traindat);
00032     set_labels(trainlab);
00033 }
00034 
00035 
00036 CSVMLin::~CSVMLin()
00037 {
00038 }
00039 
00040 bool CSVMLin::train_machine(CFeatures* data)
00041 {
00042     ASSERT(m_labels);
00043 
00044     if (data)
00045     {
00046         if (!data->has_property(FP_DOT))
00047             SG_ERROR("Specified features are not of type CDotFeatures\n");
00048         set_features((CDotFeatures*) data);
00049     }
00050 
00051     ASSERT(features);
00052 
00053     SGVector<float64_t> train_labels=((CBinaryLabels*) m_labels)->get_labels();
00054     int32_t num_feat=features->get_dim_feature_space();
00055     int32_t num_vec=features->get_num_vectors();
00056 
00057     ASSERT(num_vec==train_labels.vlen);
00058 
00059     struct options Options;
00060     struct data Data;
00061     struct vector_double Weights;
00062     struct vector_double Outputs;
00063 
00064     Data.l=num_vec;
00065     Data.m=num_vec;
00066     Data.u=0;
00067     Data.n=num_feat+1;
00068     Data.nz=num_feat+1;
00069     Data.Y=train_labels.vector;
00070     Data.features=features;
00071     Data.C = SG_MALLOC(float64_t, Data.l);
00072 
00073     Options.algo = SVM;
00074     Options.lambda=1/(2*get_C1());
00075     Options.lambda_u=1/(2*get_C1());
00076     Options.S=10000;
00077     Options.R=0.5;
00078     Options.epsilon = get_epsilon();
00079     Options.cgitermax=10000;
00080     Options.mfnitermax=50;
00081     Options.Cp = get_C2()/get_C1();
00082     Options.Cn = 1;
00083 
00084     if (use_bias)
00085         Options.bias=1.0;
00086     else
00087         Options.bias=0.0;
00088 
00089     for (int32_t i=0;i<num_vec;i++)
00090     {
00091         if(train_labels.vector[i]>0)
00092             Data.C[i]=Options.Cp;
00093         else
00094             Data.C[i]=Options.Cn;
00095     }
00096     ssl_train(&Data, &Options, &Weights, &Outputs);
00097     ASSERT(Weights.vec && Weights.d==num_feat+1);
00098 
00099     float64_t sgn=train_labels.vector[0];
00100     for (int32_t i=0; i<num_feat+1; i++)
00101         Weights.vec[i]*=sgn;
00102 
00103     set_w(SGVector<float64_t>(Weights.vec, num_feat));
00104     set_bias(Weights.vec[num_feat]);
00105 
00106     SG_FREE(Data.C);
00107     SG_FREE(Outputs.vec);
00108     return true;
00109 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation