00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/config.h>
00012 #include <shogun/base/SGObject.h>
00013 #include <shogun/io/SGIO.h>
00014 #include <shogun/base/Parallel.h>
00015 #include <shogun/base/init.h>
00016 #include <shogun/base/Version.h>
00017 #include <shogun/base/Parameter.h>
00018
00019 #include <stdlib.h>
00020 #include <stdio.h>
00021
00022
00023 namespace shogun
00024 {
00025 class CMath;
00026 class Parallel;
00027 class IO;
00028 class Version;
00029
00030 extern CMath* sg_math;
00031 extern Parallel* sg_parallel;
00032 extern SGIO* sg_io;
00033 extern Version* sg_version;
00034
00035 template<> void CSGObject::set_generic<bool>()
00036 {
00037 m_generic = PT_BOOL;
00038 }
00039
00040 template<> void CSGObject::set_generic<char>()
00041 {
00042 m_generic = PT_CHAR;
00043 }
00044
00045 template<> void CSGObject::set_generic<int8_t>()
00046 {
00047 m_generic = PT_INT8;
00048 }
00049
00050 template<> void CSGObject::set_generic<uint8_t>()
00051 {
00052 m_generic = PT_UINT8;
00053 }
00054
00055 template<> void CSGObject::set_generic<int16_t>()
00056 {
00057 m_generic = PT_INT16;
00058 }
00059
00060 template<> void CSGObject::set_generic<uint16_t>()
00061 {
00062 m_generic = PT_UINT16;
00063 }
00064
00065 template<> void CSGObject::set_generic<int32_t>()
00066 {
00067 m_generic = PT_INT32;
00068 }
00069
00070 template<> void CSGObject::set_generic<uint32_t>()
00071 {
00072 m_generic = PT_UINT32;
00073 }
00074
00075 template<> void CSGObject::set_generic<int64_t>()
00076 {
00077 m_generic = PT_INT64;
00078 }
00079
00080 template<> void CSGObject::set_generic<uint64_t>()
00081 {
00082 m_generic = PT_UINT64;
00083 }
00084
00085 template<> void CSGObject::set_generic<float32_t>()
00086 {
00087 m_generic = PT_FLOAT32;
00088 }
00089
00090 template<> void CSGObject::set_generic<float64_t>()
00091 {
00092 m_generic = PT_FLOAT64;
00093 }
00094
00095 template<> void CSGObject::set_generic<floatmax_t>()
00096 {
00097 m_generic = PT_FLOATMAX;
00098 }
00099
00100 }
00101
00102 using namespace shogun;
00103
00104 CSGObject::CSGObject()
00105 {
00106 init();
00107 set_global_objects();
00108
00109 SG_GCDEBUG("SGObject created (%p)\n", this);
00110 }
00111
00112 CSGObject::CSGObject(const CSGObject& orig)
00113 :io(orig.io), parallel(orig.parallel), version(orig.version)
00114 {
00115 init();
00116 set_global_objects();
00117 }
00118
00119 CSGObject::~CSGObject()
00120 {
00121 SG_GCDEBUG("SGObject destroyed (%p)\n", this);
00122
00123 #ifdef HAVE_PTHREAD
00124 PTHREAD_LOCK_DESTROY(&m_ref_lock);
00125 #endif
00126 unset_global_objects();
00127 delete m_parameters;
00128 delete m_model_selection_parameters;
00129 }
00130
00131 #ifdef USE_REFERENCE_COUNTING
00132
00133 int32_t CSGObject::ref()
00134 {
00135 #ifdef HAVE_PTHREAD
00136 PTHREAD_LOCK(&m_ref_lock);
00137 #endif //HAVE_PTHREAD
00138 ++m_refcount;
00139 int32_t count=m_refcount;
00140 #ifdef HAVE_PTHREAD
00141 PTHREAD_UNLOCK(&m_ref_lock);
00142 #endif //HAVE_PTHREAD
00143 SG_GCDEBUG("ref() refcount %ld obj %s (%p) increased\n", count, this->get_name(), this);
00144 return m_refcount;
00145 }
00146
00147 int32_t CSGObject::ref_count()
00148 {
00149 #ifdef HAVE_PTHREAD
00150 PTHREAD_LOCK(&m_ref_lock);
00151 #endif //HAVE_PTHREAD
00152 int32_t count=m_refcount;
00153 #ifdef HAVE_PTHREAD
00154 PTHREAD_UNLOCK(&m_ref_lock);
00155 #endif //HAVE_PTHREAD
00156 SG_GCDEBUG("ref_count(): refcount %d, obj %s (%p)\n", count, this->get_name(), this);
00157 return count;
00158 }
00159
00160 int32_t CSGObject::unref()
00161 {
00162 #ifdef HAVE_PTHREAD
00163 PTHREAD_LOCK(&m_ref_lock);
00164 #endif //HAVE_PTHREAD
00165 if (m_refcount==0 || --m_refcount==0)
00166 {
00167 SG_GCDEBUG("unref() refcount %ld, obj %s (%p) destroying\n", m_refcount, this->get_name(), this);
00168 #ifdef HAVE_PTHREAD
00169 PTHREAD_UNLOCK(&m_ref_lock);
00170 #endif //HAVE_PTHREAD
00171 delete this;
00172 return 0;
00173 }
00174 else
00175 {
00176 SG_GCDEBUG("unref() refcount %ld obj %s (%p) decreased\n", m_refcount, this->get_name(), this);
00177 #ifdef HAVE_PTHREAD
00178 PTHREAD_UNLOCK(&m_ref_lock);
00179 #endif //HAVE_PTHREAD
00180 return m_refcount;
00181 }
00182 }
00183 #endif //USE_REFERENCE_COUNTING
00184
00185
00186 void CSGObject::set_global_objects()
00187 {
00188 if (!sg_io || !sg_parallel || !sg_version)
00189 {
00190 fprintf(stderr, "call init_shogun() before using the library, dying.\n");
00191 exit(1);
00192 }
00193
00194 SG_REF(sg_io);
00195 SG_REF(sg_parallel);
00196 SG_REF(sg_version);
00197
00198 io=sg_io;
00199 parallel=sg_parallel;
00200 version=sg_version;
00201 }
00202
00203 void CSGObject::unset_global_objects()
00204 {
00205 SG_UNREF(version);
00206 SG_UNREF(parallel);
00207 SG_UNREF(io);
00208 }
00209
00210 void CSGObject::set_global_io(SGIO* new_io)
00211 {
00212 SG_UNREF(sg_io);
00213 sg_io=new_io;
00214 SG_REF(sg_io);
00215 }
00216
00217 SGIO* CSGObject::get_global_io()
00218 {
00219 SG_REF(sg_io);
00220 return sg_io;
00221 }
00222
00223 void CSGObject::set_global_parallel(Parallel* new_parallel)
00224 {
00225 SG_UNREF(sg_parallel);
00226 sg_parallel=new_parallel;
00227 SG_REF(sg_parallel);
00228 }
00229
00230 Parallel* CSGObject::get_global_parallel()
00231 {
00232 SG_REF(sg_parallel);
00233 return sg_parallel;
00234 }
00235
00236 void CSGObject::set_global_version(Version* new_version)
00237 {
00238 SG_UNREF(sg_version);
00239 sg_version=new_version;
00240 SG_REF(sg_version);
00241 }
00242
00243 Version* CSGObject::get_global_version()
00244 {
00245 SG_REF(sg_version);
00246 return sg_version;
00247 }
00248
00249 bool CSGObject::is_generic(EPrimitiveType* generic) const
00250 {
00251 *generic = m_generic;
00252
00253 return m_generic != PT_NOT_GENERIC;
00254 }
00255
00256 void CSGObject::unset_generic()
00257 {
00258 m_generic = PT_NOT_GENERIC;
00259 }
00260
00261 void CSGObject::print_serializable(const char* prefix)
00262 {
00263 SG_PRINT("\n%s\n================================================================================\n", get_name());
00264 m_parameters->print(prefix);
00265 }
00266
00267 bool CSGObject::save_serializable(CSerializableFile* file,
00268 const char* prefix)
00269 {
00270 SG_DEBUG("START SAVING CSGObject '%s'\n", get_name());
00271 try
00272 {
00273 save_serializable_pre();
00274 }
00275 catch (ShogunException e)
00276 {
00277 SG_SWARNING("%s%s::save_serializable_pre(): ShogunException: "
00278 "%s\n", prefix, get_name(),
00279 e.get_exception_string());
00280 return false;
00281 }
00282 if (!m_save_pre_called)
00283 {
00284 SG_SWARNING("%s%s::save_serializable_pre(): Implementation "
00285 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not "
00286 "called!\n", prefix, get_name());
00287 return false;
00288 }
00289
00290
00291 if (!save_parameter_version(file, prefix))
00292 return false;
00293
00294 if (!m_parameters->save(file, prefix))
00295 return false;
00296
00297 try
00298 {
00299 save_serializable_post();
00300 }
00301 catch (ShogunException e)
00302 {
00303 SG_SWARNING("%s%s::save_serializable_post(): ShogunException: "
00304 "%s\n", prefix, get_name(),
00305 e.get_exception_string());
00306 return false;
00307 }
00308
00309 if (!m_save_post_called)
00310 {
00311 SG_SWARNING("%s%s::save_serializable_post(): Implementation "
00312 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not "
00313 "called!\n", prefix, get_name());
00314 return false;
00315 }
00316
00317 if (prefix == NULL || *prefix == '\0')
00318 file->close();
00319
00320 SG_DEBUG("DONE SAVING CSGObject '%s' (%p)\n", get_name(), this);
00321
00322 return true;;
00323 }
00324
00325 bool CSGObject::load_serializable(CSerializableFile* file,
00326 const char* prefix)
00327 {
00328 SG_DEBUG("START LOADING CSGObject '%s'\n", get_name());
00329 try
00330 {
00331 load_serializable_pre();
00332 }
00333 catch (ShogunException e)
00334 {
00335 SG_SWARNING("%s%s::load_serializable_pre(): ShogunException: "
00336 "%s\n", prefix, get_name(),
00337 e.get_exception_string());
00338 return false;
00339 }
00340 if (!m_load_pre_called)
00341 {
00342 SG_SWARNING("%s%s::load_serializable_pre(): Implementation "
00343 "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not "
00344 "called!\n", prefix, get_name());
00345 return false;
00346 }
00347
00348
00349 int32_t file_version=load_parameter_version(file, prefix);
00350
00351 if (file_version<0)
00352 {
00353 SG_WARNING("%s%s::load_serializable(): File contains no parameter "
00354 "version. Seems like your file is from the days before this "
00355 "was introduced. Ignore warning or serialize with this version "
00356 "of shogun to get rid of above and this warnings.\n",
00357 prefix, get_name());
00358 }
00359
00360 if (file_version>version->get_version_parameter())
00361 {
00362 SG_WARNING("%s%s::load_serializable(): parameter version of file "
00363 "larger than the one of shogun. Try with a more recent version "
00364 "of shogun.\n", prefix, get_name());
00365 return false;
00366 }
00367
00368 if (!m_parameters->load(file, prefix))
00369 return false;
00370
00371 try
00372 {
00373 load_serializable_post();
00374 }
00375 catch (ShogunException e)
00376 {
00377 SG_SWARNING("%s%s::load_serializable_post(): ShogunException: "
00378 "%s\n", prefix, get_name(),
00379 e.get_exception_string());
00380 return false;
00381 }
00382
00383 if (!m_load_post_called)
00384 {
00385 SG_SWARNING("%s%s::load_serializable_post(): Implementation "
00386 "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not "
00387 "called!\n", prefix, get_name());
00388 return false;
00389 }
00390 SG_DEBUG("DONE LOADING CSGObject '%s' (%p)\n", get_name(), this);
00391
00392 return true;
00393 }
00394
00395 bool CSGObject::save_parameter_version(CSerializableFile* file,
00396 const char* prefix)
00397 {
00398 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32);
00399 int32_t v=version->get_version_parameter();
00400 TParameter p(&t, &v, "version_parameter",
00401 "Version of parameters of this object");
00402 return p.save(file, prefix);
00403 }
00404
00405 int32_t CSGObject::load_parameter_version(CSerializableFile* file,
00406 const char* prefix)
00407 {
00408 TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32);
00409 int32_t v;
00410 TParameter tp(&t, &v, "version_parameter", "");
00411 if (tp.load(file, prefix))
00412 return v;
00413 else
00414 return -1;
00415 }
00416
00417 void CSGObject::load_serializable_pre() throw (ShogunException)
00418 {
00419 m_load_pre_called = true;
00420 }
00421
00422 void CSGObject::load_serializable_post() throw (ShogunException)
00423 {
00424 m_load_post_called = true;
00425 }
00426
00427 void CSGObject::save_serializable_pre() throw (ShogunException)
00428 {
00429 m_save_pre_called = true;
00430 }
00431
00432 void CSGObject::save_serializable_post() throw (ShogunException)
00433 {
00434 m_save_post_called = true;
00435 }
00436
00437 #ifdef TRACE_MEMORY_ALLOCS
00438 #include <shogun/lib/Set.h>
00439 extern CSet<shogun::MemoryBlock>* sg_mallocs;
00440 #endif
00441
00442 void CSGObject::init()
00443 {
00444 #ifdef HAVE_PTHREAD
00445 PTHREAD_LOCK_INIT(&m_ref_lock);
00446 #endif
00447
00448 #ifdef TRACE_MEMORY_ALLOCS
00449 if (sg_mallocs)
00450 {
00451 int32_t idx=sg_mallocs->index_of(MemoryBlock(this));
00452 if (idx>-1)
00453 {
00454 MemoryBlock* b=sg_mallocs->get_element_ptr(idx);
00455 b->set_sgobject();
00456 }
00457 }
00458 #endif
00459
00460 m_refcount = 0;
00461 io = NULL;
00462 parallel = NULL;
00463 version = NULL;
00464 m_parameters = new Parameter();
00465 m_model_selection_parameters = new Parameter();
00466 m_generic = PT_NOT_GENERIC;
00467 m_load_pre_called = false;
00468 m_load_post_called = false;
00469 }
00470
00471 SGVector<char*> CSGObject::get_modelsel_names()
00472 {
00473 SGVector<char*> result=SGVector<char*>(
00474 m_model_selection_parameters->get_num_parameters());
00475
00476 for (index_t i=0; i<result.vlen; ++i)
00477 result.vector[i]=m_model_selection_parameters->get_parameter(i)->m_name;
00478
00479 return result;
00480 }
00481
00482 char* CSGObject::get_modsel_param_descr(const char* param_name)
00483 {
00484 index_t index=get_modsel_param_index(param_name);
00485
00486 if (index<0)
00487 {
00488 SG_ERROR("There is no model selection parameter called \"%s\" for %s",
00489 param_name, get_name());
00490 }
00491
00492 return m_model_selection_parameters->get_parameter(index)->m_description;
00493 }
00494
00495 index_t CSGObject::get_modsel_param_index(const char* param_name)
00496 {
00497
00498
00499 SGVector<char*> names=get_modelsel_names();
00500
00501
00502 index_t index=-1;
00503 for (index_t i=0; i<names.vlen; ++i)
00504 {
00505 TParameter* current=m_model_selection_parameters->get_parameter(i);
00506 if (!strcmp(param_name, current->m_name))
00507 {
00508 index=i;
00509 break;
00510 }
00511 }
00512
00513
00514 names.destroy_vector();
00515
00516 return index;
00517 }