SGObject.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2008-2009 Soeren Sonnenburg
00008  * Copyright (C) 2008-2009 Fraunhofer Institute FIRST and Max Planck Society
00009  */
00010 
00011 #include <shogun/lib/config.h>
00012 #include <shogun/base/SGObject.h>
00013 #include <shogun/io/SGIO.h>
00014 #include <shogun/base/Parallel.h>
00015 #include <shogun/base/init.h>
00016 #include <shogun/base/Version.h>
00017 #include <shogun/base/Parameter.h>
00018 
00019 #include <stdlib.h>
00020 #include <stdio.h>
00021 
00022 
00023 namespace shogun
00024 {
00025     class CMath;
00026     class Parallel;
00027     class IO;
00028     class Version;
00029 
00030     extern CMath* sg_math;
00031     extern Parallel* sg_parallel;
00032     extern SGIO* sg_io;
00033     extern Version* sg_version;
00034 
00035     template<> void CSGObject::set_generic<bool>()
00036     {
00037         m_generic = PT_BOOL;
00038     }
00039 
00040     template<> void CSGObject::set_generic<char>()
00041     {
00042         m_generic = PT_CHAR;
00043     }
00044 
00045     template<> void CSGObject::set_generic<int8_t>()
00046     {
00047         m_generic = PT_INT8;
00048     }
00049 
00050     template<> void CSGObject::set_generic<uint8_t>()
00051     {
00052         m_generic = PT_UINT8;
00053     }
00054 
00055     template<> void CSGObject::set_generic<int16_t>()
00056     {
00057         m_generic = PT_INT16;
00058     }
00059 
00060     template<> void CSGObject::set_generic<uint16_t>()
00061     {
00062         m_generic = PT_UINT16;
00063     }
00064 
00065     template<> void CSGObject::set_generic<int32_t>()
00066     {
00067         m_generic = PT_INT32;
00068     }
00069 
00070     template<> void CSGObject::set_generic<uint32_t>()
00071     {
00072         m_generic = PT_UINT32;
00073     }
00074 
00075     template<> void CSGObject::set_generic<int64_t>()
00076     {
00077         m_generic = PT_INT64;
00078     }
00079 
00080     template<> void CSGObject::set_generic<uint64_t>()
00081     {
00082         m_generic = PT_UINT64;
00083     }
00084 
00085     template<> void CSGObject::set_generic<float32_t>()
00086     {
00087         m_generic = PT_FLOAT32;
00088     }
00089 
00090     template<> void CSGObject::set_generic<float64_t>()
00091     {
00092         m_generic = PT_FLOAT64;
00093     }
00094 
00095     template<> void CSGObject::set_generic<floatmax_t>()
00096     {
00097         m_generic = PT_FLOATMAX;
00098     }
00099 
00100 } /* namespace shogun  */
00101 
00102 using namespace shogun;
00103 
00104 CSGObject::CSGObject()
00105 {
00106     init();
00107     set_global_objects();
00108 
00109     SG_GCDEBUG("SGObject created (%p)\n", this);
00110 }
00111 
00112 CSGObject::CSGObject(const CSGObject& orig)
00113 :io(orig.io), parallel(orig.parallel), version(orig.version)
00114 {
00115     init();
00116     set_global_objects();
00117 }
00118 
00119 CSGObject::~CSGObject()
00120 {
00121     SG_GCDEBUG("SGObject destroyed (%p)\n", this);
00122 
00123 #ifdef HAVE_PTHREAD
00124     PTHREAD_LOCK_DESTROY(&m_ref_lock);
00125 #endif
00126     unset_global_objects();
00127     delete m_parameters;
00128     delete m_model_selection_parameters;
00129 }
00130 
00131 #ifdef USE_REFERENCE_COUNTING
00132 
00133 int32_t CSGObject::ref()
00134 {
00135 #ifdef HAVE_PTHREAD
00136         PTHREAD_LOCK(&m_ref_lock);
00137 #endif //HAVE_PTHREAD
00138         ++m_refcount;
00139         int32_t count=m_refcount;
00140 #ifdef HAVE_PTHREAD
00141         PTHREAD_UNLOCK(&m_ref_lock);
00142 #endif //HAVE_PTHREAD
00143         SG_GCDEBUG("ref() refcount %ld obj %s (%p) increased\n", count, this->get_name(), this);
00144         return m_refcount;
00145 }
00146 
00147 int32_t CSGObject::ref_count()
00148 {
00149 #ifdef HAVE_PTHREAD
00150     PTHREAD_LOCK(&m_ref_lock);
00151 #endif //HAVE_PTHREAD
00152     int32_t count=m_refcount;
00153 #ifdef HAVE_PTHREAD
00154     PTHREAD_UNLOCK(&m_ref_lock);
00155 #endif //HAVE_PTHREAD
00156     SG_GCDEBUG("ref_count(): refcount %d, obj %s (%p)\n", count, this->get_name(), this);
00157     return count;
00158 }
00159 
00160 int32_t CSGObject::unref()
00161 {
00162 #ifdef HAVE_PTHREAD
00163     PTHREAD_LOCK(&m_ref_lock);
00164 #endif //HAVE_PTHREAD
00165     if (m_refcount==0 || --m_refcount==0)
00166     {
00167         SG_GCDEBUG("unref() refcount %ld, obj %s (%p) destroying\n", m_refcount, this->get_name(), this);
00168 #ifdef HAVE_PTHREAD
00169         PTHREAD_UNLOCK(&m_ref_lock);
00170 #endif //HAVE_PTHREAD
00171         delete this;
00172         return 0;
00173     }
00174     else
00175     {
00176         SG_GCDEBUG("unref() refcount %ld obj %s (%p) decreased\n", m_refcount, this->get_name(), this);
00177 #ifdef HAVE_PTHREAD
00178         PTHREAD_UNLOCK(&m_ref_lock);
00179 #endif //HAVE_PTHREAD
00180         return m_refcount;
00181     }
00182 }
00183 #endif //USE_REFERENCE_COUNTING
00184 
00185 
00186 void CSGObject::set_global_objects()
00187 {
00188     if (!sg_io || !sg_parallel || !sg_version)
00189     {
00190         fprintf(stderr, "call init_shogun() before using the library, dying.\n");
00191         exit(1);
00192     }
00193 
00194     SG_REF(sg_io);
00195     SG_REF(sg_parallel);
00196     SG_REF(sg_version);
00197 
00198     io=sg_io;
00199     parallel=sg_parallel;
00200     version=sg_version;
00201 }
00202 
00203 void CSGObject::unset_global_objects()
00204 {
00205     SG_UNREF(version);
00206     SG_UNREF(parallel);
00207     SG_UNREF(io);
00208 }
00209 
00210 void CSGObject::set_global_io(SGIO* new_io)
00211 {
00212     SG_UNREF(sg_io);
00213     sg_io=new_io;
00214     SG_REF(sg_io);
00215 }
00216 
00217 SGIO* CSGObject::get_global_io()
00218 {
00219     SG_REF(sg_io);
00220     return sg_io;
00221 }
00222 
00223 void CSGObject::set_global_parallel(Parallel* new_parallel)
00224 {
00225     SG_UNREF(sg_parallel);
00226     sg_parallel=new_parallel;
00227     SG_REF(sg_parallel);
00228 }
00229 
00230 Parallel* CSGObject::get_global_parallel()
00231 {
00232     SG_REF(sg_parallel);
00233     return sg_parallel;
00234 }
00235 
00236 void CSGObject::set_global_version(Version* new_version)
00237 {
00238     SG_UNREF(sg_version);
00239     sg_version=new_version;
00240     SG_REF(sg_version);
00241 }
00242 
00243 Version* CSGObject::get_global_version()
00244 {
00245     SG_REF(sg_version);
00246     return sg_version;
00247 }
00248 
00249 bool CSGObject::is_generic(EPrimitiveType* generic) const
00250 {
00251     *generic = m_generic;
00252 
00253     return m_generic != PT_NOT_GENERIC;
00254 }
00255 
00256 void CSGObject::unset_generic()
00257 {
00258     m_generic = PT_NOT_GENERIC;
00259 }
00260 
00261 void CSGObject::print_serializable(const char* prefix)
00262 {
00263     SG_PRINT("\n%s\n================================================================================\n", get_name());
00264     m_parameters->print(prefix);
00265 }
00266 
00267 bool CSGObject::save_serializable(CSerializableFile* file,
00268                                    const char* prefix)
00269 {
00270     SG_DEBUG("START SAVING CSGObject '%s'\n", get_name());
00271     try
00272     {
00273         save_serializable_pre();
00274     }
00275     catch (ShogunException e)
00276     {
00277         SG_SWARNING("%s%s::save_serializable_pre(): ShogunException: "
00278                    "%s\n", prefix, get_name(),
00279                    e.get_exception_string());
00280         return false;
00281     }
00282     if (!m_save_pre_called)
00283     {
00284         SG_SWARNING("%s%s::save_serializable_pre(): Implementation "
00285                    "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not "
00286                    "called!\n", prefix, get_name());
00287         return false;
00288     }
00289 
00290     /* save parameter version */
00291     if (!save_parameter_version(file, prefix))
00292         return false;
00293 
00294     if (!m_parameters->save(file, prefix))
00295         return false;
00296 
00297     try
00298     {
00299         save_serializable_post();
00300     }
00301     catch (ShogunException e)
00302     {
00303         SG_SWARNING("%s%s::save_serializable_post(): ShogunException: "
00304                    "%s\n", prefix, get_name(),
00305                    e.get_exception_string());
00306         return false;
00307     }
00308 
00309     if (!m_save_post_called)
00310     {
00311         SG_SWARNING("%s%s::save_serializable_post(): Implementation "
00312                    "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not "
00313                    "called!\n", prefix, get_name());
00314         return false;
00315     }
00316 
00317     if (prefix == NULL || *prefix == '\0')
00318         file->close();
00319 
00320     SG_DEBUG("DONE SAVING CSGObject '%s' (%p)\n", get_name(), this);
00321 
00322     return true;;
00323 }
00324 
00325 bool CSGObject::load_serializable(CSerializableFile* file,
00326                                    const char* prefix)
00327 {
00328     SG_DEBUG("START LOADING CSGObject '%s'\n", get_name());
00329     try
00330     {
00331         load_serializable_pre();
00332     }
00333     catch (ShogunException e)
00334     {
00335         SG_SWARNING("%s%s::load_serializable_pre(): ShogunException: "
00336                    "%s\n", prefix, get_name(),
00337                    e.get_exception_string());
00338         return false;
00339     }
00340     if (!m_load_pre_called)
00341     {
00342         SG_SWARNING("%s%s::load_serializable_pre(): Implementation "
00343                    "error: BASE_CLASS::LOAD_SERIALIZABLE_PRE() not "
00344                    "called!\n", prefix, get_name());
00345         return false;
00346     }
00347 
00348     /* try to load version of parameters */
00349     int32_t file_version=load_parameter_version(file, prefix);
00350 
00351     if (file_version<0)
00352     {
00353         SG_WARNING("%s%s::load_serializable(): File contains no parameter "
00354                    "version. Seems like your file is from the days before this "
00355                    "was introduced. Ignore warning or serialize with this version "
00356                    "of shogun to get rid of above and this warnings.\n",
00357                    prefix, get_name());
00358     }
00359 
00360     if (file_version>version->get_version_parameter())
00361     {
00362         SG_WARNING("%s%s::load_serializable(): parameter version of file "
00363                    "larger than the one of shogun. Try with a more recent version "
00364                    "of shogun.\n", prefix, get_name());
00365         return false;
00366     }
00367 
00368     if (!m_parameters->load(file, prefix))
00369         return false;
00370 
00371     try
00372     {
00373         load_serializable_post();
00374     }
00375     catch (ShogunException e)
00376     {
00377         SG_SWARNING("%s%s::load_serializable_post(): ShogunException: "
00378                     "%s\n", prefix, get_name(),
00379                     e.get_exception_string());
00380         return false;
00381     }
00382 
00383     if (!m_load_post_called)
00384     {
00385         SG_SWARNING("%s%s::load_serializable_post(): Implementation "
00386                     "error: BASE_CLASS::LOAD_SERIALIZABLE_POST() not "
00387                     "called!\n", prefix, get_name());
00388         return false;
00389     }
00390     SG_DEBUG("DONE LOADING CSGObject '%s' (%p)\n", get_name(), this);
00391 
00392     return true;
00393 }
00394 
00395 bool CSGObject::save_parameter_version(CSerializableFile* file,
00396         const char* prefix)
00397 {
00398     TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32);
00399     int32_t v=version->get_version_parameter();
00400     TParameter p(&t, &v, "version_parameter",
00401                  "Version of parameters of this object");
00402     return p.save(file, prefix);
00403 }
00404 
00405 int32_t CSGObject::load_parameter_version(CSerializableFile* file,
00406         const char* prefix)
00407 {
00408     TSGDataType t(CT_SCALAR, ST_NONE, PT_INT32);
00409     int32_t v;
00410     TParameter tp(&t, &v, "version_parameter", "");
00411     if (tp.load(file, prefix))
00412         return v;
00413     else
00414         return -1;
00415 }
00416 
00417 void CSGObject::load_serializable_pre() throw (ShogunException)
00418 {
00419     m_load_pre_called = true;
00420 }
00421 
00422 void CSGObject::load_serializable_post() throw (ShogunException)
00423 {
00424     m_load_post_called = true;
00425 }
00426 
00427 void CSGObject::save_serializable_pre() throw (ShogunException)
00428 {
00429     m_save_pre_called = true;
00430 }
00431 
00432 void CSGObject::save_serializable_post() throw (ShogunException)
00433 {
00434     m_save_post_called = true;
00435 }
00436 
00437 #ifdef TRACE_MEMORY_ALLOCS
00438 #include <shogun/lib/Set.h>
00439 extern CSet<shogun::MemoryBlock>* sg_mallocs;
00440 #endif
00441 
00442 void CSGObject::init()
00443 {
00444 #ifdef HAVE_PTHREAD
00445     PTHREAD_LOCK_INIT(&m_ref_lock);
00446 #endif
00447 
00448 #ifdef TRACE_MEMORY_ALLOCS
00449     if (sg_mallocs)
00450     {
00451         int32_t idx=sg_mallocs->index_of(MemoryBlock(this));
00452         if (idx>-1)
00453         {
00454             MemoryBlock* b=sg_mallocs->get_element_ptr(idx);
00455             b->set_sgobject();
00456         }
00457     }
00458 #endif
00459 
00460     m_refcount = 0;
00461     io = NULL;
00462     parallel = NULL;
00463     version = NULL;
00464     m_parameters = new Parameter();
00465     m_model_selection_parameters = new Parameter();
00466     m_generic = PT_NOT_GENERIC;
00467     m_load_pre_called = false;
00468     m_load_post_called = false;
00469 }
00470 
00471 SGVector<char*> CSGObject::get_modelsel_names()
00472 {
00473     SGVector<char*> result=SGVector<char*>(
00474             m_model_selection_parameters->get_num_parameters());
00475 
00476     for (index_t i=0; i<result.vlen; ++i)
00477         result.vector[i]=m_model_selection_parameters->get_parameter(i)->m_name;
00478 
00479     return result;
00480 }
00481 
00482 char* CSGObject::get_modsel_param_descr(const char* param_name)
00483 {
00484     index_t index=get_modsel_param_index(param_name);
00485 
00486     if (index<0)
00487     {
00488         SG_ERROR("There is no model selection parameter called \"%s\" for %s",
00489                 param_name, get_name());
00490     }
00491 
00492     return m_model_selection_parameters->get_parameter(index)->m_description;
00493 }
00494 
00495 index_t CSGObject::get_modsel_param_index(const char* param_name)
00496 {
00497     /* use fact that names extracted from below method are in same order than
00498      * in m_model_selection_parameters variable */
00499     SGVector<char*> names=get_modelsel_names();
00500 
00501     /* search for parameter with provided name */
00502     index_t index=-1;
00503     for (index_t i=0; i<names.vlen; ++i)
00504     {
00505         TParameter* current=m_model_selection_parameters->get_parameter(i);
00506         if (!strcmp(param_name, current->m_name))
00507         {
00508             index=i;
00509             break;
00510         }
00511     }
00512 
00513     /* clean up */
00514     names.destroy_vector();
00515 
00516     return index;
00517 }
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines

SHOGUN Machine Learning Toolbox - Documentation