Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include "classifier/svm/SVMLin.h"
00012 #include "features/Labels.h"
00013 #include "lib/Mathematics.h"
00014 #include "classifier/svm/ssl.h"
00015 #include "classifier/LinearClassifier.h"
00016 #include "features/DotFeatures.h"
00017 #include "features/Labels.h"
00018
00019 using namespace shogun;
00020
00021 CSVMLin::CSVMLin()
00022 : CLinearClassifier(), C1(1), C2(1), epsilon(1e-5), use_bias(true)
00023 {
00024 }
00025
00026 CSVMLin::CSVMLin(
00027 float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00028 : CLinearClassifier(), C1(C), C2(C), epsilon(1e-5), use_bias(true)
00029 {
00030 set_features(traindat);
00031 set_labels(trainlab);
00032 }
00033
00034
00035 CSVMLin::~CSVMLin()
00036 {
00037 }
00038
00039 bool CSVMLin::train(CFeatures* data)
00040 {
00041 ASSERT(labels);
00042
00043 if (data)
00044 {
00045 if (!data->has_property(FP_DOT))
00046 SG_ERROR("Specified features are not of type CDotFeatures\n");
00047 set_features((CDotFeatures*) data);
00048 }
00049
00050 ASSERT(features);
00051
00052 int32_t num_train_labels=0;
00053 float64_t* train_labels=labels->get_labels(num_train_labels);
00054 int32_t num_feat=features->get_dim_feature_space();
00055 int32_t num_vec=features->get_num_vectors();
00056
00057 ASSERT(num_vec==num_train_labels);
00058 delete[] w;
00059
00060 struct options Options;
00061 struct data Data;
00062 struct vector_double Weights;
00063 struct vector_double Outputs;
00064
00065 Data.l=num_vec;
00066 Data.m=num_vec;
00067 Data.u=0;
00068 Data.n=num_feat+1;
00069 Data.nz=num_feat+1;
00070 Data.Y=train_labels;
00071 Data.features=features;
00072 Data.C = new float64_t[Data.l];
00073
00074 Options.algo = SVM;
00075 Options.lambda=1/(2*get_C1());
00076 Options.lambda_u=1/(2*get_C1());
00077 Options.S=10000;
00078 Options.R=0.5;
00079 Options.epsilon = get_epsilon();
00080 Options.cgitermax=10000;
00081 Options.mfnitermax=50;
00082 Options.Cp = get_C2()/get_C1();
00083 Options.Cn = 1;
00084
00085 if (use_bias)
00086 Options.bias=1.0;
00087 else
00088 Options.bias=0.0;
00089
00090 for (int32_t i=0;i<num_vec;i++)
00091 {
00092 if(train_labels[i]>0)
00093 Data.C[i]=Options.Cp;
00094 else
00095 Data.C[i]=Options.Cn;
00096 }
00097 ssl_train(&Data, &Options, &Weights, &Outputs);
00098 ASSERT(Weights.vec && Weights.d==num_feat+1);
00099
00100 float64_t sgn=train_labels[0];
00101 for (int32_t i=0; i<num_feat+1; i++)
00102 Weights.vec[i]*=sgn;
00103
00104 set_w(Weights.vec, num_feat);
00105 set_bias(Weights.vec[num_feat]);
00106
00107 delete[] Weights.vec;
00108 delete[] Data.C;
00109 delete[] train_labels;
00110 delete[] Outputs.vec;
00111 return true;
00112 }