18 using namespace shogun;
20 CProtobufFile::CProtobufFile()
25 CProtobufFile::CProtobufFile(FILE* f,
const char* name) :
31 CProtobufFile::CProtobufFile(
const char* fname,
char rw,
const char* name) :
32 CFile(fname, rw, name)
37 CProtobufFile::~CProtobufFile()
42 void CProtobufFile::init()
45 message_size=1024*1024;
47 buffer=SG_MALLOC(uint8_t, message_size*
sizeof(uint32_t));
50 #define GET_VECTOR(sg_type) \
51 void CProtobufFile::get_vector(sg_type*& vector, int32_t& len) \
53 read_and_validate_global_header(ShogunVersion::VECTOR); \
54 VectorHeader data_header=read_vector_header(); \
55 len=data_header.len(); \
56 read_memory_block(vector, len, data_header.num_messages()); \
73 #define GET_MATRIX(read_func, sg_type) \
74 void CProtobufFile::get_matrix(sg_type*& matrix, int32_t& num_feat, int32_t& num_vec) \
76 read_and_validate_global_header(ShogunVersion::MATRIX); \
77 MatrixHeader data_header=read_matrix_header(); \
78 num_feat=data_header.num_cols(); \
79 num_vec=data_header.num_rows(); \
80 read_memory_block(matrix, num_feat*num_vec, data_header.num_messages()); \
97 #define GET_NDARRAY(read_func, sg_type) \
98 void CProtobufFile::get_ndarray(sg_type*& array, int32_t*& dims, int32_t& num_dims) \
112 #define GET_SPARSE_MATRIX(sg_type) \
113 void CProtobufFile::get_sparse_matrix( \
114 SGSparseVector<sg_type>*& matrix, int32_t& num_feat, int32_t& num_vec) \
116 read_and_validate_global_header(ShogunVersion::SPARSE_MATRIX); \
117 SparseMatrixHeader data_header=read_sparse_matrix_header(); \
118 num_feat=data_header.num_features(); \
119 num_vec=data_header.num_vectors(); \
120 read_sparse_matrix(matrix, data_header); \
136 #undef GET_SPARSE_MATRIX
138 #define SET_VECTOR(sg_type) \
139 void CProtobufFile::set_vector(const sg_type* vector, int32_t len) \
141 int32_t num_messages=compute_num_messages(len, sizeof(sg_type)); \
142 write_global_header(ShogunVersion::VECTOR); \
143 write_vector_header(len, num_messages); \
144 write_memory_block(vector, len, num_messages); \
161 #define SET_MATRIX(sg_type) \
162 void CProtobufFile::set_matrix(const sg_type* matrix, int32_t num_feat, int32_t num_vec) \
164 int32_t num_messages=compute_num_messages(num_feat*num_vec, sizeof(sg_type)); \
165 write_global_header(ShogunVersion::MATRIX); \
166 write_matrix_header(num_feat, num_vec, num_messages); \
167 write_memory_block(matrix, num_feat*num_vec, num_messages); \
184 #define SET_SPARSE_MATRIX(sg_type) \
185 void CProtobufFile::set_sparse_matrix( \
186 const SGSparseVector<sg_type>* matrix, int32_t num_feat, int32_t num_vec) \
188 write_global_header(ShogunVersion::SPARSE_MATRIX); \
189 write_sparse_matrix_header(matrix, num_feat, num_vec); \
190 write_sparse_matrix(matrix, num_vec); \
206 #undef SET_SPARSE_MATRIX
208 #define GET_STRING_LIST(sg_type) \
209 void CProtobufFile::get_string_list( \
210 SGString<sg_type>*& strings, int32_t& num_str, \
211 int32_t& max_string_len) \
213 read_and_validate_global_header(ShogunVersion::STRING_LIST); \
214 StringListHeader data_header=read_string_list_header(); \
215 num_str=data_header.num_str(); \
216 max_string_len=data_header.max_string_len(); \
217 read_string_list(strings, data_header); \
232 #undef GET_STRING_LIST
234 #define SET_STRING_LIST(sg_type) \
235 void CProtobufFile::set_string_list( \
236 const SGString<sg_type>* strings, int32_t num_str) \
238 write_global_header(ShogunVersion::STRING_LIST); \
239 write_string_list_header(strings, num_str); \
240 write_string_list(strings, num_str); \
255 #undef SET_STRING_LIST
257 void CProtobufFile::write_big_endian_uint(uint32_t number, uint8_t* array, uint32_t size)
260 SG_ERROR(
"array is too small to write\n");
262 array[0]=(number>>24)&0xffu;
263 array[1]=(number>>16)&0xffu;
264 array[2]=(number>>8)&0xffu;
265 array[3]=number&0xffu;
268 uint32_t CProtobufFile::read_big_endian_uint(uint8_t* array, uint32_t size)
271 SG_ERROR(
"array is too small to read\n");
273 return (array[0]<<24) | (array[1]<<16) | (array[2]<<8) | array[3];
276 int32_t CProtobufFile::compute_num_messages(uint64_t len, int32_t sizeof_type)
const
278 uint32_t elements_in_message=message_size/sizeof_type;
279 uint32_t num_messages=len/elements_in_message;
280 if (len % elements_in_message > 0)
286 void CProtobufFile::read_and_validate_global_header(ShogunVersion_SGDataType type)
288 ShogunVersion header;
289 read_message(header);
290 REQUIRE(header.version()==version,
"wrong version\n")
291 REQUIRE(header.data_type()==type,
"wrong type\n")
294 void CProtobufFile::write_global_header(ShogunVersion_SGDataType type)
296 ShogunVersion header;
297 header.set_version(version);
298 header.set_data_type(type);
299 write_message(header);
302 VectorHeader CProtobufFile::read_vector_header()
304 VectorHeader data_header;
305 read_message(data_header);
310 SparseMatrixHeader CProtobufFile::read_sparse_matrix_header()
312 SparseMatrixHeader data_header;
313 read_message(data_header);
318 MatrixHeader CProtobufFile::read_matrix_header()
320 MatrixHeader data_header;
321 read_message(data_header);
326 StringListHeader CProtobufFile::read_string_list_header()
328 StringListHeader data_header;
329 read_message(data_header);
334 void CProtobufFile::write_vector_header(int32_t len, int32_t num_messages)
336 VectorHeader data_header;
337 data_header.set_len(len);
338 data_header.set_num_messages(num_messages);
339 write_message(data_header);
342 void CProtobufFile::write_matrix_header(int32_t num_feat, int32_t num_vec, int32_t num_messages)
344 MatrixHeader data_header;
345 data_header.set_num_cols(num_feat);
346 data_header.set_num_rows(num_vec);
347 data_header.set_num_messages(num_messages);
348 write_message(data_header);
351 #define WRITE_SPARSE_MATRIX_HEADER(sg_type) \
352 void CProtobufFile::write_sparse_matrix_header( \
353 const SGSparseVector<sg_type>* matrix, int32_t num_feat, int32_t num_vec) \
355 SparseMatrixHeader data_header; \
356 data_header.set_num_features(num_feat); \
357 data_header.set_num_vectors(num_vec); \
358 for (int32_t i=0; i<num_vec; i++) \
360 data_header.add_num_feat_entries(matrix[i].num_feat_entries); \
363 write_message(data_header); \
366 WRITE_SPARSE_MATRIX_HEADER(
bool)
367 WRITE_SPARSE_MATRIX_HEADER(int8_t)
368 WRITE_SPARSE_MATRIX_HEADER(uint8_t)
369 WRITE_SPARSE_MATRIX_HEADER(
char)
370 WRITE_SPARSE_MATRIX_HEADER(int32_t)
371 WRITE_SPARSE_MATRIX_HEADER(uint32_t)
372 WRITE_SPARSE_MATRIX_HEADER(int64_t)
373 WRITE_SPARSE_MATRIX_HEADER(uint64_t)
377 WRITE_SPARSE_MATRIX_HEADER(int16_t)
378 WRITE_SPARSE_MATRIX_HEADER(uint16_t)
379 #undef WRITE_SPARSE_MATRIX_HEADER
381 #define WRITE_STRING_LIST_HEADER(sg_type) \
382 void CProtobufFile::write_string_list_header(const SGString<sg_type>* strings, int32_t num_str) \
384 int32_t max_string_len=0; \
385 StringListHeader data_header; \
386 data_header.set_num_str(num_str); \
387 for (int32_t i=0; i<num_str; i++) \
389 data_header.add_str_len(strings[i].slen); \
390 if (strings[i].slen>max_string_len) \
391 max_string_len=strings[i].slen; \
393 data_header.set_max_string_len(max_string_len); \
394 write_message(data_header); \
397 WRITE_STRING_LIST_HEADER(int8_t)
398 WRITE_STRING_LIST_HEADER(uint8_t)
399 WRITE_STRING_LIST_HEADER(
char)
400 WRITE_STRING_LIST_HEADER(int32_t)
401 WRITE_STRING_LIST_HEADER(uint32_t)
402 WRITE_STRING_LIST_HEADER(int64_t)
403 WRITE_STRING_LIST_HEADER(uint64_t)
407 WRITE_STRING_LIST_HEADER(int16_t)
408 WRITE_STRING_LIST_HEADER(uint16_t)
409 #undef WRITE_STRING_LIST_HEADER
411 void CProtobufFile::read_message(google::protobuf::Message& message)
413 uint32_t bytes_read=0;
417 bytes_read=fread(uint_buffer,
sizeof(
char),
sizeof(uint32_t), file);
418 REQUIRE(bytes_read==
sizeof(uint32_t),
"IO error\n");
419 msg_size=read_big_endian_uint(uint_buffer,
sizeof(uint32_t));
420 REQUIRE(msg_size>0,
"message size should be more than zero\n");
423 bytes_read=fread(buffer,
sizeof(
char), msg_size, file);
424 REQUIRE(bytes_read==msg_size,
"IO error\n");
427 REQUIRE(message.ParseFromArray(buffer, msg_size),
"cannot parse header\n");
430 void CProtobufFile::write_message(
const google::protobuf::Message& message)
432 uint32_t bytes_write=0;
433 uint32_t msg_size=message.ByteSize();
436 write_big_endian_uint(msg_size, uint_buffer,
sizeof(uint32_t));
437 bytes_write=fwrite(uint_buffer,
sizeof(
char),
sizeof(uint32_t), file);
438 REQUIRE(bytes_write==
sizeof(uint32_t),
"IO error\n");
441 message.SerializeToArray(buffer, msg_size);
442 bytes_write=fwrite(buffer,
sizeof(
char), msg_size, file);
443 REQUIRE(bytes_write==msg_size,
"IO error\n");
446 #define READ_MEMORY_BLOCK(chunk_type, sg_type) \
447 void CProtobufFile::read_memory_block(sg_type*& vector, uint64_t len, int32_t num_messages) \
449 vector=SG_MALLOC(sg_type, len); \
452 int32_t elements_in_message=message_size/sizeof(sg_type); \
453 for (int32_t i=0; i<num_messages; i++) \
455 read_message(chunk); \
457 int32_t num_elements_to_read=0; \
458 if ((len-(i+1)*elements_in_message)<=0) \
459 num_elements_to_read=len-i*elements_in_message; \
461 num_elements_to_read=elements_in_message; \
463 for (int32_t j=0; j<num_elements_to_read; j++) \
464 vector[j+i*elements_in_message]=chunk.data(j); \
468 READ_MEMORY_BLOCK(Int32Chunk, int8_t)
469 READ_MEMORY_BLOCK(UInt32Chunk, uint8_t)
470 READ_MEMORY_BLOCK(UInt32Chunk,
char)
471 READ_MEMORY_BLOCK(Int32Chunk, int32_t)
472 READ_MEMORY_BLOCK(UInt32Chunk, uint32_t)
473 READ_MEMORY_BLOCK(Float32Chunk,
float32_t)
474 READ_MEMORY_BLOCK(Float64Chunk,
float64_t)
476 READ_MEMORY_BLOCK(Int32Chunk, int16_t)
477 READ_MEMORY_BLOCK(UInt32Chunk, uint16_t)
478 READ_MEMORY_BLOCK(Int64Chunk, int64_t)
479 READ_MEMORY_BLOCK(UInt64Chunk, uint64_t)
480 #undef READ_MEMORY_BLOCK
482 #define WRITE_MEMORY_BLOCK(chunk_type, sg_type) \
483 void CProtobufFile::write_memory_block(const sg_type* vector, uint64_t len, int32_t num_messages) \
486 int32_t elements_in_message=message_size/sizeof(sg_type); \
487 for (int32_t i=0; i<num_messages; i++) \
490 int32_t num_elements_to_write=0; \
491 if ((len-(i+1)*elements_in_message)<=0) \
492 num_elements_to_write=len-i*elements_in_message; \
494 num_elements_to_write=elements_in_message; \
496 for (int32_t j=0; j<num_elements_to_write; j++) \
497 chunk.add_data(vector[j+i*elements_in_message]); \
499 write_message(chunk); \
504 WRITE_MEMORY_BLOCK(Int32Chunk, int8_t)
505 WRITE_MEMORY_BLOCK(UInt32Chunk, uint8_t)
506 WRITE_MEMORY_BLOCK(UInt32Chunk,
char)
507 WRITE_MEMORY_BLOCK(Int32Chunk, int32_t)
508 WRITE_MEMORY_BLOCK(UInt64Chunk, uint32_t)
509 WRITE_MEMORY_BLOCK(Int64Chunk, int64_t)
510 WRITE_MEMORY_BLOCK(UInt64Chunk, uint64_t)
511 WRITE_MEMORY_BLOCK(Float32Chunk,
float32_t)
512 WRITE_MEMORY_BLOCK(Float64Chunk,
float64_t)
514 WRITE_MEMORY_BLOCK(Int32Chunk, int16_t)
515 WRITE_MEMORY_BLOCK(UInt32Chunk, uint16_t)
516 #undef WRITE_MEMORY_BLOCK
518 #define READ_SPARSE_MATRIX(chunk_type, sg_type) \
519 void CProtobufFile::read_sparse_matrix( \
520 SGSparseVector<sg_type>*& matrix, const SparseMatrixHeader& data_header) \
522 matrix=SG_MALLOC(SGSparseVector<sg_type>, data_header.num_vectors()); \
524 UInt64Chunk feat_index_chunk; \
525 chunk_type entry_chunk; \
526 read_message(feat_index_chunk); \
527 read_message(entry_chunk); \
529 int32_t elements_in_message=message_size/sizeof(sg_type); \
530 int32_t buffer_counter=0; \
531 for (uint32_t i=0; i<data_header.num_vectors(); i++) \
533 matrix[i]=SGSparseVector<sg_type>(data_header.num_feat_entries(i)); \
534 for (int32_t j=0; j<matrix[i].num_feat_entries; j++) \
536 matrix[i].features[j].feat_index=feat_index_chunk.data(buffer_counter); \
537 matrix[i].features[j].entry=entry_chunk.data(buffer_counter); \
540 if (buffer_counter==elements_in_message) \
542 read_message(feat_index_chunk); \
543 read_message(entry_chunk); \
550 READ_SPARSE_MATRIX(BoolChunk,
bool)
551 READ_SPARSE_MATRIX(Int32Chunk, int8_t)
552 READ_SPARSE_MATRIX(UInt32Chunk, uint8_t)
553 READ_SPARSE_MATRIX(UInt32Chunk,
char)
554 READ_SPARSE_MATRIX(Int32Chunk, int32_t)
555 READ_SPARSE_MATRIX(UInt32Chunk, uint32_t)
556 READ_SPARSE_MATRIX(Float32Chunk,
float32_t)
557 READ_SPARSE_MATRIX(Float64Chunk,
float64_t)
559 READ_SPARSE_MATRIX(Int32Chunk, int16_t)
560 READ_SPARSE_MATRIX(UInt32Chunk, uint16_t)
561 READ_SPARSE_MATRIX(Int64Chunk, int64_t)
562 READ_SPARSE_MATRIX(UInt64Chunk, uint64_t)
563 #undef READ_SPARSE_MATRIX
565 #define WRITE_SPARSE_MATRIX(chunk_type, sg_type) \
566 void CProtobufFile::write_sparse_matrix( \
567 const SGSparseVector<sg_type>* matrix, int32_t num_vec) \
569 UInt64Chunk feat_index_chunk; \
570 chunk_type entry_chunk; \
571 int32_t elements_in_message=message_size/sizeof(sg_type); \
572 int32_t buffer_counter=0; \
573 for (int32_t i=0; i<num_vec; i++) \
575 for (int32_t j=0; j<matrix[i].num_feat_entries; j++) \
577 feat_index_chunk.add_data(matrix[i].features[j].feat_index); \
578 entry_chunk.add_data(matrix[i].features[j].entry); \
581 if (buffer_counter==elements_in_message) \
583 write_message(feat_index_chunk); \
584 write_message(entry_chunk); \
585 feat_index_chunk.Clear(); \
586 entry_chunk.Clear(); \
592 if (buffer_counter!=0) \
594 write_message(feat_index_chunk); \
595 write_message(entry_chunk); \
599 WRITE_SPARSE_MATRIX(BoolChunk,
bool)
600 WRITE_SPARSE_MATRIX(Int32Chunk, int8_t)
601 WRITE_SPARSE_MATRIX(UInt32Chunk, uint8_t)
602 WRITE_SPARSE_MATRIX(UInt32Chunk,
char)
603 WRITE_SPARSE_MATRIX(Int32Chunk, int32_t)
604 WRITE_SPARSE_MATRIX(UInt64Chunk, uint32_t)
605 WRITE_SPARSE_MATRIX(Int64Chunk, int64_t)
606 WRITE_SPARSE_MATRIX(UInt64Chunk, uint64_t)
607 WRITE_SPARSE_MATRIX(Float32Chunk,
float32_t)
608 WRITE_SPARSE_MATRIX(Float64Chunk,
float64_t)
610 WRITE_SPARSE_MATRIX(Int32Chunk, int16_t)
611 WRITE_SPARSE_MATRIX(UInt32Chunk, uint16_t)
612 #undef WRITE_SPARSE_MATRIX
614 #define READ_STRING_LIST(chunk_type, sg_type) \
615 void CProtobufFile::read_string_list( \
616 SGString<sg_type>*& strings, const StringListHeader& data_header) \
618 strings=SG_MALLOC(SGString<sg_type>, data_header.num_str()); \
621 read_message(chunk); \
622 int32_t elements_in_message=message_size/sizeof(sg_type); \
623 int32_t buffer_counter=0; \
624 for (uint32_t i=0; i<data_header.num_str(); i++) \
626 strings[i]=SGString<sg_type>(data_header.str_len(i)); \
627 for (int32_t j=0; j<strings[i].slen; j++) \
629 strings[i].string[j]=chunk.data(buffer_counter); \
632 if (buffer_counter==elements_in_message) \
634 read_message(chunk); \
641 READ_STRING_LIST(Int32Chunk, int8_t)
642 READ_STRING_LIST(UInt32Chunk, uint8_t)
643 READ_STRING_LIST(UInt32Chunk,
char)
644 READ_STRING_LIST(Int32Chunk, int32_t)
645 READ_STRING_LIST(UInt32Chunk, uint32_t)
646 READ_STRING_LIST(Float32Chunk,
float32_t)
647 READ_STRING_LIST(Float64Chunk,
float64_t)
649 READ_STRING_LIST(Int32Chunk, int16_t)
650 READ_STRING_LIST(UInt32Chunk, uint16_t)
651 READ_STRING_LIST(Int64Chunk, int64_t)
652 READ_STRING_LIST(UInt64Chunk, uint64_t)
653 #undef READ_STRING_LIST
655 #define WRITE_STRING_LIST(chunk_type, sg_type) \
656 void CProtobufFile::write_string_list( \
657 const SGString<sg_type>* strings, int32_t num_str) \
660 int32_t elements_in_message=message_size/sizeof(sg_type); \
661 int32_t buffer_counter=0; \
662 for (int32_t i=0; i<num_str; i++) \
664 for (int32_t j=0; j<strings[i].slen; j++) \
666 chunk.add_data(strings[i].string[j]); \
669 if (buffer_counter==elements_in_message) \
671 write_message(chunk); \
678 if (buffer_counter!=0) \
679 write_message(chunk); \
682 WRITE_STRING_LIST(Int32Chunk, int8_t)
683 WRITE_STRING_LIST(UInt32Chunk, uint8_t)
684 WRITE_STRING_LIST(UInt32Chunk,
char)
685 WRITE_STRING_LIST(Int32Chunk, int32_t)
686 WRITE_STRING_LIST(UInt64Chunk, uint32_t)
687 WRITE_STRING_LIST(Int64Chunk, int64_t)
688 WRITE_STRING_LIST(UInt64Chunk, uint64_t)
689 WRITE_STRING_LIST(Float32Chunk,
float32_t)
690 WRITE_STRING_LIST(Float64Chunk,
float64_t)
692 WRITE_STRING_LIST(Int32Chunk, int16_t)
693 WRITE_STRING_LIST(UInt32Chunk, uint16_t)
694 #undef WRITE_STRING_LIST