Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #ifndef _VOWPALWABBIT_H__
00017 #define _VOWPALWABBIT_H__
00018
00019 #include <shogun/classifier/vw/vw_common.h>
00020 #include <shogun/classifier/vw/learners/VwAdaptiveLearner.h>
00021 #include <shogun/classifier/vw/learners/VwNonAdaptiveLearner.h>
00022 #include <shogun/classifier/vw/VwRegressor.h>
00023
00024 #include <shogun/features/streaming/StreamingVwFeatures.h>
00025 #include <shogun/machine/OnlineLinearMachine.h>
00026
00027 namespace shogun
00028 {
00038 class CVowpalWabbit: public COnlineLinearMachine
00039 {
00040 public:
00041
00043 MACHINE_PROBLEM_TYPE(PT_BINARY);
00044
00048 CVowpalWabbit();
00049
00056 CVowpalWabbit(CStreamingVwFeatures* feat);
00057
00061 CVowpalWabbit(CVowpalWabbit *vw);
00062
00066 ~CVowpalWabbit();
00067
00072 void reinitialize_weights();
00073
00082 void set_no_training(bool dont_train) { no_training = dont_train; }
00083
00089 void set_adaptive(bool adaptive_learning);
00090
00097 void set_exact_adaptive_norm(bool exact_adaptive);
00098
00104 void set_num_passes(int32_t passes)
00105 {
00106 env->num_passes = passes;
00107 }
00108
00114 void load_regressor(char* file_name);
00115
00122 void set_regressor_out(char* file_name, bool is_text = true);
00123
00129 void set_prediction_out(char* file_name);
00130
00137 void add_quadratic_pair(char* pair);
00138
00144 virtual bool train_machine(CFeatures* feat = NULL);
00145
00153 virtual float32_t predict_and_finalize(VwExample* ex);
00154
00163 float32_t compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x);
00164
00177 float32_t compute_exact_norm_quad(float32_t* weights, VwFeature& page_feature, v_array<VwFeature> &offer_features,
00178 vw_size_t mask, float32_t g, float32_t& sum_abs_x);
00179
00185 virtual CVwEnvironment* get_env()
00186 {
00187 SG_REF(env);
00188 return env;
00189 }
00190
00196 virtual const char* get_name() const { return "VowpalWabbit"; }
00197
00202 virtual void set_learner();
00203
00207 CVwLearner* get_learner() { return learner; }
00208
00209 private:
00215 virtual void init(CStreamingVwFeatures* feat = NULL);
00216
00224 virtual float32_t inline_l1_predict(VwExample* &ex);
00225
00233 virtual float32_t inline_predict(VwExample* &ex);
00234
00242 virtual float32_t finalize_prediction(float32_t ret);
00243
00249 virtual void output_example(VwExample* &ex);
00250
00256 virtual void print_update(VwExample* &ex);
00257
00266 virtual void output_prediction(int32_t f, float32_t res, float32_t weight, v_array<char> tag);
00267
00273 void set_verbose(bool verbose);
00274
00275 protected:
00277 CStreamingVwFeatures* features;
00278
00280 CVwEnvironment* env;
00281
00283 CVwLearner* learner;
00284
00286 CVwRegressor* reg;
00287
00288 private:
00290 bool quiet;
00291
00293 bool no_training;
00294
00296 float32_t dump_interval;
00298 float32_t sum_loss_since_last_dump;
00300 float64_t old_weighted_examples;
00301
00303 char* reg_name;
00305 bool reg_dump_text;
00306
00308 bool save_predictions;
00310 int32_t prediction_fd;
00311 };
00312
00313 }
00314 #endif // _VOWPALWABBIT_H__