SVMLin.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2006-2009 Soeren Sonnenburg
00008  * Copyright (C) 2006-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "features/DotFeatures.h"
00017 #include "features/Labels.h"
00018 
00019 using namespace shogun;
00020 
00021 CSVMLin::CSVMLin()
00022 : CLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00023 {
00024 }
00025 
00026 CSVMLin::CSVMLin(
00027     float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00028 : CLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00029 {
00030     set_features(traindat);
00031     set_labels(trainlab);
00032 }
00033 
00034 
00035 CSVMLin::~CSVMLin()
00036 {
00037 }
00038 
00039 bool CSVMLin::train(CFeatures* data)
00040 {
00041     ASSERT(labels);
00042 
00043     if (data)
00044     {
00045         if (!data->has_property(FP_DOT))
00046             SG_ERROR("Specified features are not of type CDotFeatures\n");
00047         set_features((CDotFeatures*) data);
00048     }
00049 
00050     ASSERT(features);
00051 
00052     int32_t num_train_labels=0;
00053     float64_t* train_labels=labels->get_labels(num_train_labels);
00054     int32_t num_feat=features->get_dim_feature_space();
00055     int32_t num_vec=features->get_num_vectors();
00056 
00057     ASSERT(num_vec==num_train_labels);
00058     delete[] w;
00059 
00060     struct options Options;
00061     struct data Data;
00062     struct vector_double Weights;
00063     struct vector_double Outputs;
00064 
00065     Data.l=num_vec;
00066     Data.m=num_vec;
00067     Data.u=0; 
00068     Data.n=num_feat+1;
00069     Data.nz=num_feat+1;
00070     Data.Y=train_labels;
00071     Data.features=features;
00072     Data.C = new float64_t[Data.l];
00073 
00074     Options.algo = SVM;
00075     Options.lambda=1/(2*get_C1());
00076     Options.lambda_u=1/(2*get_C1());
00077     Options.S=10000;
00078     Options.R=0.5;
00079     Options.epsilon = get_epsilon();
00080     Options.cgitermax=10000;
00081     Options.mfnitermax=50;
00082     Options.Cp = get_C2()/get_C1();
00083     Options.Cn = 1;
00084     
00085     if (use_bias)
00086         Options.bias=1.0;
00087     else
00088         Options.bias=0.0;
00089 
00090     for (int32_t i=0;i<num_vec;i++)
00091     {
00092         if(train_labels[i]>0) 
00093             Data.C[i]=Options.Cp;
00094         else 
00095             Data.C[i]=Options.Cn;
00096     }
00097     ssl_train(&Data, &Options, &Weights, &Outputs);
00098     ASSERT(Weights.vec && Weights.d==num_feat+1);
00099 
00100     float64_t sgn=train_labels[0];
00101     for (int32_t i=0; i<num_feat+1; i++)
00102         Weights.vec[i]*=sgn;
00103 
00104     set_w(Weights.vec, num_feat);
00105     set_bias(Weights.vec[num_feat]);
00106 
00107     delete[] Weights.vec;
00108     delete[] Data.C;
00109     delete[] train_labels;
00110     delete[] Outputs.vec;
00111     return true;
00112 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation