VwParser.cpp

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Adaptation of Vowpal Wabbit v5.1.
00013  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00014  */
00015 
00016 #include <shogun/classifier/vw/VwParser.h>
00017 #include <shogun/classifier/vw/cache/VwNativeCacheWriter.h>
00018 
00019 using namespace shogun;
00020 
00021 CVwParser::CVwParser()
00022     : CSGObject()
00023 {
00024     env = new CVwEnvironment();
00025     hasher = CHash::MurmurHashString;
00026     write_cache = false;
00027     cache_writer = NULL;
00028 }
00029 
00030 CVwParser::CVwParser(CVwEnvironment* env_to_use)
00031     : CSGObject()
00032 {
00033     ASSERT(env_to_use);
00034 
00035     env = env_to_use;
00036     hasher = CHash::MurmurHashString;
00037     write_cache = false;
00038     cache_writer = NULL;
00039     SG_REF(env);
00040 }
00041 
00042 CVwParser::~CVwParser()
00043 {
00044     SG_UNREF(env);
00045     SG_UNREF(cache_writer);
00046 }
00047 
00048 int32_t CVwParser::read_features(CIOBuffer* buf, VwExample*& ae)
00049 {
00050     char *line=NULL;
00051     int32_t num_chars = buf->read_line(line);
00052     if (num_chars == 0)
00053         return num_chars;
00054 
00055     /* Mark begin and end of example in the buffer */
00056     substring example_string = {line, line + num_chars};
00057 
00058     /* Channels containing separate namespaces/label information*/
00059     channels.erase();
00060 
00061     /* Split at '|' character */
00062     tokenize('|', example_string, channels);
00063 
00064     /* If first char is not '|', then the first channel contains label data */
00065     substring* feature_start = &channels[1];
00066 
00067     if (*line == '|')
00068         feature_start = &channels[0]; /* Unlabelled data */
00069     else
00070     {
00071         /* First channel has label info */
00072         substring label_space = channels[0];
00073         char* tab_location = safe_index(label_space.start, '\t', label_space.end);
00074         if (tab_location != label_space.end)
00075             label_space.start = tab_location+1;
00076 
00077         /* Split the label space on spaces */
00078         tokenize(' ',label_space,words);
00079         if (words.index() > 0 && words.last().end == label_space.end) //The last field is a tag, so record and strip it off
00080         {
00081             substring tag = words.pop();
00082             ae->tag.push_many(tag.start, tag.end - tag.start);
00083         }
00084 
00085         ae->ld->label_from_substring(words);
00086         set_minmax(ae->ld->label);
00087     }
00088 
00089     vw_size_t mask = env->mask;
00090 
00091     /* Now parse the individual channels, i.e., namespaces */
00092     for (substring* i = feature_start; i != channels.end; i++)
00093     {
00094         substring channel = *i;
00095 
00096         tokenize(' ',channel, words);
00097         if (words.begin == words.end)
00098             continue;
00099 
00100         /* Set default scale value for channel */
00101         float32_t channel_v = 1.;
00102         vw_size_t channel_hash;
00103 
00104         /* Index by which to refer to the namespace */
00105         vw_size_t index = 0;
00106         bool new_index = false;
00107         vw_size_t feature_offset = 0;
00108 
00109         if (channel.start[0] != ' ')
00110         {
00111             /* Nonanonymous namespace specified */
00112             feature_offset++;
00113             feature_value(words[0], name, channel_v);
00114 
00115             if (name.index() > 0)
00116             {
00117                 index = (unsigned char)(*name[0].start);
00118                 if (ae->atomics[index].begin == ae->atomics[index].end)
00119                 {
00120                     ae->sum_feat_sq[index] = 0;
00121                     new_index = true;
00122                 }
00123             }
00124             channel_hash = hasher(name[0], hash_base);
00125         }
00126         else
00127         {
00128             /* Use default namespace with index below */
00129             index = (unsigned char)' ';
00130             if (ae->atomics[index].begin == ae->atomics[index].end)
00131             {
00132                 ae->sum_feat_sq[index] = 0;
00133                 new_index = true;
00134             }
00135             channel_hash = 0;
00136         }
00137 
00138         for (substring* j = words.begin+feature_offset; j != words.end; j++)
00139         {
00140             /* Get individual features and multiply by scale value */
00141             float32_t v = 0.0;
00142             feature_value(*j, name, v);
00143             v *= channel_v;
00144 
00145             /* Hash feature */
00146             vw_size_t word_hash = (hasher(name[0], channel_hash)) & mask;
00147             VwFeature f = {v,word_hash};
00148             ae->sum_feat_sq[index] += v*v;
00149             ae->atomics[index].push(f);
00150         }
00151 
00152         /* Add index to list of indices if required */
00153         if (new_index && ae->atomics[index].begin != ae->atomics[index].end)
00154             ae->indices.push(index);
00155 
00156     }
00157 
00158     if (write_cache)
00159         cache_writer->cache_example(ae);
00160 
00161     return num_chars;
00162 }
00163 
00164 int32_t CVwParser::read_svmlight_features(CIOBuffer* buf, VwExample*& ae)
00165 {
00166     char *line=NULL;
00167     int32_t num_chars = buf->read_line(line);
00168     if (num_chars == 0)
00169         return num_chars;
00170 
00171     /* Mark begin and end of example in the buffer */
00172     substring example_string = {line, line + num_chars};
00173 
00174     vw_size_t mask = env->mask;
00175     tokenize(' ', example_string, words);
00176 
00177     ae->ld->label = float_of_substring(words[0]);
00178     ae->ld->weight = 1.;
00179     ae->ld->initial = 0.;
00180     set_minmax(ae->ld->label);
00181 
00182     substring* feature_start = &words[1];
00183 
00184     vw_size_t index = (unsigned char)' ';   // Any default namespace is ok
00185     vw_size_t channel_hash = 0;
00186     ae->sum_feat_sq[index] = 0;
00187     ae->indices.push(index);
00188     /* Now parse the individual features */
00189     for (substring* i = feature_start; i != words.end; i++)
00190     {
00191         float32_t v;
00192         feature_value(*i, name, v);
00193 
00194         vw_size_t word_hash = (hasher(name[0], channel_hash)) & mask;
00195         VwFeature f = {v,word_hash};
00196         ae->sum_feat_sq[index] += v*v;
00197         ae->atomics[index].push(f);
00198     }
00199 
00200     if (write_cache)
00201         cache_writer->cache_example(ae);
00202 
00203     return num_chars;
00204 }
00205 
00206 int32_t CVwParser::read_dense_features(CIOBuffer* buf, VwExample*& ae)
00207 {
00208     char *line=NULL;
00209     int32_t num_chars = buf->read_line(line);
00210     if (num_chars == 0)
00211         return num_chars;
00212 
00213     // Mark begin and end of example in the buffer
00214     substring example_string = {line, line + num_chars};
00215 
00216     vw_size_t mask = env->mask;
00217     tokenize(' ', example_string, words);
00218 
00219     ae->ld->label = float_of_substring(words[0]);
00220     ae->ld->weight = 1.;
00221     ae->ld->initial = 0.;
00222     set_minmax(ae->ld->label);
00223 
00224     substring* feature_start = &words[1];
00225 
00226     vw_size_t index = (unsigned char)' ';
00227 
00228     ae->sum_feat_sq[index] = 0;
00229     ae->indices.push(index);
00230     // Now parse individual features
00231     int32_t j=0;
00232     for (substring* i = feature_start; i != words.end; i++)
00233     {
00234         float32_t v = float_of_substring(*i);
00235         vw_size_t word_hash = j & mask;
00236         VwFeature f = {v,word_hash};
00237         ae->sum_feat_sq[index] += v*v;
00238         ae->atomics[index].push(f);
00239         j++;
00240     }
00241 
00242     if (write_cache)
00243         cache_writer->cache_example(ae);
00244 
00245     return num_chars;
00246 }
00247 
00248 void CVwParser::init_cache(char * fname, EVwCacheType type)
00249 {
00250     char* file_name = fname;
00251     char default_cache_name[] = "vw_cache.dat.cache";
00252 
00253     if (!fname)
00254         file_name = default_cache_name;
00255 
00256     write_cache = true;
00257     cache_type = type;
00258 
00259     switch (type)
00260     {
00261     case C_NATIVE:
00262         cache_writer = new CVwNativeCacheWriter(file_name, env);
00263         return;
00264     case C_PROTOBUF:
00265         SG_ERROR("Protocol buffers cache support is not implemented yet.\n");
00266     }
00267 
00268     SG_ERROR("Unexpected cache type specified!\n");
00269 }
00270 
00271 void CVwParser::feature_value(substring &s, v_array<substring>& feat_name, float32_t &v)
00272 {
00273     // Get the value of the feature in the substring
00274     tokenize(':', s, feat_name);
00275 
00276     switch (feat_name.index())
00277     {
00278     // If feature value is not specified, assume 1.0
00279     case 0:
00280     case 1:
00281         v = 1.;
00282         break;
00283     case 2:
00284         v = float_of_substring(feat_name[1]);
00285         if (isnan(v))
00286             SG_SERROR("error NaN value for feature %s! Terminating!\n",
00287                   c_string_of_substring(feat_name[0]));
00288         break;
00289     default:
00290         SG_SERROR("Examples with a weird name, i.e., '%s'\n",
00291               c_string_of_substring(s));
00292     }
00293 }
00294 
00295 void CVwParser::tokenize(char delim, substring s, v_array<substring>& ret)
00296 {
00297     ret.erase();
00298     char *last = s.start;
00299     for (; s.start != s.end; s.start++)
00300     {
00301         if (*s.start == delim)
00302         {
00303             if (s.start != last)
00304             {
00305                 substring temp = {last,s.start};
00306                 ret.push(temp);
00307             }
00308             last = s.start+1;
00309         }
00310     }
00311     if (s.start != last)
00312     {
00313         substring final = {last, s.start};
00314         ret.push(final);
00315     }
00316 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation