00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <shogun/machine/Machine.h>
00013 #include <shogun/base/Parameter.h>
00014 #include <shogun/mathematics/Math.h>
00015 #include <shogun/base/ParameterMap.h>
00016
00017 using namespace shogun;
00018
00019 CMachine::CMachine() : CSGObject(), m_max_train_time(0), m_labels(NULL),
00020 m_solver_type(ST_AUTO)
00021 {
00022 m_data_locked=false;
00023 m_store_model_features=false;
00024
00025 SG_ADD(&m_max_train_time, "max_train_time",
00026 "Maximum training time.", MS_NOT_AVAILABLE);
00027 SG_ADD((machine_int_t*) &m_solver_type, "solver_type",
00028 "Type of solver.", MS_NOT_AVAILABLE);
00029 SG_ADD((CSGObject**) &m_labels, "labels",
00030 "Labels to be used.", MS_NOT_AVAILABLE);
00031 SG_ADD(&m_store_model_features, "store_model_features",
00032 "Should feature data of model be stored after training?", MS_NOT_AVAILABLE);
00033 SG_ADD(&m_data_locked, "data_locked",
00034 "Indicates whether data is locked", MS_NOT_AVAILABLE);
00035
00036 m_parameter_map->put(
00037 new SGParamInfo("data_locked", CT_SCALAR, ST_NONE, PT_BOOL, 1),
00038 new SGParamInfo()
00039 );
00040
00041 m_parameter_map->finalize_map();
00042 }
00043
00044 CMachine::~CMachine()
00045 {
00046 SG_UNREF(m_labels);
00047 }
00048
00049 bool CMachine::train(CFeatures* data)
00050 {
00051
00052 if (m_data_locked)
00053 {
00054 SG_ERROR("%s::train data_lock() was called, only train_locked() is"
00055 " possible. Call data_unlock if you want to call train()\n",
00056 get_name());
00057 }
00058
00059 if (train_require_labels())
00060 {
00061 if (m_labels == NULL)
00062 SG_ERROR("%s@%p: No labels given", get_name(), this);
00063
00064 m_labels->ensure_valid(get_name());
00065 }
00066
00067 bool result = train_machine(data);
00068
00069 if (m_store_model_features)
00070 store_model_features();
00071
00072 return result;
00073 }
00074
00075 void CMachine::set_labels(CLabels* lab)
00076 {
00077 if (lab != NULL)
00078 if (!is_label_valid(lab))
00079 SG_ERROR("Invalid label for %s", get_name());
00080
00081 SG_UNREF(m_labels);
00082 SG_REF(lab);
00083 m_labels = lab;
00084 }
00085
00086 CLabels* CMachine::get_labels()
00087 {
00088 SG_REF(m_labels);
00089 return m_labels;
00090 }
00091
00092 void CMachine::set_max_train_time(float64_t t)
00093 {
00094 m_max_train_time = t;
00095 }
00096
00097 float64_t CMachine::get_max_train_time()
00098 {
00099 return m_max_train_time;
00100 }
00101
00102 EMachineType CMachine::get_classifier_type()
00103 {
00104 return CT_NONE;
00105 }
00106
00107 void CMachine::set_solver_type(ESolverType st)
00108 {
00109 m_solver_type = st;
00110 }
00111
00112 ESolverType CMachine::get_solver_type()
00113 {
00114 return m_solver_type;
00115 }
00116
00117 void CMachine::set_store_model_features(bool store_model)
00118 {
00119 m_store_model_features = store_model;
00120 }
00121
00122 void CMachine::data_lock(CLabels* labs, CFeatures* features)
00123 {
00124 SG_DEBUG("entering %s::data_lock\n", get_name());
00125 if (!supports_locking())
00126 {
00127 {
00128 SG_ERROR("%s::data_lock(): Machine does not support data locking!\n",
00129 get_name());
00130 }
00131 }
00132
00133 if (!labs)
00134 {
00135 SG_ERROR("%s::data_lock() is not possible will NULL labels!\n",
00136 get_name());
00137 }
00138
00139
00140 set_labels(labs);
00141
00142 if (m_data_locked)
00143 {
00144 SG_ERROR("%s::data_lock() was already called. Dont lock twice!",
00145 get_name());
00146 }
00147
00148 m_data_locked=true;
00149 post_lock(labs,features);
00150 SG_DEBUG("leaving %s::data_lock\n", get_name());
00151 }
00152
00153 void CMachine::data_unlock()
00154 {
00155 SG_DEBUG("entering %s::data_lock\n", get_name());
00156 if (m_data_locked)
00157 m_data_locked=false;
00158
00159 SG_DEBUG("leaving %s::data_lock\n", get_name());
00160 }
00161
00162 CLabels* CMachine::apply(CFeatures* data)
00163 {
00164 SG_DEBUG("entering %s::apply(%s at %p)\n",
00165 get_name(), data ? data->get_name() : "NULL", data);
00166
00167 CLabels* result=NULL;
00168
00169 switch (get_machine_problem_type())
00170 {
00171 case PT_BINARY:
00172 result=apply_binary(data);
00173 break;
00174 case PT_REGRESSION:
00175 result=apply_regression(data);
00176 break;
00177 case PT_MULTICLASS:
00178 result=apply_multiclass(data);
00179 break;
00180 case PT_STRUCTURED:
00181 result=apply_structured(data);
00182 break;
00183 case PT_LATENT:
00184 result=apply_latent(data);
00185 break;
00186 default:
00187 SG_ERROR("Unknown problem type");
00188 break;
00189 }
00190
00191 SG_DEBUG("leaving %s::apply(%s at %p)\n",
00192 get_name(), data ? data->get_name() : "NULL", data);
00193
00194 return result;
00195 }
00196
00197 CLabels* CMachine::apply_locked(SGVector<index_t> indices)
00198 {
00199 switch (get_machine_problem_type())
00200 {
00201 case PT_BINARY:
00202 return apply_locked_binary(indices);
00203 case PT_REGRESSION:
00204 return apply_locked_regression(indices);
00205 case PT_MULTICLASS:
00206 return apply_locked_multiclass(indices);
00207 case PT_STRUCTURED:
00208 return apply_locked_structured(indices);
00209 case PT_LATENT:
00210 return apply_locked_latent(indices);
00211 default:
00212 SG_ERROR("Unknown problem type");
00213 break;
00214 }
00215 return NULL;
00216 }
00217
00218 CBinaryLabels* CMachine::apply_binary(CFeatures* data)
00219 {
00220 SG_ERROR("This machine does not support apply_binary()\n");
00221 return NULL;
00222 }
00223
00224 CRegressionLabels* CMachine::apply_regression(CFeatures* data)
00225 {
00226 SG_ERROR("This machine does not support apply_regression()\n");
00227 return NULL;
00228 }
00229
00230 CMulticlassLabels* CMachine::apply_multiclass(CFeatures* data)
00231 {
00232 SG_ERROR("This machine does not support apply_multiclass()\n");
00233 return NULL;
00234 }
00235
00236 CStructuredLabels* CMachine::apply_structured(CFeatures* data)
00237 {
00238 SG_ERROR("This machine does not support apply_structured()\n");
00239 return NULL;
00240 }
00241
00242 CLatentLabels* CMachine::apply_latent(CFeatures* data)
00243 {
00244 SG_ERROR("This machine does not support apply_latent()\n");
00245 return NULL;
00246 }
00247
00248 CBinaryLabels* CMachine::apply_locked_binary(SGVector<index_t> indices)
00249 {
00250 SG_ERROR("apply_locked_binary(SGVector<index_t>) is not yet implemented "
00251 "for %s\n", get_name());
00252 return NULL;
00253 }
00254
00255 CRegressionLabels* CMachine::apply_locked_regression(SGVector<index_t> indices)
00256 {
00257 SG_ERROR("apply_locked_regression(SGVector<index_t>) is not yet implemented "
00258 "for %s\n", get_name());
00259 return NULL;
00260 }
00261
00262 CMulticlassLabels* CMachine::apply_locked_multiclass(SGVector<index_t> indices)
00263 {
00264 SG_ERROR("apply_locked_multiclass(SGVector<index_t>) is not yet implemented "
00265 "for %s\n", get_name());
00266 return NULL;
00267 }
00268
00269 CStructuredLabels* CMachine::apply_locked_structured(SGVector<index_t> indices)
00270 {
00271 SG_ERROR("apply_locked_structured(SGVector<index_t>) is not yet implemented "
00272 "for %s\n", get_name());
00273 return NULL;
00274 }
00275
00276 CLatentLabels* CMachine::apply_locked_latent(SGVector<index_t> indices)
00277 {
00278 SG_ERROR("apply_locked_latent(SGVector<index_t>) is not yet implemented "
00279 "for %s\n", get_name());
00280 return NULL;
00281 }
00282
00283