00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include <stdio.h>
00013 #include <string.h>
00014
00015 #include <shogun/lib/config.h>
00016 #include <shogun/io/SGIO.h>
00017 #include <shogun/structure/Plif.h>
00018
00019
00020
00021 using namespace shogun;
00022
00023 CPlif::CPlif(int32_t l)
00024 : CPlifBase()
00025 {
00026 limits=SGVector<float64_t>();
00027 penalties=SGVector<float64_t>();
00028 cum_derivatives=SGVector<float64_t>();
00029 id=-1;
00030 transform=T_LINEAR;
00031 name=NULL;
00032 max_value=0;
00033 min_value=0;
00034 cache=NULL;
00035 use_svm=0;
00036 use_cache=false;
00037 len=0;
00038 do_calc = true;
00039 if (l>0)
00040 set_plif_length(l);
00041 }
00042
00043 CPlif::~CPlif()
00044 {
00045 SG_FREE(name);
00046 SG_FREE(cache);
00047 }
00048
00049 bool CPlif::set_transform_type(const char *type_str)
00050 {
00051 invalidate_cache();
00052
00053 if (strcmp(type_str, "linear")==0)
00054 transform = T_LINEAR ;
00055 else if (strcmp(type_str, "")==0)
00056 transform = T_LINEAR ;
00057 else if (strcmp(type_str, "log")==0)
00058 transform = T_LOG ;
00059 else if (strcmp(type_str, "log(+1)")==0)
00060 transform = T_LOG_PLUS1 ;
00061 else if (strcmp(type_str, "log(+3)")==0)
00062 transform = T_LOG_PLUS3 ;
00063 else if (strcmp(type_str, "(+3)")==0)
00064 transform = T_LINEAR_PLUS3 ;
00065 else
00066 {
00067 SG_ERROR( "unknown transform type (%s)\n", type_str) ;
00068 return false ;
00069 }
00070 return true ;
00071 }
00072
00073 void CPlif::init_penalty_struct_cache()
00074 {
00075 if (!use_cache)
00076 return ;
00077 if (cache || use_svm)
00078 return ;
00079 if (max_value<=0)
00080 return ;
00081
00082 float64_t* local_cache=SG_MALLOC(float64_t, ((int32_t) max_value) + 2);
00083
00084 if (local_cache)
00085 {
00086 for (int32_t i=0; i<=max_value; i++)
00087 {
00088 if (i<min_value)
00089 local_cache[i] = -CMath::INFTY ;
00090 else
00091 local_cache[i] = lookup_penalty(i, NULL) ;
00092 }
00093 }
00094 this->cache=local_cache ;
00095 }
00096
00097 void CPlif::set_plif_name(char *p_name)
00098 {
00099 SG_FREE(name);
00100 name=SG_MALLOC(char, strlen(p_name)+3);
00101 strcpy(name,p_name) ;
00102 }
00103
00104 void CPlif::delete_penalty_struct(CPlif** PEN, int32_t P)
00105 {
00106 for (int32_t i=0; i<P; i++)
00107 delete PEN[i] ;
00108 SG_FREE(PEN);
00109 }
00110
00111 float64_t CPlif::lookup_penalty_svm(
00112 float64_t p_value, float64_t *d_values) const
00113 {
00114 ASSERT(use_svm>0);
00115 float64_t d_value=d_values[use_svm-1] ;
00116 #ifdef PLIF_DEBUG
00117 SG_PRINT("%s.lookup_penalty_svm(%f)\n", get_name(), d_value) ;
00118 #endif
00119
00120 if (!do_calc)
00121 return d_value;
00122 switch (transform)
00123 {
00124 case T_LINEAR:
00125 break ;
00126 case T_LOG:
00127 d_value = log(d_value) ;
00128 break ;
00129 case T_LOG_PLUS1:
00130 d_value = log(d_value+1) ;
00131 break ;
00132 case T_LOG_PLUS3:
00133 d_value = log(d_value+3) ;
00134 break ;
00135 case T_LINEAR_PLUS3:
00136 d_value = d_value+3 ;
00137 break ;
00138 default:
00139 SG_ERROR("unknown transform\n");
00140 break ;
00141 }
00142
00143 int32_t idx = 0 ;
00144 float64_t ret ;
00145 for (int32_t i=0; i<len; i++)
00146 if (limits[i]<=d_value)
00147 idx++ ;
00148 else
00149 break ;
00150
00151 #ifdef PLIF_DEBUG
00152 SG_PRINT(" -> idx = %i ", idx) ;
00153 #endif
00154
00155 if (idx==0)
00156 ret=penalties[0] ;
00157 else if (idx==len)
00158 ret=penalties[len-1] ;
00159 else
00160 {
00161 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00162 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00163 #ifdef PLIF_DEBUG
00164 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f)", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00165 #endif
00166 }
00167 #ifdef PLIF_DEBUG
00168 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00169 #endif
00170
00171 return ret ;
00172 }
00173
00174 float64_t CPlif::lookup_penalty(int32_t p_value, float64_t* svm_values) const
00175 {
00176 if (use_svm)
00177 return lookup_penalty_svm(p_value, svm_values) ;
00178
00179 if ((p_value<min_value) || (p_value>max_value))
00180 {
00181
00182 return -CMath::INFTY ;
00183 }
00184 if (!do_calc)
00185 return p_value;
00186 if (cache!=NULL && (p_value>=0) && (p_value<=max_value))
00187 {
00188 float64_t ret=cache[p_value] ;
00189 return ret ;
00190 }
00191 return lookup_penalty((float64_t) p_value, svm_values) ;
00192 }
00193
00194 float64_t CPlif::lookup_penalty(float64_t p_value, float64_t* svm_values) const
00195 {
00196 if (use_svm)
00197 return lookup_penalty_svm(p_value, svm_values) ;
00198
00199 #ifdef PLIF_DEBUG
00200 SG_PRINT("%s.lookup_penalty(%f)\n", get_name(), p_value) ;
00201 #endif
00202
00203
00204 if ((p_value<min_value) || (p_value>max_value))
00205 {
00206
00207 return -CMath::INFTY ;
00208 }
00209
00210 if (!do_calc)
00211 return p_value;
00212
00213 float64_t d_value = (float64_t) p_value ;
00214 switch (transform)
00215 {
00216 case T_LINEAR:
00217 break ;
00218 case T_LOG:
00219 d_value = log(d_value) ;
00220 break ;
00221 case T_LOG_PLUS1:
00222 d_value = log(d_value+1) ;
00223 break ;
00224 case T_LOG_PLUS3:
00225 d_value = log(d_value+3) ;
00226 break ;
00227 case T_LINEAR_PLUS3:
00228 d_value = d_value+3 ;
00229 break ;
00230 default:
00231 SG_ERROR( "unknown transform\n") ;
00232 break ;
00233 }
00234
00235 #ifdef PLIF_DEBUG
00236 SG_PRINT(" -> value = %1.4f ", d_value) ;
00237 #endif
00238
00239 int32_t idx = 0 ;
00240 float64_t ret ;
00241 for (int32_t i=0; i<len; i++)
00242 if (limits[i]<=d_value)
00243 idx++ ;
00244 else
00245 break ;
00246
00247 #ifdef PLIF_DEBUG
00248 SG_PRINT(" -> idx = %i ", idx) ;
00249 #endif
00250
00251 if (idx==0)
00252 ret=penalties[0] ;
00253 else if (idx==len)
00254 ret=penalties[len-1] ;
00255 else
00256 {
00257 ret = (penalties[idx]*(d_value-limits[idx-1]) + penalties[idx-1]*
00258 (limits[idx]-d_value)) / (limits[idx]-limits[idx-1]) ;
00259 #ifdef PLIF_DEBUG
00260 SG_PRINT(" -> (%1.3f*%1.3f, %1.3f*%1.3f) ", (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]), penalties[idx], (limits[idx]-d_value)/(limits[idx]-limits[idx-1]), penalties[idx-1]) ;
00261 #endif
00262 }
00263
00264
00265 #ifdef PLIF_DEBUG
00266 SG_PRINT(" -> ret=%1.3f\n", ret) ;
00267 #endif
00268
00269 return ret ;
00270 }
00271
00272 void CPlif::penalty_clear_derivative()
00273 {
00274 for (int32_t i=0; i<len; i++)
00275 cum_derivatives[i]=0.0 ;
00276 }
00277
00278 void CPlif::penalty_add_derivative(float64_t p_value, float64_t* svm_values, float64_t factor)
00279 {
00280 if (use_svm)
00281 {
00282 penalty_add_derivative_svm(p_value, svm_values, factor) ;
00283 return ;
00284 }
00285
00286 if ((p_value<min_value) || (p_value>max_value))
00287 {
00288 return ;
00289 }
00290 float64_t d_value = (float64_t) p_value ;
00291 switch (transform)
00292 {
00293 case T_LINEAR:
00294 break ;
00295 case T_LOG:
00296 d_value = log(d_value) ;
00297 break ;
00298 case T_LOG_PLUS1:
00299 d_value = log(d_value+1) ;
00300 break ;
00301 case T_LOG_PLUS3:
00302 d_value = log(d_value+3) ;
00303 break ;
00304 case T_LINEAR_PLUS3:
00305 d_value = d_value+3 ;
00306 break ;
00307 default:
00308 SG_ERROR( "unknown transform\n") ;
00309 break ;
00310 }
00311
00312 int32_t idx = 0 ;
00313 for (int32_t i=0; i<len; i++)
00314 if (limits[i]<=d_value)
00315 idx++ ;
00316 else
00317 break ;
00318
00319 if (idx==0)
00320 cum_derivatives[0]+= factor ;
00321 else if (idx==len)
00322 cum_derivatives[len-1]+= factor ;
00323 else
00324 {
00325 cum_derivatives[idx] += factor * (d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00326 cum_derivatives[idx-1]+= factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00327 }
00328 }
00329
00330 void CPlif::penalty_add_derivative_svm(float64_t p_value, float64_t *d_values, float64_t factor)
00331 {
00332 ASSERT(use_svm>0);
00333 float64_t d_value=d_values[use_svm-1] ;
00334
00335 if (d_value<-1e+20)
00336 return;
00337
00338 switch (transform)
00339 {
00340 case T_LINEAR:
00341 break ;
00342 case T_LOG:
00343 d_value = log(d_value) ;
00344 break ;
00345 case T_LOG_PLUS1:
00346 d_value = log(d_value+1) ;
00347 break ;
00348 case T_LOG_PLUS3:
00349 d_value = log(d_value+3) ;
00350 break ;
00351 case T_LINEAR_PLUS3:
00352 d_value = d_value+3 ;
00353 break ;
00354 default:
00355 SG_ERROR( "unknown transform\n") ;
00356 break ;
00357 }
00358
00359 int32_t idx = 0 ;
00360 for (int32_t i=0; i<len; i++)
00361 if (limits[i]<=d_value)
00362 idx++ ;
00363 else
00364 break ;
00365
00366 if (idx==0)
00367 cum_derivatives[0]+=factor ;
00368 else if (idx==len)
00369 cum_derivatives[len-1]+=factor ;
00370 else
00371 {
00372 cum_derivatives[idx] += factor*(d_value-limits[idx-1])/(limits[idx]-limits[idx-1]) ;
00373 cum_derivatives[idx-1] += factor*(limits[idx]-d_value)/(limits[idx]-limits[idx-1]) ;
00374 }
00375 }
00376
00377 void CPlif::get_used_svms(int32_t* num_svms, int32_t* svm_ids)
00378 {
00379 if (use_svm)
00380 {
00381 svm_ids[(*num_svms)] = use_svm;
00382 (*num_svms)++;
00383 }
00384 SG_PRINT("->use_svm:%i plif_id:%i name:%s trans_type:%s ",use_svm, get_id(), get_name(), get_transform_type());
00385 }
00386
00387 bool CPlif::get_do_calc()
00388 {
00389 return do_calc;
00390 }
00391
00392 void CPlif::set_do_calc(bool b)
00393 {
00394 do_calc = b;;
00395 }