VwRegressor.cpp

Go to the documentation of this file.
00001 /*
00002  * Copyright (c) 2009 Yahoo! Inc.  All rights reserved.  The copyrights
00003  * embodied in the content of this file are licensed under the BSD
00004  * (revised) open source license.
00005  *
00006  * This program is free software; you can redistribute it and/or modify
00007  * it under the terms of the GNU General Public License as published by
00008  * the Free Software Foundation; either version 3 of the License, or
00009  * (at your option) any later version.
00010  *
00011  * Written (W) 2011 Shashwat Lal Das
00012  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
00013  */
00014 
00015 #include <shogun/classifier/vw/VwRegressor.h>
00016 #include <shogun/loss/SquaredLoss.h>
00017 #include <shogun/io/IOBuffer.h>
00018 
00019 using namespace shogun;
00020 
00021 CVwRegressor::CVwRegressor()
00022     : CSGObject()
00023 {
00024     weight_vectors = NULL;
00025     loss = new CSquaredLoss();
00026     init(NULL);
00027 }
00028 
00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use)
00030     : CSGObject()
00031 {
00032     weight_vectors = NULL;
00033     loss = new CSquaredLoss();
00034     init(env_to_use);
00035 }
00036 
00037 CVwRegressor::~CVwRegressor()
00038 {
00039     SG_FREE(weight_vectors);
00040     SG_UNREF(loss);
00041     SG_UNREF(env);
00042 }
00043 
00044 void CVwRegressor::init(CVwEnvironment* env_to_use)
00045 {
00046     if (!env_to_use)
00047         env_to_use = new CVwEnvironment();
00048 
00049     env = env_to_use;
00050     SG_REF(env);
00051 
00052     // For each feature, there should be 'stride' number of
00053     // elements in the weight vector
00054     vw_size_t length = ((vw_size_t) 1) << env->num_bits;
00055     env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
00056 
00057     // Only one learning thread for now
00058     vw_size_t num_threads = 1;
00059     weight_vectors = SG_MALLOC(float32_t*, num_threads);
00060 
00061     for (vw_size_t i = 0; i < num_threads; i++)
00062     {
00063         weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
00064 
00065         if (env->random_weights)
00066         {
00067             for (vw_size_t j = 0; j < length/num_threads; j++)
00068                 weight_vectors[i][j] = drand48() - 0.5;
00069         }
00070 
00071         if (env->initial_weight != 0.)
00072             for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
00073                 weight_vectors[i][j] = env->initial_weight;
00074 
00075         if (env->adaptive)
00076             for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
00077                 weight_vectors[i][j] = 1;
00078     }
00079 }
00080 
00081 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
00082 {
00083     CIOBuffer io_temp;
00084     int32_t f = io_temp.open_file(reg_name,'w');
00085 
00086     if (f < 0)
00087         SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name);
00088 
00089     const char* vw_version = env->vw_version;
00090     vw_size_t v_length = env->v_length;
00091 
00092     if (!as_text)
00093     {
00094         // Write version info
00095         io_temp.write_file((char*)&v_length, sizeof(v_length));
00096         io_temp.write_file(vw_version,v_length);
00097 
00098         // Write max and min labels
00099         io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
00100         io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
00101 
00102         // Write weight vector bits information
00103         io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
00104         io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
00105 
00106         // For paired namespaces forming quadratic features
00107         int32_t len = env->pairs.get_num_elements();
00108         io_temp.write_file((char *)&len, sizeof(len));
00109 
00110         for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
00111             io_temp.write_file(env->pairs.get_element(k), 2);
00112 
00113         // ngram and skips information
00114         io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
00115         io_temp.write_file((char*)&env->skips, sizeof(env->skips));
00116     }
00117     else
00118     {
00119         // Write as human readable form
00120         char buff[512];
00121         int32_t len;
00122 
00123         len = sprintf(buff, "Version %s\n", vw_version);
00124         io_temp.write_file(buff, len);
00125         len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
00126         io_temp.write_file(buff, len);
00127         len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
00128         io_temp.write_file(buff, len);
00129 
00130         if (env->pairs.get_num_elements() > 0)
00131         {
00132             len = sprintf(buff, "\n");
00133             io_temp.write_file(buff, len);
00134         }
00135 
00136         len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
00137         io_temp.write_file(buff, len);
00138     }
00139 
00140     uint32_t length = 1 << env->num_bits;
00141     vw_size_t num_threads = env->num_threads();
00142     vw_size_t stride = env->stride;
00143 
00144     // Write individual weights
00145     for(uint32_t i = 0; i < length; i++)
00146     {
00147         float32_t v;
00148         v = weight_vectors[i%num_threads][stride*(i/num_threads)];
00149         if (v != 0.)
00150         {
00151             if (!as_text)
00152             {
00153                 io_temp.write_file((char *)&i, sizeof (i));
00154                 io_temp.write_file((char *)&v, sizeof (v));
00155             }
00156             else
00157             {
00158                 char buff[512];
00159                 int32_t len = sprintf(buff, "%d:%f\n", i, v);
00160                 io_temp.write_file(buff, len);
00161             }
00162         }
00163     }
00164 
00165     io_temp.close_file();
00166 }
00167 
00168 void CVwRegressor::load_regressor(char* file)
00169 {
00170     CIOBuffer source;
00171     int32_t fd = source.open_file(file, 'r');
00172 
00173     if (fd < 0)
00174         SG_SERROR("Unable to open file for loading regressor!\n");
00175 
00176     // Read version info
00177     vw_size_t v_length;
00178     source.read_file((char*)&v_length, sizeof(v_length));
00179     char* t = SG_MALLOC(char, v_length);
00180     source.read_file(t,v_length);
00181     if (strcmp(t,env->vw_version) != 0)
00182     {
00183         SG_FREE(t);
00184         SG_SERROR("Regressor source has an incompatible VW version!\n");
00185     }
00186     SG_FREE(t);
00187 
00188     // Read min and max label
00189     source.read_file((char*)&env->min_label, sizeof(env->min_label));
00190     source.read_file((char*)&env->max_label, sizeof(env->max_label));
00191 
00192     // Read num_bits, multiple sources are not supported
00193     vw_size_t local_num_bits;
00194     source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
00195 
00196     if ((vw_size_t) env->num_bits != local_num_bits)
00197         SG_SERROR("Wrong number of bits in regressor source!\n");
00198 
00199     env->num_bits = local_num_bits;
00200 
00201     vw_size_t local_thread_bits;
00202     source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
00203 
00204     env->thread_bits = local_thread_bits;
00205 
00206     int32_t len;
00207     source.read_file((char *)&len, sizeof(len));
00208 
00209     // Read paired namespace information
00210     DynArray<char*> local_pairs;
00211     for (; len > 0; len--)
00212     {
00213         char pair[3];
00214         source.read_file(pair, sizeof(char)*2);
00215         pair[2]='\0';
00216         local_pairs.push_back(pair);
00217     }
00218 
00219     env->pairs = local_pairs;
00220 
00221     // Initialize the weight vector
00222     if (weight_vectors)
00223         SG_FREE(weight_vectors);
00224     init(env);
00225 
00226     vw_size_t local_ngram;
00227     source.read_file((char*)&local_ngram, sizeof(local_ngram));
00228     vw_size_t local_skips;
00229     source.read_file((char*)&local_skips, sizeof(local_skips));
00230 
00231     env->ngram = local_ngram;
00232     env->skips = local_skips;
00233 
00234     // Read individual weights
00235     vw_size_t stride = env->stride;
00236     while (true)
00237     {
00238         uint32_t hash;
00239         ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
00240         if (hash_bytes <= 0)
00241             break;
00242 
00243         float32_t w = 0.;
00244         ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
00245         if (weight_bytes <= 0)
00246             break;
00247 
00248         vw_size_t num_threads = env->num_threads();
00249 
00250         weight_vectors[hash % num_threads][(hash*stride)/num_threads]
00251             = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
00252     }
00253     source.close_file();
00254 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation