SHOGUN  v2.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
VwNativeCacheReader.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
3  * embodied in the content of this file are licensed under the BSD
4  * (revised) open source license.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Written (W) 2011 Shashwat Lal Das
12  * Adaptation of Vowpal Wabbit v5.1.
13  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
14  */
15 
17 
18 using namespace shogun;
19 
21  : CVwCacheReader(), int_size(6), char_size(2)
22 {
23  init();
24 }
25 
27  : CVwCacheReader(fname, env_to_use), int_size(6), char_size(2)
28 {
29  init();
30  buf.use_file(fd);
32 }
33 
35  : CVwCacheReader(f, env_to_use), int_size(6), char_size(2)
36 {
37  init();
38  buf.use_file(fd);
40 }
41 
43 {
44  buf.close_file();
45 }
46 
48 {
49  if (fd > 0)
50  buf.close_file();
51 
52  fd = f;
53  buf.use_file(fd);
55 }
56 
57 void CVwNativeCacheReader::init()
58 {
59  neg_1 = 1;
60  general = 2;
61 }
62 
64 {
65  const char* vw_version=env->vw_version;
66  vw_size_t numbits = env->num_bits;
67 
68  vw_size_t v_length;
69  buf.read_file((char*)&v_length, sizeof(v_length));
70  if(v_length > 29)
71  SG_SERROR("Cache version too long, cache file is probably invalid.\n");
72 
73  char* t=SG_MALLOC(char, v_length);
74  buf.read_file(t,v_length);
75  if (strcmp(t,vw_version) != 0)
76  {
77  SG_FREE(t);
78  SG_SERROR("Cache has possibly incompatible version!\n");
79  }
80  SG_FREE(t);
81 
82  vw_size_t cache_numbits = 0;
83  if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t)))
84  return;
85 
86  if (cache_numbits != numbits)
87  SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits);
88 }
89 
90 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i)
91 {
92  // Read an int32_t 7 bits at a time.
93  vw_size_t count = 0;
94  while(*p & 128)\
95  i = i | ((*(p++) & 127) << 7*count++);
96  i = i | (*(p++) << 7*count);
97  return p;
98 }
99 
100 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
101 {
102  ld->label = *(float32_t*)c;
103  c += sizeof(ld->label);
104  set_minmax(ld->label);
105 
106  ld->weight = *(float32_t*)c;
107  c += sizeof(ld->weight);
108  ld->initial = *(float32_t*)c;
109  c += sizeof(ld->initial);
110 
111  return c;
112 }
113 
114 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld)
115 {
116  char *c;
117  vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial);
118  if (buf.buf_read(c, total) < total)
119  return 0;
120  c = bufread_label(ld,c);
121 
122  return total;
123 }
124 
125 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae)
126 {
127  char* c;
128  vw_size_t tag_size;
129  if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size))
130  return 0;
131  tag_size = *(vw_size_t*)c;
132  c += sizeof(tag_size);
133 
134  buf.set(c);
135  if (buf.buf_read(c, tag_size) < tag_size)
136  return 0;
137 
138  ae->tag.erase();
139  ae->tag.push_many(c, tag_size);
140  return tag_size+sizeof(tag_size);
141 }
142 
144 {
145  vw_size_t mask = env->mask;
146  vw_size_t total = read_cached_label(ae->ld);
147  if (total == 0)
148  return false;
149  if (read_cached_tag(ae) == 0)
150  return false;
151 
152  char* c;
153  unsigned char num_indices = 0;
154  if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices))
155  return false;
156  num_indices = *(unsigned char*)c;
157  c += sizeof(num_indices);
158 
159  buf.set(c);
160 
161  for (; num_indices > 0; num_indices--)
162  {
163  vw_size_t temp;
164  unsigned char index = 0;
165  temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t));
166 
167  if (temp < sizeof(index) + sizeof(vw_size_t))
168  SG_SERROR("Truncated example! %d < %d bytes expected.\n",
169  temp, char_size + sizeof(vw_size_t));
170 
171  index = *(unsigned char*) c;
172  c += sizeof(index);
173  ae->indices.push((vw_size_t) index);
174 
175  v_array<VwFeature>* ours = ae->atomics+index;
176  float64_t* our_sum_feat_sq = ae->sum_feat_sq+index;
177  vw_size_t storage = *(vw_size_t *)c;
178  c += sizeof(vw_size_t);
179 
180  buf.set(c);
181  total += storage;
182  if (buf.buf_read(c, storage) < storage)
183  SG_SERROR("Truncated example! Wanted %d bytes!\n", storage);
184 
185  char *end = c + storage;
186 
187  vw_size_t last = 0;
188 
189  for (; c!=end; )
190  {
191  VwFeature f = {1., 0};
192  temp = f.weight_index;
193  c = run_len_decode(c, temp);
194  f.weight_index = temp;
195 
196  if (f.weight_index & neg_1)
197  f.x = -1.;
198  else if (f.weight_index & general)
199  {
200  f.x = ((one_float*)c)->f;
201  c += sizeof(float32_t);
202  }
203 
204  *our_sum_feat_sq += f.x*f.x;
205 
206  vw_size_t diff = f.weight_index >> 2;
207  int32_t s_diff = ZigZagDecode(diff);
208  if (s_diff < 0)
209  ae->sorted = false;
210 
211  f.weight_index = last + s_diff;
212  last = f.weight_index;
213  f.weight_index = f.weight_index & mask;
214 
215  ours->push(f);
216  }
217  buf.set(c);
218  }
219 
220  return true;
221 }

SHOGUN Machine Learning Toolbox - Documentation