SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
VwRegressor.cpp
Go to the documentation of this file.
1 /*
2  * Copyright (c) 2009 Yahoo! Inc. All rights reserved. The copyrights
3  * embodied in the content of this file are licensed under the BSD
4  * (revised) open source license.
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * Written (W) 2011 Shashwat Lal Das
12  * Copyright (C) 2011 Berlin Institute of Technology and Max-Planck-Society.
13  */
14 
17 #include <shogun/io/IOBuffer.h>
18 
19 using namespace shogun;
20 
22  : CSGObject()
23 {
24  weight_vectors = NULL;
25  loss = new CSquaredLoss();
26  init(NULL);
27 }
28 
30  : CSGObject()
31 {
32  weight_vectors = NULL;
33  loss = new CSquaredLoss();
34  init(env_to_use);
35 }
36 
38 {
39  // TODO: the number of weight_vectors depends on num_threads
40  // this should be reimplemented using SGVector (for reference counting)
41  if (weight_vectors)
42  {
43  vw_size_t num_threads = 1;
44  for (vw_size_t i = 0; i < num_threads; i++)
45  {
46  SG_FREE(weight_vectors[i]);
47  }
48  }
49 
50  SG_FREE(weight_vectors);
51  SG_UNREF(loss);
52  SG_UNREF(env);
53 }
54 
55 void CVwRegressor::init(CVwEnvironment* env_to_use)
56 {
57  if (!env_to_use)
58  env_to_use = new CVwEnvironment();
59 
60  env = env_to_use;
61  SG_REF(env);
62 
63  // For each feature, there should be 'stride' number of
64  // elements in the weight vector
65  vw_size_t length = ((vw_size_t) 1) << env->num_bits;
66  env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
67 
68  // Only one learning thread for now
69  vw_size_t num_threads = 1;
70  weight_vectors = SG_MALLOC(float32_t*, num_threads);
71 
72  for (vw_size_t i = 0; i < num_threads; i++)
73  {
74  weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
75 
76  if (env->random_weights)
77  {
78  for (vw_size_t j = 0; j < length/num_threads; j++)
79  weight_vectors[i][j] = CMath::random(-0.5, 0.5);
80  }
81 
82  if (env->initial_weight != 0.)
83  for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
85 
86  if (env->adaptive)
87  for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
88  weight_vectors[i][j] = 1;
89  }
90 }
91 
92 // TODO: remove this, as we have serialization FW
93 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
94 {
95  CIOBuffer io_temp;
96  int32_t f = io_temp.open_file(reg_name,'w');
97 
98  if (f < 0)
99  SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name)
100 
101  const char* vw_version = env->vw_version;
102  vw_size_t v_length = env->v_length;
103 
104  if (!as_text)
105  {
106  // Write version info
107  io_temp.write_file((char*)&v_length, sizeof(v_length));
108  io_temp.write_file(vw_version,v_length);
109 
110  // Write max and min labels
111  io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
112  io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
113 
114  // Write weight vector bits information
115  io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
116  io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
117 
118  // For paired namespaces forming quadratic features
119  int32_t len = env->pairs.get_num_elements();
120  io_temp.write_file((char *)&len, sizeof(len));
121 
122  for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
123  io_temp.write_file(env->pairs.get_element(k), 2);
124 
125  // ngram and skips information
126  io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
127  io_temp.write_file((char*)&env->skips, sizeof(env->skips));
128  }
129  else
130  {
131  // Write as human readable form
132  char buff[512];
133  int32_t len;
134 
135  len = sprintf(buff, "Version %s\n", vw_version);
136  io_temp.write_file(buff, len);
137  len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
138  io_temp.write_file(buff, len);
139  len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
140  io_temp.write_file(buff, len);
141 
142  if (env->pairs.get_num_elements() > 0)
143  {
144  len = sprintf(buff, "\n");
145  io_temp.write_file(buff, len);
146  }
147 
148  len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
149  io_temp.write_file(buff, len);
150  }
151 
152  uint32_t length = 1 << env->num_bits;
153  vw_size_t num_threads = env->num_threads();
154  vw_size_t stride = env->stride;
155 
156  // Write individual weights
157  for(uint32_t i = 0; i < length; i++)
158  {
159  float32_t v;
160  v = weight_vectors[i%num_threads][stride*(i/num_threads)];
161  if (v != 0.)
162  {
163  if (!as_text)
164  {
165  io_temp.write_file((char *)&i, sizeof (i));
166  io_temp.write_file((char *)&v, sizeof (v));
167  }
168  else
169  {
170  char buff[512];
171  int32_t len = sprintf(buff, "%d:%f\n", i, v);
172  io_temp.write_file(buff, len);
173  }
174  }
175  }
176 
177  io_temp.close_file();
178 }
179 
180 // TODO: remove this, as we have serialization FW
182 {
183  CIOBuffer source;
184  int32_t fd = source.open_file(file, 'r');
185 
186  if (fd < 0)
187  SG_SERROR("Unable to open file for loading regressor!\n")
188 
189  // Read version info
190  vw_size_t v_length;
191  source.read_file((char*)&v_length, sizeof(v_length));
192  char* t = SG_MALLOC(char, v_length);
193  source.read_file(t,v_length);
194  if (strcmp(t,env->vw_version) != 0)
195  {
196  SG_FREE(t);
197  SG_SERROR("Regressor source has an incompatible VW version!\n")
198  }
199  SG_FREE(t);
200 
201  // Read min and max label
202  source.read_file((char*)&env->min_label, sizeof(env->min_label));
203  source.read_file((char*)&env->max_label, sizeof(env->max_label));
204 
205  // Read num_bits, multiple sources are not supported
206  vw_size_t local_num_bits;
207  source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
208 
209  if ((vw_size_t) env->num_bits != local_num_bits)
210  SG_SERROR("Wrong number of bits in regressor source!\n")
211 
212  env->num_bits = local_num_bits;
213 
214  vw_size_t local_thread_bits;
215  source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
216 
217  env->thread_bits = local_thread_bits;
218 
219  int32_t len;
220  source.read_file((char *)&len, sizeof(len));
221 
222  // Read paired namespace information
223  DynArray<char*> local_pairs;
224  for (; len > 0; len--)
225  {
226  char pair[3];
227  source.read_file(pair, sizeof(char)*2);
228  pair[2]='\0';
229  local_pairs.push_back(pair);
230  }
231 
232  env->pairs = local_pairs;
233 
234  // Initialize the weight vector
235  if (weight_vectors)
236  SG_FREE(weight_vectors);
237  init(env);
238 
239  vw_size_t local_ngram;
240  source.read_file((char*)&local_ngram, sizeof(local_ngram));
241  vw_size_t local_skips;
242  source.read_file((char*)&local_skips, sizeof(local_skips));
243 
244  env->ngram = local_ngram;
245  env->skips = local_skips;
246 
247  // Read individual weights
248  vw_size_t stride = env->stride;
249  while (true)
250  {
251  uint32_t hash;
252  ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
253  if (hash_bytes <= 0)
254  break;
255 
256  float32_t w = 0.;
257  ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
258  if (weight_bytes <= 0)
259  break;
260 
261  vw_size_t num_threads = env->num_threads();
262 
263  weight_vectors[hash % num_threads][(hash*stride)/num_threads]
264  = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
265  }
266  source.close_file();
267 }
An I/O buffer class.
Definition: IOBuffer.h:41
uint32_t vw_size_t
vw_size_t typedef to work across platforms
Definition: vw_constants.h:26
T get_element(int32_t index) const
Definition: DynArray.h:142
bool random_weights
Whether to use random weights.
virtual void load_regressor(char *file_name)
virtual void init(CVwEnvironment *env_to_use=NULL)
Definition: VwRegressor.cpp:55
Class CVwEnvironment is the environment used by VW.
Definition: VwEnvironment.h:41
CLossFunction * loss
Loss function.
Definition: VwRegressor.h:118
float64_t min_label
Smallest label seen.
float32_t ** weight_vectors
Weight vectors, one array for each thread.
Definition: VwRegressor.h:116
vw_size_t num_bits
log_2 of the number of features
int32_t get_num_elements() const
Definition: DynArray.h:130
virtual bool close_file()
Definition: IOBuffer.cpp:126
float64_t max_label
Largest label seen.
#define SG_REF(x)
Definition: SGObject.h:51
static uint64_t random()
Definition: Math.h:1019
virtual ssize_t write_file(const void *buf, size_t nbytes)
Definition: IOBuffer.cpp:110
Class SGObject is the base class of all shogun objects.
Definition: SGObject.h:112
void push_back(T element)
Definition: DynArray.h:254
vw_size_t ngram
ngrams to generate
DynArray< char * > pairs
Pairs of features to cross for quadratic updates.
vw_size_t stride
Number of elements in weight vector per feature.
vw_size_t skips
Skips in ngrams.
virtual int open_file(const char *name, char flag='r')
Definition: IOBuffer.cpp:55
virtual ssize_t read_file(void *buf, size_t nbytes)
Definition: IOBuffer.cpp:87
float float32_t
Definition: common.h:49
#define SG_UNREF(x)
Definition: SGObject.h:52
CSquaredLoss implements the squared loss function.
Definition: SquaredLoss.h:29
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
vw_size_t thread_bits
log_2 of the number of threads
float32_t initial_weight
Initial value of all elements in weight vector.
#define SG_SERROR(...)
Definition: SGIO.h:179
const char * vw_version
VW version.
CVwEnvironment * env
Environment.
Definition: VwRegressor.h:122
vw_size_t thread_mask
Mask used by regressor for learning.
bool adaptive
Whether adaptive learning is used.
virtual void dump_regressor(char *reg_name, bool as_text)
Definition: VwRegressor.cpp:93
vw_size_t v_length
Length of version string.

SHOGUN Machine Learning Toolbox - Documentation