VwNativeCacheReader.cpp

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Adaptation of Vowpal Wabbit v5.1.
00013  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00014  */
00015 
00016 #include <shogun/classifier/vw/cache/VwNativeCacheReader.h>
00017 
00018 using namespace shogun;
00019 
00020 CVwNativeCacheReader::CVwNativeCacheReader()
00021     : CVwCacheReader(), int_size(6), char_size(2)
00022 {
00023     init();
00024 }
00025 
00026 CVwNativeCacheReader::CVwNativeCacheReader(char * fname, CVwEnvironment* env_to_use)
00027     : CVwCacheReader(fname, env_to_use), int_size(6), char_size(2)
00028 {
00029     init();
00030     buf.use_file(fd);
00031     check_cache_metadata();
00032 }
00033 
00034 CVwNativeCacheReader::CVwNativeCacheReader(int32_t f, CVwEnvironment* env_to_use)
00035     : CVwCacheReader(f, env_to_use), int_size(6), char_size(2)
00036 {
00037     init();
00038     buf.use_file(fd);
00039     check_cache_metadata();
00040 }
00041 
00042 CVwNativeCacheReader::~CVwNativeCacheReader()
00043 {
00044     buf.close_file();
00045 }
00046 
00047 void CVwNativeCacheReader::set_file(int32_t f)
00048 {
00049     if (fd > 0)
00050         buf.close_file();
00051 
00052     fd = f;
00053     buf.use_file(fd);
00054     check_cache_metadata();
00055 }
00056 
00057 void CVwNativeCacheReader::init()
00058 {
00059     neg_1 = 1;
00060     general = 2;
00061 }
00062 
00063 void CVwNativeCacheReader::check_cache_metadata()
00064 {
00065     const char* vw_version=env->vw_version;
00066     vw_size_t numbits = env->num_bits;
00067 
00068     vw_size_t v_length;
00069     buf.read_file((char*)&v_length, sizeof(v_length));
00070     if(v_length > 29)
00071         SG_SERROR("Cache version too long, cache file is probably invalid.\n");
00072 
00073     char t[v_length];
00074     buf.read_file(t,v_length);
00075     if (strcmp(t,vw_version) != 0)
00076         SG_SERROR("Cache has possibly incompatible version!\n");
00077 
00078     vw_size_t cache_numbits = 0;
00079     if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t)))
00080         return;
00081 
00082     if (cache_numbits != numbits)
00083         SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits);
00084 }
00085 
00086 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i)
00087 {
00088     // Read an int32_t 7 bits at a time.
00089     vw_size_t count = 0;
00090     while(*p & 128)\
00091         i = i | ((*(p++) & 127) << 7*count++);
00092     i = i | (*(p++) << 7*count);
00093     return p;
00094 }
00095 
00096 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
00097 {
00098     ld->label = *(float32_t*)c;
00099     c += sizeof(ld->label);
00100     set_minmax(ld->label);
00101 
00102     ld->weight = *(float32_t*)c;
00103     c += sizeof(ld->weight);
00104     ld->initial = *(float32_t*)c;
00105     c += sizeof(ld->initial);
00106 
00107     return c;
00108 }
00109 
00110 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld)
00111 {
00112     char *c;
00113     vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial);
00114     if (buf.buf_read(c, total) < total)
00115         return 0;
00116     c = bufread_label(ld,c);
00117 
00118     return total;
00119 }
00120 
00121 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae)
00122 {
00123     char* c;
00124     vw_size_t tag_size;
00125     if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size))
00126         return 0;
00127     tag_size = *(vw_size_t*)c;
00128     c += sizeof(tag_size);
00129 
00130     buf.set(c);
00131     if (buf.buf_read(c, tag_size) < tag_size)
00132         return 0;
00133 
00134     ae->tag.erase();
00135     ae->tag.push_many(c, tag_size);
00136     return tag_size+sizeof(tag_size);
00137 }
00138 
00139 bool CVwNativeCacheReader::read_cached_example(VwExample* const ae)
00140 {
00141     vw_size_t mask =  env->mask;
00142     vw_size_t total = read_cached_label(ae->ld);
00143     if (total == 0)
00144         return false;
00145     if (read_cached_tag(ae) == 0)
00146         return false;
00147 
00148     char* c;
00149     unsigned char num_indices = 0;
00150     if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices))
00151         return false;
00152     num_indices = *(unsigned char*)c;
00153     c += sizeof(num_indices);
00154 
00155     buf.set(c);
00156 
00157     for (; num_indices > 0; num_indices--)
00158     {
00159         vw_size_t temp;
00160         unsigned char index = 0;
00161         temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t));
00162 
00163         if (temp < sizeof(index) + sizeof(vw_size_t))
00164             SG_SERROR("Truncated example! %d < %d bytes expected.\n",
00165                   temp, char_size + sizeof(vw_size_t));
00166 
00167         index = *(unsigned char*) c;
00168         c += sizeof(index);
00169         ae->indices.push((vw_size_t) index);
00170 
00171         v_array<VwFeature>* ours = ae->atomics+index;
00172         float64_t* our_sum_feat_sq = ae->sum_feat_sq+index;
00173         vw_size_t storage = *(vw_size_t *)c;
00174         c += sizeof(vw_size_t);
00175 
00176         buf.set(c);
00177         total += storage;
00178         if (buf.buf_read(c, storage) < storage)
00179             SG_SERROR("Truncated example! Wanted %d bytes!\n", storage);
00180 
00181         char *end = c + storage;
00182 
00183         vw_size_t last = 0;
00184 
00185         for (; c!=end; )
00186         {
00187             VwFeature f = {1., 0};
00188             temp = f.weight_index;
00189             c = run_len_decode(c, temp);
00190             f.weight_index = temp;
00191 
00192             if (f.weight_index & neg_1)
00193                 f.x = -1.;
00194             else if (f.weight_index & general)
00195             {
00196                 f.x = ((one_float*)c)->f;
00197                 c += sizeof(float32_t);
00198             }
00199 
00200             *our_sum_feat_sq += f.x*f.x;
00201 
00202             vw_size_t diff = f.weight_index >> 2;
00203             int32_t s_diff = ZigZagDecode(diff);
00204             if (s_diff < 0)
00205                 ae->sorted = false;
00206 
00207             f.weight_index = last + s_diff;
00208             last = f.weight_index;
00209             f.weight_index = f.weight_index & mask;
00210 
00211             ours->push(f);
00212         }
00213         buf.set(c);
00214     }
00215 
00216     return true;
00217 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation