00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023 #include "classifier/svm/SVMSGD.h"
00024 #include "base/Parameter.h"
00025 #include "lib/Signal.h"
00026
00027 using namespace shogun;
00028
00029
00030 #define HINGELOSS 1
00031 #define SMOOTHHINGELOSS 2
00032 #define SQUAREDHINGELOSS 3
00033 #define LOGLOSS 10
00034 #define LOGLOSSMARGIN 11
00035
00036
00037 #define LOSS HINGELOSS
00038
00039
00040 #define REGULARIZEBIAS 0
00041
00042 inline
00043 float64_t loss(float64_t z)
00044 {
00045 #if LOSS == LOGLOSS
00046 if (z >= 0)
00047 return log(1+exp(-z));
00048 else
00049 return -z + log(1+exp(z));
00050 #elif LOSS == LOGLOSSMARGIN
00051 if (z >= 1)
00052 return log(1+exp(1-z));
00053 else
00054 return 1-z + log(1+exp(z-1));
00055 #elif LOSS == SMOOTHHINGELOSS
00056 if (z < 0)
00057 return 0.5 - z;
00058 if (z < 1)
00059 return 0.5 * (1-z) * (1-z);
00060 return 0;
00061 #elif LOSS == SQUAREDHINGELOSS
00062 if (z < 1)
00063 return 0.5 * (1 - z) * (1 - z);
00064 return 0;
00065 #elif LOSS == HINGELOSS
00066 if (z < 1)
00067 return 1 - z;
00068 return 0;
00069 #else
00070 # error "Undefined loss"
00071 #endif
00072 }
00073
00074 inline
00075 float64_t dloss(float64_t z)
00076 {
00077 #if LOSS == LOGLOSS
00078 if (z < 0)
00079 return 1 / (exp(z) + 1);
00080 float64_t ez = exp(-z);
00081 return ez / (ez + 1);
00082 #elif LOSS == LOGLOSSMARGIN
00083 if (z < 1)
00084 return 1 / (exp(z-1) + 1);
00085 float64_t ez = exp(1-z);
00086 return ez / (ez + 1);
00087 #elif LOSS == SMOOTHHINGELOSS
00088 if (z < 0)
00089 return 1;
00090 if (z < 1)
00091 return 1-z;
00092 return 0;
00093 #elif LOSS == SQUAREDHINGELOSS
00094 if (z < 1)
00095 return (1 - z);
00096 return 0;
00097 #else
00098 if (z < 1)
00099 return 1;
00100 return 0;
00101 #endif
00102 }
00103
00104
00105 CSVMSGD::CSVMSGD()
00106 : CLinearClassifier()
00107 {
00108 init();
00109 }
00110
00111 CSVMSGD::CSVMSGD(float64_t C)
00112 : CLinearClassifier()
00113 {
00114 init();
00115
00116 C1=C;
00117 C2=C;
00118 }
00119
00120 CSVMSGD::CSVMSGD(float64_t C, CDotFeatures* traindat, CLabels* trainlab)
00121 : CLinearClassifier()
00122 {
00123 init();
00124 C1=C;
00125 C2=C;
00126
00127 set_features(traindat);
00128 set_labels(trainlab);
00129 }
00130
00131 CSVMSGD::~CSVMSGD()
00132 {
00133 }
00134
00135 bool CSVMSGD::train(CFeatures* data)
00136 {
00137
00138 ASSERT(labels);
00139
00140 if (data)
00141 {
00142 if (!data->has_property(FP_DOT))
00143 SG_ERROR("Specified features are not of type CDotFeatures\n");
00144 set_features((CDotFeatures*) data);
00145 }
00146
00147 ASSERT(features);
00148 ASSERT(labels->is_two_class_labeling());
00149
00150 int32_t num_train_labels=labels->get_num_labels();
00151 w_dim=features->get_dim_feature_space();
00152 int32_t num_vec=features->get_num_vectors();
00153
00154 ASSERT(num_vec==num_train_labels);
00155 ASSERT(num_vec>0);
00156
00157 delete[] w;
00158 w=new float64_t[w_dim];
00159 memset(w, 0, w_dim*sizeof(float64_t));
00160 bias=0;
00161
00162 float64_t lambda= 1.0/(C1*num_vec);
00163
00164
00165
00166
00167 float64_t maxw = 1.0 / sqrt(lambda);
00168 float64_t typw = sqrt(maxw);
00169 float64_t eta0 = typw / CMath::max(1.0,dloss(-typw));
00170 t = 1 / (eta0 * lambda);
00171
00172 SG_INFO("lambda=%f, epochs=%d, eta0=%f\n", lambda, epochs, eta0);
00173
00174
00175
00176 calibrate();
00177
00178 SG_INFO("Training on %d vectors\n", num_vec);
00179 CSignal::clear_cancel();
00180
00181 for(int32_t e=0; e<epochs && (!CSignal::cancel_computations()); e++)
00182 {
00183 count = skip;
00184 for (int32_t i=0; i<num_vec; i++)
00185 {
00186 float64_t eta = 1.0 / (lambda * t);
00187 float64_t y = labels->get_label(i);
00188 float64_t z = y * (features->dense_dot(i, w, w_dim) + bias);
00189
00190 #if LOSS < LOGLOSS
00191 if (z < 1)
00192 #endif
00193 {
00194 float64_t etd = eta * dloss(z);
00195 features->add_to_dense_vec(etd * y / wscale, i, w, w_dim);
00196
00197 if (use_bias)
00198 {
00199 if (use_regularized_bias)
00200 bias *= 1 - eta * lambda * bscale;
00201 bias += etd * y * bscale;
00202 }
00203 }
00204
00205 if (--count <= 0)
00206 {
00207 float64_t r = 1 - eta * lambda * skip;
00208 if (r < 0.8)
00209 r = pow(1 - eta * lambda, skip);
00210 CMath::scale_vector(r, w, w_dim);
00211 count = skip;
00212 }
00213 t++;
00214 }
00215 }
00216
00217 float64_t wnorm = CMath::dot(w,w, w_dim);
00218 SG_INFO("Norm: %.6f, Bias: %.6f\n", wnorm, bias);
00219
00220 return true;
00221 }
00222
00223 void CSVMSGD::calibrate()
00224 {
00225 ASSERT(features);
00226 int32_t num_vec=features->get_num_vectors();
00227 int32_t c_dim=features->get_dim_feature_space();
00228
00229 ASSERT(num_vec>0);
00230 ASSERT(c_dim>0);
00231
00232 float64_t* c=new float64_t[c_dim];
00233 memset(c, 0, c_dim*sizeof(float64_t));
00234
00235 SG_INFO("Estimating sparsity and bscale num_vec=%d num_feat=%d.\n", num_vec, c_dim);
00236
00237
00238 int32_t n = 0;
00239 float64_t m = 0;
00240 float64_t r = 0;
00241
00242 for (int32_t j=0; j<num_vec && m<=1000; j++, n++)
00243 {
00244 r += features->get_nnz_features_for_vector(j);
00245 features->add_to_dense_vec(1, j, c, c_dim, true);
00246
00247
00248
00249 m=CMath::max(c, c_dim);
00250 }
00251
00252
00253 bscale = m/n;
00254
00255
00256 skip = (int32_t) ((16 * n * c_dim) / r);
00257 SG_INFO("using %d examples. skip=%d bscale=%.6f\n", n, skip, bscale);
00258
00259 delete[] c;
00260 }
00261
00262 void CSVMSGD::init()
00263 {
00264 t=1;
00265 C1=1;
00266 C2=1;
00267 wscale=1;
00268 bscale=1;
00269 epochs=5;
00270 skip=1000;
00271 count=1000;
00272 use_bias=true;
00273
00274 use_regularized_bias=false;
00275
00276 m_parameters->add(&C1, "C1", "Cost constant 1.");
00277 m_parameters->add(&C2, "C2", "Cost constant 2.");
00278 m_parameters->add(&wscale, "wscale", "W scale");
00279 m_parameters->add(&bscale, "bscale", "b scale");
00280 m_parameters->add(&epochs, "epochs", "epochs");
00281 m_parameters->add(&skip, "skip", "skip");
00282 m_parameters->add(&count, "count", "count");
00283 m_parameters->add(&use_bias, "use_bias", "Indicates if bias is used.");
00284 m_parameters->add(&use_regularized_bias, "use_regularized_bias", "Indicates if bias is regularized.");
00285 }