Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #include <shogun/classifier/vw/cache/VwNativeCacheReader.h>
00017
00018 using namespace shogun;
00019
00020 CVwNativeCacheReader::CVwNativeCacheReader()
00021 : CVwCacheReader(), int_size(6), char_size(2)
00022 {
00023 init();
00024 }
00025
00026 CVwNativeCacheReader::CVwNativeCacheReader(char * fname, CVwEnvironment* env_to_use)
00027 : CVwCacheReader(fname, env_to_use), int_size(6), char_size(2)
00028 {
00029 init();
00030 buf.use_file(fd);
00031 check_cache_metadata();
00032 }
00033
00034 CVwNativeCacheReader::CVwNativeCacheReader(int32_t f, CVwEnvironment* env_to_use)
00035 : CVwCacheReader(f, env_to_use), int_size(6), char_size(2)
00036 {
00037 init();
00038 buf.use_file(fd);
00039 check_cache_metadata();
00040 }
00041
00042 CVwNativeCacheReader::~CVwNativeCacheReader()
00043 {
00044 buf.close_file();
00045 }
00046
00047 void CVwNativeCacheReader::set_file(int32_t f)
00048 {
00049 if (fd > 0)
00050 buf.close_file();
00051
00052 fd = f;
00053 buf.use_file(fd);
00054 check_cache_metadata();
00055 }
00056
00057 void CVwNativeCacheReader::init()
00058 {
00059 neg_1 = 1;
00060 general = 2;
00061 }
00062
00063 void CVwNativeCacheReader::check_cache_metadata()
00064 {
00065 const char* vw_version=env->vw_version;
00066 vw_size_t numbits = env->num_bits;
00067
00068 vw_size_t v_length;
00069 buf.read_file((char*)&v_length, sizeof(v_length));
00070 if(v_length > 29)
00071 SG_SERROR("Cache version too long, cache file is probably invalid.\n");
00072
00073 char t[v_length];
00074 buf.read_file(t,v_length);
00075 if (strcmp(t,vw_version) != 0)
00076 SG_SERROR("Cache has possibly incompatible version!\n");
00077
00078 vw_size_t cache_numbits = 0;
00079 if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t)))
00080 return;
00081
00082 if (cache_numbits != numbits)
00083 SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits);
00084 }
00085
00086 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i)
00087 {
00088
00089 vw_size_t count = 0;
00090 while(*p & 128)\
00091 i = i | ((*(p++) & 127) << 7*count++);
00092 i = i | (*(p++) << 7*count);
00093 return p;
00094 }
00095
00096 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
00097 {
00098 ld->label = *(float32_t*)c;
00099 c += sizeof(ld->label);
00100 set_minmax(ld->label);
00101
00102 ld->weight = *(float32_t*)c;
00103 c += sizeof(ld->weight);
00104 ld->initial = *(float32_t*)c;
00105 c += sizeof(ld->initial);
00106
00107 return c;
00108 }
00109
00110 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld)
00111 {
00112 char *c;
00113 vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial);
00114 if (buf.buf_read(c, total) < total)
00115 return 0;
00116 c = bufread_label(ld,c);
00117
00118 return total;
00119 }
00120
00121 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae)
00122 {
00123 char* c;
00124 vw_size_t tag_size;
00125 if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size))
00126 return 0;
00127 tag_size = *(vw_size_t*)c;
00128 c += sizeof(tag_size);
00129
00130 buf.set(c);
00131 if (buf.buf_read(c, tag_size) < tag_size)
00132 return 0;
00133
00134 ae->tag.erase();
00135 ae->tag.push_many(c, tag_size);
00136 return tag_size+sizeof(tag_size);
00137 }
00138
00139 bool CVwNativeCacheReader::read_cached_example(VwExample* const ae)
00140 {
00141 vw_size_t mask = env->mask;
00142 vw_size_t total = read_cached_label(ae->ld);
00143 if (total == 0)
00144 return false;
00145 if (read_cached_tag(ae) == 0)
00146 return false;
00147
00148 char* c;
00149 unsigned char num_indices = 0;
00150 if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices))
00151 return false;
00152 num_indices = *(unsigned char*)c;
00153 c += sizeof(num_indices);
00154
00155 buf.set(c);
00156
00157 for (; num_indices > 0; num_indices--)
00158 {
00159 vw_size_t temp;
00160 unsigned char index = 0;
00161 temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t));
00162
00163 if (temp < sizeof(index) + sizeof(vw_size_t))
00164 SG_SERROR("Truncated example! %d < %d bytes expected.\n",
00165 temp, char_size + sizeof(vw_size_t));
00166
00167 index = *(unsigned char*) c;
00168 c += sizeof(index);
00169 ae->indices.push((vw_size_t) index);
00170
00171 v_array<VwFeature>* ours = ae->atomics+index;
00172 float64_t* our_sum_feat_sq = ae->sum_feat_sq+index;
00173 vw_size_t storage = *(vw_size_t *)c;
00174 c += sizeof(vw_size_t);
00175
00176 buf.set(c);
00177 total += storage;
00178 if (buf.buf_read(c, storage) < storage)
00179 SG_SERROR("Truncated example! Wanted %d bytes!\n", storage);
00180
00181 char *end = c + storage;
00182
00183 vw_size_t last = 0;
00184
00185 for (; c!=end; )
00186 {
00187 VwFeature f = {1., 0};
00188 temp = f.weight_index;
00189 c = run_len_decode(c, temp);
00190 f.weight_index = temp;
00191
00192 if (f.weight_index & neg_1)
00193 f.x = -1.;
00194 else if (f.weight_index & general)
00195 {
00196 f.x = ((one_float*)c)->f;
00197 c += sizeof(float32_t);
00198 }
00199
00200 *our_sum_feat_sq += f.x*f.x;
00201
00202 vw_size_t diff = f.weight_index >> 2;
00203 int32_t s_diff = ZigZagDecode(diff);
00204 if (s_diff < 0)
00205 ae->sorted = false;
00206
00207 f.weight_index = last + s_diff;
00208 last = f.weight_index;
00209 f.weight_index = f.weight_index & mask;
00210
00211 ours->push(f);
00212 }
00213 buf.set(c);
00214 }
00215
00216 return true;
00217 }