Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016 #include <shogun/classifier/vw/cache/VwNativeCacheReader.h>
00017
00018 using namespace shogun;
00019
00020 CVwNativeCacheReader::CVwNativeCacheReader()
00021 : CVwCacheReader(), int_size(6), char_size(2)
00022 {
00023 init();
00024 }
00025
00026 CVwNativeCacheReader::CVwNativeCacheReader(char * fname, CVwEnvironment* env_to_use)
00027 : CVwCacheReader(fname, env_to_use), int_size(6), char_size(2)
00028 {
00029 init();
00030 buf.use_file(fd);
00031 check_cache_metadata();
00032 }
00033
00034 CVwNativeCacheReader::CVwNativeCacheReader(int32_t f, CVwEnvironment* env_to_use)
00035 : CVwCacheReader(f, env_to_use), int_size(6), char_size(2)
00036 {
00037 init();
00038 buf.use_file(fd);
00039 check_cache_metadata();
00040 }
00041
00042 CVwNativeCacheReader::~CVwNativeCacheReader()
00043 {
00044 buf.close_file();
00045 }
00046
00047 void CVwNativeCacheReader::set_file(int32_t f)
00048 {
00049 if (fd > 0)
00050 buf.close_file();
00051
00052 fd = f;
00053 buf.use_file(fd);
00054 check_cache_metadata();
00055 }
00056
00057 void CVwNativeCacheReader::init()
00058 {
00059 neg_1 = 1;
00060 general = 2;
00061 }
00062
00063 void CVwNativeCacheReader::check_cache_metadata()
00064 {
00065 const char* vw_version=env->vw_version;
00066 vw_size_t numbits = env->num_bits;
00067
00068 vw_size_t v_length;
00069 buf.read_file((char*)&v_length, sizeof(v_length));
00070 if(v_length > 29)
00071 SG_SERROR("Cache version too long, cache file is probably invalid.\n");
00072
00073 char* t=SG_MALLOC(char, v_length);
00074 buf.read_file(t,v_length);
00075 if (strcmp(t,vw_version) != 0)
00076 {
00077 SG_FREE(t);
00078 SG_SERROR("Cache has possibly incompatible version!\n");
00079 }
00080 SG_FREE(t);
00081
00082 vw_size_t cache_numbits = 0;
00083 if (buf.read_file(&cache_numbits, sizeof(vw_size_t)) < ssize_t(sizeof(vw_size_t)))
00084 return;
00085
00086 if (cache_numbits != numbits)
00087 SG_SERROR("Bug encountered in caching! Bits used for weight in cache: %d.\n", cache_numbits);
00088 }
00089
00090 char* CVwNativeCacheReader::run_len_decode(char *p, vw_size_t& i)
00091 {
00092
00093 vw_size_t count = 0;
00094 while(*p & 128)\
00095 i = i | ((*(p++) & 127) << 7*count++);
00096 i = i | (*(p++) << 7*count);
00097 return p;
00098 }
00099
00100 char* CVwNativeCacheReader::bufread_label(VwLabel* const ld, char* c)
00101 {
00102 ld->label = *(float32_t*)c;
00103 c += sizeof(ld->label);
00104 set_minmax(ld->label);
00105
00106 ld->weight = *(float32_t*)c;
00107 c += sizeof(ld->weight);
00108 ld->initial = *(float32_t*)c;
00109 c += sizeof(ld->initial);
00110
00111 return c;
00112 }
00113
00114 vw_size_t CVwNativeCacheReader::read_cached_label(VwLabel* const ld)
00115 {
00116 char *c;
00117 vw_size_t total = sizeof(ld->label)+sizeof(ld->weight)+sizeof(ld->initial);
00118 if (buf.buf_read(c, total) < total)
00119 return 0;
00120 c = bufread_label(ld,c);
00121
00122 return total;
00123 }
00124
00125 vw_size_t CVwNativeCacheReader::read_cached_tag(VwExample* const ae)
00126 {
00127 char* c;
00128 vw_size_t tag_size;
00129 if (buf.buf_read(c, sizeof(tag_size)) < sizeof(tag_size))
00130 return 0;
00131 tag_size = *(vw_size_t*)c;
00132 c += sizeof(tag_size);
00133
00134 buf.set(c);
00135 if (buf.buf_read(c, tag_size) < tag_size)
00136 return 0;
00137
00138 ae->tag.erase();
00139 ae->tag.push_many(c, tag_size);
00140 return tag_size+sizeof(tag_size);
00141 }
00142
00143 bool CVwNativeCacheReader::read_cached_example(VwExample* const ae)
00144 {
00145 vw_size_t mask = env->mask;
00146 vw_size_t total = read_cached_label(ae->ld);
00147 if (total == 0)
00148 return false;
00149 if (read_cached_tag(ae) == 0)
00150 return false;
00151
00152 char* c;
00153 unsigned char num_indices = 0;
00154 if (buf.buf_read(c, sizeof(num_indices)) < sizeof(num_indices))
00155 return false;
00156 num_indices = *(unsigned char*)c;
00157 c += sizeof(num_indices);
00158
00159 buf.set(c);
00160
00161 for (; num_indices > 0; num_indices--)
00162 {
00163 vw_size_t temp;
00164 unsigned char index = 0;
00165 temp = buf.buf_read(c, sizeof(index) + sizeof(vw_size_t));
00166
00167 if (temp < sizeof(index) + sizeof(vw_size_t))
00168 SG_SERROR("Truncated example! %d < %d bytes expected.\n",
00169 temp, char_size + sizeof(vw_size_t));
00170
00171 index = *(unsigned char*) c;
00172 c += sizeof(index);
00173 ae->indices.push((vw_size_t) index);
00174
00175 v_array<VwFeature>* ours = ae->atomics+index;
00176 float64_t* our_sum_feat_sq = ae->sum_feat_sq+index;
00177 vw_size_t storage = *(vw_size_t *)c;
00178 c += sizeof(vw_size_t);
00179
00180 buf.set(c);
00181 total += storage;
00182 if (buf.buf_read(c, storage) < storage)
00183 SG_SERROR("Truncated example! Wanted %d bytes!\n", storage);
00184
00185 char *end = c + storage;
00186
00187 vw_size_t last = 0;
00188
00189 for (; c!=end; )
00190 {
00191 VwFeature f = {1., 0};
00192 temp = f.weight_index;
00193 c = run_len_decode(c, temp);
00194 f.weight_index = temp;
00195
00196 if (f.weight_index & neg_1)
00197 f.x = -1.;
00198 else if (f.weight_index & general)
00199 {
00200 f.x = ((one_float*)c)->f;
00201 c += sizeof(float32_t);
00202 }
00203
00204 *our_sum_feat_sq += f.x*f.x;
00205
00206 vw_size_t diff = f.weight_index >> 2;
00207 int32_t s_diff = ZigZagDecode(diff);
00208 if (s_diff < 0)
00209 ae->sorted = false;
00210
00211 f.weight_index = last + s_diff;
00212 last = f.weight_index;
00213 f.weight_index = f.weight_index & mask;
00214
00215 ours->push(f);
00216 }
00217 buf.set(c);
00218 }
00219
00220 return true;
00221 }