VwNativeCacheReader.cpp

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Adaptation of Vowpal Wabbit v5.1.
00013  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00014  */
00015 
00016 #include <shogun/classifier/vw/cache/VwNativeCacheReader.h>
00017 
00018 using namespace shogun;
00019 
00020 CVwNativeCacheReader::CVwNativeCacheReader()
00021     : CVwCacheReader(), int_size(6), char_size(2)
00022 {
00023     init();
00024 }
00025 
00026 CVwNativeCacheReader::CVwNativeCacheReader(char * fname, CVwEnvironment* env_to_use)
00027     : CVwCacheReader(fname, env_to_use), int_size(6), char_size(2)
00028 {
00029     init();
00030     buf.use_file(fd);
00031     check_cache_metadata();
00032 }
00033 
00034 CVwNativeCacheReader::CVwNativeCacheReader(int32_t f, CVwEnvironment* env_to_use)
00035     : CVwCacheReader(f, env_to_use), int_size(6), char_size(2)
00036 {
00037     init();
00038     buf.use_file(fd);
00039     check_cache_metadata();
00040 }
00041 
00042 CVwNativeCacheReader::~CVwNativeCacheReader()
00043 {
00044     buf.close_file();
00045 }
00046 
00047 void CVwNativeCacheReader::set_file(int32_t f)
00048 {
00049     if (fd > 0)
00050         buf.close_file();
00051 
00052     fd = f;
00053     buf.use_file(fd);
00054     check_cache_metadata();
00055 }
00056 
00057 void CVwNativeCacheReader::init()
00058 {
00059     neg_1 = 1;
00060     general = 2;
00061 }
00062 
00063 void CVwNativeCacheReader::check_cache_metadata()
00064 {
00065     const char* vw_version=env->vw_version;
00066     vw_size_t numbits = env->num_bits;
00067 
00068     vw_size_t v_length;
00069     buf.read_file((char*)&v_length, sizeof(v_length));
00070     if(v_length > 29)
00071         SG_SERROR("Cache version too long, cache file is probably invalid.\n");
00072 
00073     char* t=SG_MALLOC(char, v_length);
00074     buf.read_file(t,v_length);
00075     if (strcmp(t,vw_version) != 0)
00076     {
00077         SG_FREE(t);
00078         SG_SERROR("Cache has possibly incompatible version!\n");
00079     }
00080     SG_FREE(t);
00081 
00082     vw_size_t cache_numbits = 0;
00083     if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t)))
00084         return;
00085 
00086     if (cache_numbits != numbits)
00087         SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits);
00088 }
00089 
00090 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i)
00091 {
00092     // Read an int32_t 7 bits at a time.
00093     vw_size_t count = 0;
00094     while(*p & 128)\
00095         i = i | ((*(p++) & 127) << 7*count++);
00096     i = i | (*(p++) << 7*count);
00097     return p;
00098 }
00099 
00100 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
00101 {
00102     ld->label = *(float32_t*)c;
00103     c += sizeof(ld->label);
00104     set_minmax(ld->label);
00105 
00106     ld->weight = *(float32_t*)c;
00107     c += sizeof(ld->weight);
00108     ld->initial = *(float32_t*)c;
00109     c += sizeof(ld->initial);
00110 
00111     return c;
00112 }
00113 
00114 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld)
00115 {
00116     char *c;
00117     vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial);
00118     if (buf.buf_read(c, total) < total)
00119         return 0;
00120     c = bufread_label(ld,c);
00121 
00122     return total;
00123 }
00124 
00125 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae)
00126 {
00127     char* c;
00128     vw_size_t tag_size;
00129     if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size))
00130         return 0;
00131     tag_size = *(vw_size_t*)c;
00132     c += sizeof(tag_size);
00133 
00134     buf.set(c);
00135     if (buf.buf_read(c, tag_size) < tag_size)
00136         return 0;
00137 
00138     ae->tag.erase();
00139     ae->tag.push_many(c, tag_size);
00140     return tag_size+sizeof(tag_size);
00141 }
00142 
00143 bool CVwNativeCacheReader::read_cached_example(VwExample* const ae)
00144 {
00145     vw_size_t mask =  env->mask;
00146     vw_size_t total = read_cached_label(ae->ld);
00147     if (total == 0)
00148         return false;
00149     if (read_cached_tag(ae) == 0)
00150         return false;
00151 
00152     char* c;
00153     unsigned char num_indices = 0;
00154     if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices))
00155         return false;
00156     num_indices = *(unsigned char*)c;
00157     c += sizeof(num_indices);
00158 
00159     buf.set(c);
00160 
00161     for (; num_indices > 0; num_indices--)
00162     {
00163         vw_size_t temp;
00164         unsigned char index = 0;
00165         temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t));
00166 
00167         if (temp < sizeof(index) + sizeof(vw_size_t))
00168             SG_SERROR("Truncated example! %d < %d bytes expected.\n",
00169                   temp, char_size + sizeof(vw_size_t));
00170 
00171         index = *(unsigned char*) c;
00172         c += sizeof(index);
00173         ae->indices.push((vw_size_t) index);
00174 
00175         v_array<VwFeature>* ours = ae->atomics+index;
00176         float64_t* our_sum_feat_sq = ae->sum_feat_sq+index;
00177         vw_size_t storage = *(vw_size_t *)c;
00178         c += sizeof(vw_size_t);
00179 
00180         buf.set(c);
00181         total += storage;
00182         if (buf.buf_read(c, storage) < storage)
00183             SG_SERROR("Truncated example! Wanted %d bytes!\n", storage);
00184 
00185         char *end = c + storage;
00186 
00187         vw_size_t last = 0;
00188 
00189         for (; c!=end; )
00190         {
00191             VwFeature f = {1., 0};
00192             temp = f.weight_index;
00193             c = run_len_decode(c, temp);
00194             f.weight_index = temp;
00195 
00196             if (f.weight_index & neg_1)
00197                 f.x = -1.;
00198             else if (f.weight_index & general)
00199             {
00200                 f.x = ((one_float*)c)->f;
00201                 c += sizeof(float32_t);
00202             }
00203 
00204             *our_sum_feat_sq += f.x*f.x;
00205 
00206             vw_size_t diff = f.weight_index >> 2;
00207             int32_t s_diff = ZigZagDecode(diff);
00208             if (s_diff < 0)
00209                 ae->sorted = false;
00210 
00211             f.weight_index = last + s_diff;
00212             last = f.weight_index;
00213             f.weight_index = f.weight_index & mask;
00214 
00215             ours->push(f);
00216         }
00217         buf.set(c);
00218     }
00219 
00220     return true;
00221 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation