00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015 #include <shogun/classifier/vw/VwRegressor.h>
00016 #include <shogun/loss/SquaredLoss.h>
00017 #include <shogun/io/IOBuffer.h>
00018
00019 using namespace shogun;
00020
00021 CVwRegressor::CVwRegressor()
00022 : CSGObject()
00023 {
00024 weight_vectors = NULL;
00025 loss = new CSquaredLoss();
00026 init(NULL);
00027 }
00028
00029 CVwRegressor::CVwRegressor(CVwEnvironment* env_to_use)
00030 : CSGObject()
00031 {
00032 weight_vectors = NULL;
00033 loss = new CSquaredLoss();
00034 init(env_to_use);
00035 }
00036
00037 CVwRegressor::~CVwRegressor()
00038 {
00039 SG_FREE(weight_vectors);
00040 SG_UNREF(loss);
00041 SG_UNREF(env);
00042 }
00043
00044 void CVwRegressor::init(CVwEnvironment* env_to_use)
00045 {
00046 if (!env_to_use)
00047 env_to_use = new CVwEnvironment();
00048
00049 env = env_to_use;
00050 SG_REF(env);
00051
00052
00053
00054 vw_size_t length = ((vw_size_t) 1) << env->num_bits;
00055 env->thread_mask = (env->stride * (length >> env->thread_bits)) - 1;
00056
00057
00058 vw_size_t num_threads = 1;
00059 weight_vectors = SG_MALLOC(float32_t*, num_threads);
00060
00061 for (vw_size_t i = 0; i < num_threads; i++)
00062 {
00063 weight_vectors[i] = SG_CALLOC(float32_t, env->stride * length / num_threads);
00064
00065 if (env->random_weights)
00066 {
00067 for (vw_size_t j = 0; j < length/num_threads; j++)
00068 weight_vectors[i][j] = drand48() - 0.5;
00069 }
00070
00071 if (env->initial_weight != 0.)
00072 for (vw_size_t j = 0; j < env->stride*length/num_threads; j+=env->stride)
00073 weight_vectors[i][j] = env->initial_weight;
00074
00075 if (env->adaptive)
00076 for (vw_size_t j = 1; j < env->stride*length/num_threads; j+=env->stride)
00077 weight_vectors[i][j] = 1;
00078 }
00079 }
00080
00081 void CVwRegressor::dump_regressor(char* reg_name, bool as_text)
00082 {
00083 CIOBuffer io_temp;
00084 int32_t f = io_temp.open_file(reg_name,'w');
00085
00086 if (f < 0)
00087 SG_SERROR("Can't open: %s for writing! Exiting.\n", reg_name);
00088
00089 const char* vw_version = env->vw_version;
00090 vw_size_t v_length = env->v_length;
00091
00092 if (!as_text)
00093 {
00094
00095 io_temp.write_file((char*)&v_length, sizeof(v_length));
00096 io_temp.write_file(vw_version,v_length);
00097
00098
00099 io_temp.write_file((char*)&env->min_label, sizeof(env->min_label));
00100 io_temp.write_file((char*)&env->max_label, sizeof(env->max_label));
00101
00102
00103 io_temp.write_file((char *)&env->num_bits, sizeof(env->num_bits));
00104 io_temp.write_file((char *)&env->thread_bits, sizeof(env->thread_bits));
00105
00106
00107 int32_t len = env->pairs.get_num_elements();
00108 io_temp.write_file((char *)&len, sizeof(len));
00109
00110 for (int32_t k = 0; k < env->pairs.get_num_elements(); k++)
00111 io_temp.write_file(env->pairs.get_element(k), 2);
00112
00113
00114 io_temp.write_file((char*)&env->ngram, sizeof(env->ngram));
00115 io_temp.write_file((char*)&env->skips, sizeof(env->skips));
00116 }
00117 else
00118 {
00119
00120 char buff[512];
00121 int32_t len;
00122
00123 len = sprintf(buff, "Version %s\n", vw_version);
00124 io_temp.write_file(buff, len);
00125 len = sprintf(buff, "Min label:%f max label:%f\n", env->min_label, env->max_label);
00126 io_temp.write_file(buff, len);
00127 len = sprintf(buff, "bits:%d thread_bits:%d\n", (int32_t)env->num_bits, (int32_t)env->thread_bits);
00128 io_temp.write_file(buff, len);
00129
00130 if (env->pairs.get_num_elements() > 0)
00131 {
00132 len = sprintf(buff, "\n");
00133 io_temp.write_file(buff, len);
00134 }
00135
00136 len = sprintf(buff, "ngram:%d skips:%d\nindex:weight pairs:\n", (int32_t)env->ngram, (int32_t)env->skips);
00137 io_temp.write_file(buff, len);
00138 }
00139
00140 uint32_t length = 1 << env->num_bits;
00141 vw_size_t num_threads = env->num_threads();
00142 vw_size_t stride = env->stride;
00143
00144
00145 for(uint32_t i = 0; i < length; i++)
00146 {
00147 float32_t v;
00148 v = weight_vectors[i%num_threads][stride*(i/num_threads)];
00149 if (v != 0.)
00150 {
00151 if (!as_text)
00152 {
00153 io_temp.write_file((char *)&i, sizeof (i));
00154 io_temp.write_file((char *)&v, sizeof (v));
00155 }
00156 else
00157 {
00158 char buff[512];
00159 int32_t len = sprintf(buff, "%d:%f\n", i, v);
00160 io_temp.write_file(buff, len);
00161 }
00162 }
00163 }
00164
00165 io_temp.close_file();
00166 }
00167
00168 void CVwRegressor::load_regressor(char* file)
00169 {
00170 CIOBuffer source;
00171 int32_t fd = source.open_file(file, 'r');
00172
00173 if (fd < 0)
00174 SG_SERROR("Unable to open file for loading regressor!\n");
00175
00176
00177 vw_size_t v_length;
00178 source.read_file((char*)&v_length, sizeof(v_length));
00179 char* t = SG_MALLOC(char, v_length);
00180 source.read_file(t,v_length);
00181 if (strcmp(t,env->vw_version) != 0)
00182 {
00183 SG_FREE(t);
00184 SG_SERROR("Regressor source has an incompatible VW version!\n");
00185 }
00186 SG_FREE(t);
00187
00188
00189 source.read_file((char*)&env->min_label, sizeof(env->min_label));
00190 source.read_file((char*)&env->max_label, sizeof(env->max_label));
00191
00192
00193 vw_size_t local_num_bits;
00194 source.read_file((char *)&local_num_bits, sizeof(local_num_bits));
00195
00196 if ((vw_size_t) env->num_bits != local_num_bits)
00197 SG_SERROR("Wrong number of bits in regressor source!\n");
00198
00199 env->num_bits = local_num_bits;
00200
00201 vw_size_t local_thread_bits;
00202 source.read_file((char*)&local_thread_bits, sizeof(local_thread_bits));
00203
00204 env->thread_bits = local_thread_bits;
00205
00206 int32_t len;
00207 source.read_file((char *)&len, sizeof(len));
00208
00209
00210 DynArray<char*> local_pairs;
00211 for (; len > 0; len--)
00212 {
00213 char pair[3];
00214 source.read_file(pair, sizeof(char)*2);
00215 pair[2]='\0';
00216 local_pairs.push_back(pair);
00217 }
00218
00219 env->pairs = local_pairs;
00220
00221
00222 if (weight_vectors)
00223 SG_FREE(weight_vectors);
00224 init(env);
00225
00226 vw_size_t local_ngram;
00227 source.read_file((char*)&local_ngram, sizeof(local_ngram));
00228 vw_size_t local_skips;
00229 source.read_file((char*)&local_skips, sizeof(local_skips));
00230
00231 env->ngram = local_ngram;
00232 env->skips = local_skips;
00233
00234
00235 vw_size_t stride = env->stride;
00236 while (true)
00237 {
00238 uint32_t hash;
00239 ssize_t hash_bytes = source.read_file((char *)&hash, sizeof(hash));
00240 if (hash_bytes <= 0)
00241 break;
00242
00243 float32_t w = 0.;
00244 ssize_t weight_bytes = source.read_file((char *)&w, sizeof(float32_t));
00245 if (weight_bytes <= 0)
00246 break;
00247
00248 vw_size_t num_threads = env->num_threads();
00249
00250 weight_vectors[hash % num_threads][(hash*stride)/num_threads]
00251 = weight_vectors[hash % num_threads][(hash*stride)/num_threads] + w;
00252 }
00253 source.close_file();
00254 }