32 #include <unordered_map>
62 return map.find(tag) !=
map.end();
74 template<>
void CSGObject::set_generic<bool>()
79 template<>
void CSGObject::set_generic<char>()
84 template<>
void CSGObject::set_generic<int8_t>()
89 template<>
void CSGObject::set_generic<uint8_t>()
94 template<>
void CSGObject::set_generic<int16_t>()
99 template<>
void CSGObject::set_generic<uint16_t>()
101 m_generic = PT_UINT16;
104 template<>
void CSGObject::set_generic<int32_t>()
106 m_generic = PT_INT32;
109 template<>
void CSGObject::set_generic<uint32_t>()
111 m_generic = PT_UINT32;
114 template<>
void CSGObject::set_generic<int64_t>()
116 m_generic = PT_INT64;
119 template<>
void CSGObject::set_generic<uint64_t>()
121 m_generic = PT_UINT64;
124 template<>
void CSGObject::set_generic<float32_t>()
126 m_generic = PT_FLOAT32;
129 template<>
void CSGObject::set_generic<float64_t>()
131 m_generic = PT_FLOAT64;
134 template<>
void CSGObject::set_generic<floatmax_t>()
136 m_generic = PT_FLOATMAX;
139 template<>
void CSGObject::set_generic<CSGObject*>()
141 m_generic = PT_SGOBJECT;
144 template<>
void CSGObject::set_generic<complex128_t>()
146 m_generic = PT_COMPLEX128;
156 set_global_objects();
163 : self(), io(orig.io), parallel(orig.parallel), version(orig.version)
166 set_global_objects();
176 unset_global_objects();
183 #ifdef USE_REFERENCE_COUNTING
184 int32_t CSGObject::ref()
186 int32_t count = m_refcount->
ref();
188 return m_refcount->ref_count();
195 return m_refcount->ref_count();
200 int32_t count = m_refcount->
unref();
203 SG_SGCDEBUG(
"unref() refcount %ld, obj %s (%p) destroying\n", count, this->
get_name(),
this)
210 return m_refcount->ref_count();
213 #endif //USE_REFERENCE_COUNTING
215 #ifdef TRACE_MEMORY_ALLOCS
219 void CSGObject::list_memory_allocs()
221 shogun::list_memory_allocs();
237 void CSGObject::set_global_objects()
241 fprintf(stderr,
"call init_shogun() before using the library, dying.\n");
254 void CSGObject::unset_global_objects()
289 get_parameter_incremental_hash(
m_hash, carry, length);
303 get_parameter_incremental_hash(hash, carry, length);
331 *
generic = m_generic;
343 SG_PRINT(
"\n%s\n================================================================================\n",
get_name())
357 SG_SWARNING(
"%s%s::save_serializable_pre(): ShogunException: "
363 if (!m_save_pre_called)
365 SG_SWARNING(
"%s%s::save_serializable_pre(): Implementation "
366 "error: BASE_CLASS::SAVE_SERIALIZABLE_PRE() not "
380 SG_SWARNING(
"%s%s::save_serializable_post(): ShogunException: "
386 if (!m_save_post_called)
388 SG_SWARNING(
"%s%s::save_serializable_post(): Implementation "
389 "error: BASE_CLASS::SAVE_SERIALIZABLE_POST() not "
394 if (prefix == NULL || *prefix ==
'\0')
405 REQUIRE(file != NULL,
"Serializable file object should be != NULL\n");
414 SG_SWARNING(
"%s%s::load_serializable_pre(): ShogunException: "
419 if (!m_load_pre_called)
421 SG_SWARNING(
"%s%s::load_serializable_pre(): Implementation "
422 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not "
436 SG_SWARNING(
"%s%s::load_serializable_post(): ShogunException: "
442 if (!m_load_post_called)
444 SG_SWARNING(
"%s%s::load_serializable_post(): Implementation "
445 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not "
456 m_load_pre_called =
true;
461 m_load_post_called =
true;
466 m_save_pre_called =
true;
471 m_save_post_called =
true;
474 #ifdef TRACE_MEMORY_ALLOCS
479 void CSGObject::init()
481 #ifdef TRACE_MEMORY_ALLOCS
484 int32_t idx=sg_mallocs->
index_of(
this);
500 m_load_pre_called =
false;
501 m_load_post_called =
false;
502 m_save_pre_called =
false;
503 m_save_post_called =
false;
516 for (
index_t i=0; i<num_param; i++)
520 char* type=SG_MALLOC(
char, l);
539 for (
index_t i=0; i<num_param; i++)
546 if (len>max_string_length)
547 max_string_length=len;
561 SG_ERROR(
"There is no model selection parameter called \"%s\" for %s",
579 if (!strcmp(param_name, current->
m_name))
589 void CSGObject::get_parameter_incremental_hash(uint32_t& hash, uint32_t& carry,
590 uint32_t& total_length)
606 child->get_parameter_incremental_hash(hash, carry,
619 child[j]->get_parameter_incremental_hash(hash, carry,
672 SG_INFO(
"leaving %s::equals(): name of other object differs\n",
get_name());
680 SG_INFO(
"leaving %s::equals(): number of parameters of other object "
687 SG_DEBUG(
"comparing parameter %d\n", i);
694 if (!this_param && !other_param)
697 if (!this_param && other_param)
699 SG_DEBUG(
"leaving %s::equals(): parameter %d is NULL where other's "
704 if (this_param && !other_param)
706 SG_DEBUG(
"leaving %s::equals(): parameter %d is \"%s\" where other's "
711 SG_DEBUG(
"comparing parameter \"%s\" to other's \"%s\"\n",
715 if (!strcmp(
"DynamicObjectArray",
get_name()) &&
716 !strcmp(this_param->
m_name,
"num_elements") &&
717 !strcmp(other_param->
m_name,
"num_elements"))
719 SG_DEBUG(
"Ignoring DynamicObjectArray::num_elements field\n");
724 if (!strcmp(
"DynamicArray",
get_name()) &&
725 !strcmp(this_param->
m_name,
"num_elements") &&
726 !strcmp(other_param->
m_name,
"num_elements"))
728 SG_DEBUG(
"Ignoring DynamicArray::num_elements field\n");
733 if (!this_param->
equals(other_param, accuracy, tolerant))
735 SG_INFO(
"leaving %s::equals(): parameters at position %d with name"
736 " \"%s\" differs from other object parameter with name "
756 REQUIRE(copy,
"Could not create empty instance of \"%s\". The reason for "
757 "this usually is that get_name() of the class returns something "
758 "wrong, or that a class has a wrongly set generic type.\n",
763 SG_DEBUG(
"cloning parameter \"%s\" at index %d\n",
768 SG_DEBUG(
"leaving %s::clone(): Clone failed. Returning NULL\n",
778 void CSGObject::set_with_base_tag(
const BaseTag& _tag,
const Any& any)
780 self->
set(_tag, any);
783 Any CSGObject::get_with_base_tag(
const BaseTag& _tag)
const
785 Any any =
self->get(_tag);
788 SG_ERROR(
"There is no parameter called \"%s\" in %s",
794 bool CSGObject::has_with_base_tag(
const BaseTag& _tag)
const
796 return self->has(_tag);
virtual const char * get_name() const =0
SGStringList< char > get_modelsel_names()
template class SGStringList
Parallel * get_global_parallel()
virtual void update_parameter_hash()
int32_t index_of(const K &key)
virtual int32_t get_num_parameters()
virtual CSGObject * clone()
Base class for all tags. This class stores name and not the type information for a shogun object...
Class ShogunException defines an exception which is thrown whenever an error inside of shogun occurs...
virtual CSGObject * shallow_copy() const
TParameter * get_parameter(int32_t idx)
Version * get_global_version()
virtual void save_serializable_pre()
virtual bool is_generic(EPrimitiveType *generic) const
#define SG_NOTIMPLEMENTED
Allows to store objects of arbitrary types by using a BaseAnyPolicy and provides a type agnostic API...
virtual bool load_serializable(CSerializableFile *file, const char *prefix="")
virtual void print(const char *prefix="")
static uint32_t FinalizeIncrementalMurmurHash3(uint32_t h, uint32_t carry, uint32_t total_length)
virtual bool load(CSerializableFile *file, const char *prefix="")
std::map< BaseTag, Any > ParametersMap
T * get_element_ptr(int32_t index)
char * get_modsel_param_descr(const char *param_name)
bool has(const BaseTag &tag) const
bool equals(TParameter *other, float64_t accuracy=0.0, bool tolerant=false)
void set(const Tag< T > &_tag, const T &value)
int32_t add(const K &key, const T &data)
Class SGObject is the base class of all shogun objects.
void build_gradient_parameter_dictionary(CMap< TParameter *, CSGObject * > *dict)
virtual bool save(CSerializableFile *file, const char *prefix="")
virtual void save_serializable_post()
void print_modsel_params()
CSGObject * new_sgserializable(const char *sgserializable_name, EPrimitiveType generic)
Class Version provides version information.
Parameter * m_model_selection_parameters
index_t max_string_length
void get_incremental_hash(uint32_t &hash, uint32_t &carry, uint32_t &total_length)
virtual bool equals(CSGObject *other, float64_t accuracy=0.0, bool tolerant=false)
virtual CSGObject * deep_copy() const
void set_global_parallel(Parallel *parallel)
void to_string(char *dest, size_t n) const
virtual void load_serializable_pre()
virtual void load_serializable_post()
Class Parallel provides helper functions for multithreading.
virtual bool save_serializable(CSerializableFile *file, const char *prefix="")
const char * get_exception_string()
all of classes and functions are contained in the shogun namespace
index_t get_modsel_param_index(const char *param_name)
void set_global_io(SGIO *io)
bool copy(TParameter *target)
void set(const BaseTag &tag, const Any &any)
Class SGIO, used to do input output operations throughout shogun.
Parameter * m_gradient_parameters
virtual void print_serializable(const char *prefix="")
virtual bool parameter_hash_changed()
void set_global_version(Version *version)
the class CMap, a map based on the hash-table. w: http://en.wikipedia.org/wiki/Hash_table ...