VwRegressor.cpp

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00013  */
00014 
00015 #include <shogun/classifier/vw/VwRegressor.h>
00016 #include <shogun/loss/SquaredLoss.h>
00017 #include <shogun/io/IOBuffer.h>
00018 
00019 using namespace shogun;
00020 
00021 CVwRegressor::CVwRegressor()
00022     : CSGObject()
00023 {
00024     weight_vectors = NULL;
00025     loss = new CSquaredLoss();
00026     init(NULL);
00027 }
00028 
00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use)
00030     : CSGObject()
00031 {
00032     weight_vectors = NULL;
00033     loss = new CSquaredLoss();
00034     init(env_to_use);
00035 }
00036 
00037 CVwRegressor::~CVwRegressor()
00038 {
00039     SG_FREE(weight_vectors);
00040     SG_UNREF(loss);
00041     SG_UNREF(env);
00042 }
00043 
00044 void CVwRegressor::init(CVwEnvironment* env_to_use)
00045 {
00046     if (!env_to_use)
00047         env_to_use = new CVwEnvironment();
00048 
00049     env = env_to_use;
00050     SG_REF(env);
00051 
00052     // For each feature, there should be 'stride' number of
00053     // elements in the weight vector
00054     vw_size_t length = ((vw_size_t) 1) << env->num_bits;
00055     env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
00056 
00057     // Only one learning thread for now
00058     vw_size_t num_threads = 1;
00059     weight_vectors = SG_MALLOC(float32_t*, num_threads);
00060 
00061     for (vw_size_t i = 0; i < num_threads; i++)
00062     {
00063         weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
00064 
00065         if (env->random_weights)
00066         {
00067             for (vw_size_t j = 0; j < length/num_threads; j++)
00068                 weight_vectors[i][j] = drand48() - 0.5;
00069         }
00070 
00071         if (env->initial_weight != 0.)
00072             for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
00073                 weight_vectors[i][j] = env->initial_weight;
00074 
00075         if (env->adaptive)
00076             for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
00077                 weight_vectors[i][j] = 1;
00078     }
00079 }
00080 
00081 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
00082 {
00083     CIOBuffer io_temp;
00084     int32_t f = io_temp.open_file(reg_name,'w');
00085 
00086     if (f < 0)
00087         SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name);
00088 
00089     const char* vw_version = env->vw_version;
00090     vw_size_t v_length = env->v_length;
00091 
00092     if (!as_text)
00093     {
00094         // Write version info
00095         io_temp.write_file((char*)&v_length, sizeof(v_length));
00096         io_temp.write_file(vw_version,v_length);
00097 
00098         // Write max and min labels
00099         io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
00100         io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
00101 
00102         // Write weight vector bits information
00103         io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
00104         io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
00105 
00106         // For paired namespaces forming quadratic features
00107         int32_t len = env->pairs.get_num_elements();
00108         io_temp.write_file((char *)&len, sizeof(len));
00109 
00110         for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
00111             io_temp.write_file(env->pairs.get_element(k), 2);
00112 
00113         // ngram and skips information
00114         io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
00115         io_temp.write_file((char*)&env->skips, sizeof(env->skips));
00116     }
00117     else
00118     {
00119         // Write as human readable form
00120         char buff[512];
00121         int32_t len;
00122 
00123         len = sprintf(buff, "Version %s\n", vw_version);
00124         io_temp.write_file(buff, len);
00125         len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
00126         io_temp.write_file(buff, len);
00127         len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
00128         io_temp.write_file(buff, len);
00129 
00130         if (env->pairs.get_num_elements() > 0)
00131         {
00132             len = sprintf(buff, "\n");
00133             io_temp.write_file(buff, len);
00134         }
00135 
00136         len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
00137         io_temp.write_file(buff, len);
00138     }
00139 
00140     uint32_t length = 1 << env->num_bits;
00141     vw_size_t num_threads = env->num_threads();
00142     vw_size_t stride = env->stride;
00143 
00144     // Write individual weights
00145     for(uint32_t i = 0; i < length; i++)
00146     {
00147         float32_t v;
00148         v = weight_vectors[i%num_threads][stride*(i/num_threads)];
00149         if (v != 0.)
00150         {
00151             if (!as_text)
00152             {
00153                 io_temp.write_file((char *)&i, sizeof (i));
00154                 io_temp.write_file((char *)&v, sizeof (v));
00155             }
00156             else
00157             {
00158                 char buff[512];
00159                 int32_t len = sprintf(buff, "%d:%f\n", i, v);
00160                 io_temp.write_file(buff, len);
00161             }
00162         }
00163     }
00164 
00165     io_temp.close_file();
00166 }
00167 
00168 void CVwRegressor::load_regressor(char* file)
00169 {
00170     CIOBuffer source;
00171     int32_t fd = source.open_file(file, 'r');
00172 
00173     if (fd < 0)
00174         SG_SERROR("Unable to open file for loading regressor!\n");
00175 
00176     // Read version info
00177     vw_size_t v_length;
00178     source.read_file((char*)&v_length, sizeof(v_length));
00179     char t[v_length];
00180     source.read_file(t,v_length);
00181     if (strcmp(t,env->vw_version) != 0)
00182         SG_SERROR("Regressor source has an incompatible VW version!\n");
00183 
00184     // Read min and max label
00185     source.read_file((char*)&env->min_label, sizeof(env->min_label));
00186     source.read_file((char*)&env->max_label, sizeof(env->max_label));
00187 
00188     // Read num_bits, multiple sources are not supported
00189     vw_size_t local_num_bits;
00190     source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
00191 
00192     if ((vw_size_t) env->num_bits != local_num_bits)
00193         SG_SERROR("Wrong number of bits in regressor source!\n");
00194 
00195     env->num_bits = local_num_bits;
00196 
00197     vw_size_t local_thread_bits;
00198     source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
00199 
00200     env->thread_bits = local_thread_bits;
00201 
00202     int32_t len;
00203     source.read_file((char *)&len, sizeof(len));
00204 
00205     // Read paired namespace information
00206     DynArray<char*> local_pairs;
00207     for (; len > 0; len--)
00208     {
00209         char pair[3];
00210         source.read_file(pair, sizeof(char)*2);
00211         pair[2]='\0';
00212         local_pairs.push_back(pair);
00213     }
00214 
00215     env->pairs = local_pairs;
00216 
00217     // Initialize the weight vector
00218     if (weight_vectors)
00219         SG_FREE(weight_vectors);
00220     init(env);
00221 
00222     vw_size_t local_ngram;
00223     source.read_file((char*)&local_ngram, sizeof(local_ngram));
00224     vw_size_t local_skips;
00225     source.read_file((char*)&local_skips, sizeof(local_skips));
00226 
00227     env->ngram = local_ngram;
00228     env->skips = local_skips;
00229 
00230     // Read individual weights
00231     vw_size_t stride = env->stride;
00232     while (true)
00233     {
00234         uint32_t hash;
00235         ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
00236         if (hash_bytes <= 0)
00237             break;
00238 
00239         float32_t w = 0.;
00240         ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
00241         if (weight_bytes <= 0)
00242             break;
00243 
00244         vw_size_t num_threads = env->num_threads();
00245 
00246         weight_vectors[hash % num_threads][(hash*stride)/num_threads]
00247             = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
00248     }
00249     source.close_file();
00250 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation