Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #ifndef _VOWPALWABBIT_H__
00017 #define _VOWPALWABBIT_H__
00018
00019 #include <shogun/classifier/vw/vw_common.h>
00020 #include <shogun/classifier/vw/learners/VwAdaptiveLearner.h>
00021 #include <shogun/classifier/vw/learners/VwNonAdaptiveLearner.h>
00022 #include <shogun/classifier/vw/VwRegressor.h>
00023
00024 #include <shogun/features/StreamingVwFeatures.h>
00025 #include <shogun/machine/OnlineLinearMachine.h>
00026
00027 namespace shogun
00028 {
00038 class CVowpalWabbit: public COnlineLinearMachine
00039 {
00040 public:
00044 CVowpalWabbit();
00045
00052 CVowpalWabbit(CStreamingVwFeatures* feat);
00053
00057 ~CVowpalWabbit();
00058
00063 void reinitialize_weights();
00064
00073 void set_no_training(bool dont_train) { no_training = dont_train; }
00074
00080 void set_adaptive(bool adaptive_learning);
00081
00088 void set_exact_adaptive_norm(bool exact_adaptive);
00089
00095 void set_num_passes(int32_t passes)
00096 {
00097 env->num_passes = passes;
00098 }
00099
00105 void load_regressor(char* file_name);
00106
00113 void set_regressor_out(char* file_name, bool is_text = true);
00114
00120 void set_prediction_out(char* file_name);
00121
00128 void add_quadratic_pair(char* pair);
00129
00135 virtual bool train_machine(CFeatures* feat = NULL);
00136
00144 virtual float32_t predict_and_finalize(VwExample* ex);
00145
00154 float32_t compute_exact_norm(VwExample* &ex, float32_t& sum_abs_x);
00155
00168 float32_t compute_exact_norm_quad(float32_t* weights, VwFeature& page_feature, v_array<VwFeature> &offer_features,
00169 vw_size_t mask, float32_t g, float32_t& sum_abs_x);
00170
00176 virtual CVwEnvironment* get_env()
00177 {
00178 SG_REF(env);
00179 return env;
00180 }
00181
00187 virtual const char* get_name() const { return "VowpalWabbit"; }
00188
00189 private:
00195 virtual void init(CStreamingVwFeatures* feat = NULL);
00196
00201 virtual void set_learner();
00202
00210 virtual float32_t inline_l1_predict(VwExample* &ex);
00211
00219 virtual float32_t inline_predict(VwExample* &ex);
00220
00228 virtual float32_t finalize_prediction(float32_t ret);
00229
00235 virtual void output_example(VwExample* &ex);
00236
00242 virtual void print_update(VwExample* &ex);
00243
00252 virtual void output_prediction(int32_t f, float32_t res, float32_t weight, v_array<char> tag);
00253
00259 void set_verbose(bool verbose);
00260
00261 protected:
00263 CStreamingVwFeatures* features;
00264
00266 CVwEnvironment* env;
00267
00269 CVwLearner* learner;
00270
00272 CVwRegressor* reg;
00273
00274 private:
00276 bool quiet;
00277
00279 bool no_training;
00280
00282 float32_t dump_interval;
00284 float32_t sum_loss_since_last_dump;
00286 float64_t old_weighted_examples;
00287
00289 char* reg_name;
00291 bool reg_dump_text;
00292
00294 bool save_predictions;
00296 int32_t prediction_fd;
00297 };
00298
00299 }
00300 #endif // _VOWPALWABBIT_H__