StreamingVwFeatures.h

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Adaptation of Vowpal Wabbit v5.1.
00013  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00014  */
00015 
00016 #ifndef _STREAMING_VWFEATURES__H__
00017 #define _STREAMING_VWFEATURES__H__
00018 
00019 #include <shogun/lib/common.h>
00020 #include <shogun/lib/DataType.h>
00021 #include <shogun/mathematics/Math.h>
00022 
00023 #include <shogun/io/InputParser.h>
00024 #include <shogun/io/StreamingVwFile.h>
00025 #include <shogun/io/StreamingVwCacheFile.h>
00026 #include <shogun/features/StreamingDotFeatures.h>
00027 #include <shogun/classifier/vw/vw_common.h>
00028 #include <shogun/classifier/vw/vw_math.h>
00029 
00030 namespace shogun
00031 {
00039 class CStreamingVwFeatures : public CStreamingDotFeatures
00040 {
00041 public:
00042 
00050     CStreamingVwFeatures()
00051         : CStreamingDotFeatures()
00052     {
00053         init();
00054         set_read_functions();
00055     }
00056 
00065     CStreamingVwFeatures(CStreamingVwFile* file,
00066                  bool is_labelled,
00067                  int32_t size)
00068         : CStreamingDotFeatures()
00069     {
00070         init(file, is_labelled, size);
00071         set_read_functions();
00072     }
00073 
00082     CStreamingVwFeatures(CStreamingVwCacheFile* file,
00083                  bool is_labelled,
00084                  int32_t size)
00085         : CStreamingDotFeatures()
00086     {
00087         init(file, is_labelled, size);
00088         set_read_functions();
00089     }
00090 
00096     ~CStreamingVwFeatures()
00097     {
00098         parser.end_parser();
00099         SG_UNREF(env);
00100     }
00101 
00107     CFeatures* duplicate() const
00108     {
00109         return new CStreamingVwFeatures(*this);
00110     }
00111 
00121     virtual void set_vector_reader();
00122 
00132     virtual void set_vector_and_label_reader();
00133 
00139     virtual void start_parser();
00140 
00146     virtual void end_parser();
00147 
00152     virtual void reset_stream()
00153     {
00154         if (working_file->is_seekable())
00155         {
00156             working_file->reset_stream();
00157             parser.exit_parser();
00158             parser.init(working_file, has_labels, parser.get_ring_size());
00159             parser.set_free_vector_after_release(false);
00160             parser.start_parser();
00161         }
00162         else
00163             SG_ERROR("The input cannot be reset! Please use 1 pass.\n");
00164     }
00165 
00170     virtual CVwEnvironment* get_env()
00171     {
00172         SG_REF(env);
00173         return env;
00174     }
00175 
00181     virtual void set_env(CVwEnvironment* vw_env)
00182     {
00183         env = vw_env;
00184         SG_REF(env);
00185     }
00186 
00195     virtual bool get_next_example();
00196 
00202     virtual VwExample* get_example();
00203 
00211     virtual float64_t get_label();
00212 
00219     virtual void release_example();
00220 
00229     inline virtual void expand_if_required(float32_t*& vec, int32_t& len)
00230     {
00231         int32_t dim = 1 << env->num_bits;
00232         if (dim > len)
00233         {
00234             vec = SG_REALLOC(float32_t, vec, dim);
00235             memset(&vec[len], 0, (dim-len) * sizeof(float32_t));
00236             len = dim;
00237         }
00238     }
00239 
00248     inline virtual void expand_if_required(float64_t*& vec, int32_t& len)
00249     {
00250         int32_t dim = 1 << env->num_bits;
00251         if (dim > len)
00252         {
00253             vec = SG_REALLOC(float64_t, vec, dim);
00254             memset(&vec[len], 0, (dim-len) * sizeof(float64_t));
00255             len = dim;
00256         }
00257     }
00258 
00266     virtual int32_t get_dim_feature_space() const;
00267 
00276     inline virtual float32_t real_weight(float32_t w, float32_t gravity)
00277     {
00278         float32_t wprime = 0;
00279         if (gravity < fabsf(w))
00280             wprime = CMath::sign(w)*(fabsf(w) - gravity);
00281         return wprime;
00282     }
00283 
00294     virtual float32_t dot(CStreamingDotFeatures *df);
00295 
00304     virtual float32_t dense_dot(VwExample* &ex, const float32_t* vec2);
00305 
00315     virtual float32_t dense_dot(const float32_t* vec2, int32_t vec2_len);
00316 
00326     virtual float32_t dense_dot(SGSparseVector<float32_t>* vec1, const float32_t* vec2);
00327 
00338     virtual float32_t dense_dot_truncated(const float32_t* vec2, VwExample* &ex, float32_t gravity);
00339 
00350     virtual void add_to_dense_vec(float32_t alpha, VwExample* &ex, float32_t* vec2, int32_t vec2_len, bool abs_val = false);
00351 
00361     virtual void add_to_dense_vec(float32_t alpha, float32_t* vec2, int32_t vec2_len, bool abs_val = false);
00362 
00367     virtual inline int32_t get_nnz_features_for_vector()
00368     {
00369         return current_length;
00370     }
00371 
00377     virtual int32_t get_num_features();
00378 
00384     virtual inline EFeatureType get_feature_type();
00385 
00391     virtual EFeatureClass get_feature_class();
00392 
00398     inline virtual const char* get_name() const { return "StreamingVwFeatures"; }
00399 
00405     inline virtual int32_t get_num_vectors() const
00406     {
00407         if (current_example)
00408             return 1;
00409         else
00410             return 0;
00411     }
00412 
00418     virtual int32_t get_size() { return sizeof(VwExample); }
00419 
00420 private:
00425     virtual void init();
00426 
00434     virtual void init(CStreamingVwFile *file, bool is_labelled, int32_t size);
00435 
00443     virtual void init(CStreamingVwCacheFile *file, bool is_labelled, int32_t size);
00444 
00451     virtual void setup_example(VwExample* ae);
00452 
00453 protected:
00454 
00456     CInputParser<VwExample> parser;
00457 
00459     vw_size_t example_count;
00460 
00462     float64_t current_label;
00463 
00465     int32_t current_length;
00466 
00468     CVwEnvironment* env;
00469 
00471     VwExample* current_example;
00472 };
00473 }
00474 #endif // _STREAMING_VWFEATURES__H__
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation