Go to the documentation of this file.00001 #include <shogun/labels/DenseLabels.h>
00002 #include <shogun/labels/BinaryLabels.h>
00003
00004 using namespace shogun;
00005
00006 CBinaryLabels::CBinaryLabels() : CDenseLabels()
00007 {
00008 }
00009
00010 CBinaryLabels::CBinaryLabels(int32_t num_labels) : CDenseLabels(num_labels)
00011 {
00012 }
00013
00014 CBinaryLabels::CBinaryLabels(SGVector<float64_t> src, float64_t threshold) : CDenseLabels()
00015 {
00016 SGVector<float64_t> labels(src.vlen);
00017 for (int32_t i=0; i<labels.vlen; i++)
00018 labels[i] = src[i]+threshold>=0 ? +1.0 : -1.0;
00019 set_labels(labels);
00020 set_values(src);
00021 }
00022
00023 CBinaryLabels::CBinaryLabels(CFile* loader) : CDenseLabels(loader)
00024 {
00025 }
00026
00027 CBinaryLabels* CBinaryLabels::obtain_from_generic(CLabels* base_labels)
00028 {
00029 if ( base_labels->get_label_type() == LT_BINARY )
00030 return (CBinaryLabels*) base_labels;
00031 else
00032 SG_SERROR("base_labels must be of dynamic type CBinaryLabels");
00033
00034 return NULL;
00035 }
00036
00037
00038 void CBinaryLabels::ensure_valid(const char* context)
00039 {
00040 CDenseLabels::ensure_valid(context);
00041 bool found_plus_one=false;
00042 bool found_minus_one=false;
00043
00044 int32_t subset_size=get_num_labels();
00045 for (int32_t i=0; i<subset_size; i++)
00046 {
00047 int32_t real_i=m_subset_stack->subset_idx_conversion(i);
00048 if (m_labels[real_i]==+1.0)
00049 found_plus_one=true;
00050 else if (m_labels[real_i]==-1.0)
00051 found_minus_one=true;
00052 else
00053 {
00054 SG_ERROR("%s%sNot a two class labeling label[%d]=%f (only +1/-1 "
00055 "allowed)\n", context?context:"", context?": ":"", i, m_labels[real_i]);
00056 }
00057 }
00058
00059 if (!found_plus_one)
00060 {
00061 SG_ERROR("%s%sNot a two class labeling - no positively labeled examples found\n",
00062 context?context:"", context?": ":"");
00063 }
00064
00065 if (!found_minus_one)
00066 {
00067 SG_ERROR("%s%sNot a two class labeling - no negatively labeled examples found\n",
00068 context?context:"", context?": ":"");
00069 }
00070 }
00071
00072 ELabelType CBinaryLabels::get_label_type()
00073 {
00074 return LT_BINARY;
00075 }
00076
00077 void CBinaryLabels::scores_to_probabilities()
00078 {
00079 SG_DEBUG("entering CBinaryLabels::scores_to_probabilities()\n");
00080
00081 REQUIRE(m_current_values.vector, "%s::scores_to_probabilities() requires "
00082 "values vector!\n", get_name());
00083
00084
00085 int32_t prior0=0;
00086 int32_t prior1=0;
00087 SG_DEBUG("counting number of positive and negative labels\n");
00088 {
00089 for (index_t i=0; i<m_current_values.vlen; ++i)
00090 {
00091 if (m_current_values[i]>0)
00092 prior1++;
00093 else
00094 prior0++;
00095 }
00096 }
00097 SG_DEBUG("%d pos; %d neg\n", prior1, prior0);
00098
00099
00100
00101 index_t maxiter=100;
00102
00103
00104 float64_t minstep=1E-10;
00105
00106
00107 float64_t sigma=1E-12;
00108 float64_t eps=1E-5;
00109
00110
00111 float64_t hiTarget=(prior1+1.0)/(prior1+2.0);
00112 float64_t loTarget=1/(prior0+2.0);
00113 index_t length=prior1+prior0;
00114
00115 SGVector<float64_t> t(length);
00116 for (index_t i=0; i<length; ++i)
00117 {
00118 if (m_current_values[i]>0)
00119 t[i]=hiTarget;
00120 else
00121 t[i]=loTarget;
00122 }
00123
00124
00125
00126 float64_t a=0;
00127 float64_t b=CMath::log((prior0+1.0)/(prior1+1.0));
00128 float64_t fval=0.0;
00129
00130 for (index_t i=0; i<length; ++i)
00131 {
00132 float64_t fApB=m_current_values[i]*a+b;
00133 if (fApB>=0)
00134 fval+=t[i]*fApB+CMath::log(1+CMath::exp(-fApB));
00135 else
00136 fval+=(t[i]-1)*fApB+CMath::log(1+CMath::exp(fApB));
00137 }
00138
00139 index_t it;
00140 float64_t g1;
00141 float64_t g2;
00142 for (it=0; it<maxiter; ++it)
00143 {
00144 SG_DEBUG("Iteration %d, a=%f, b=%f, fval=%f\n", it, a, b, fval);
00145
00146
00147 float64_t h11=sigma;
00148 float64_t h22=h11;
00149 float64_t h21=0;
00150 g1=0;
00151 g2=0;
00152
00153 for (index_t i=0; i<length; ++i)
00154 {
00155 float64_t fApB=m_current_values[i]*a+b;
00156 float64_t p;
00157 float64_t q;
00158 if (fApB>=0)
00159 {
00160 p=CMath::exp(-fApB)/(1.0+CMath::exp(-fApB));
00161 q=1.0/(1.0+CMath::exp(-fApB));
00162 }
00163 else
00164 {
00165 p=1.0/(1.0+CMath::exp(fApB));
00166 q=CMath::exp(fApB)/(1.0+CMath::exp(fApB));
00167 }
00168
00169 float64_t d2=p*q;
00170 h11+=m_current_values[i]*m_current_values[i]*d2;
00171 h22+=d2;
00172 h21+=m_current_values[i]*d2;
00173 float64_t d1=t[i]-p;
00174 g1+=m_current_values[i]*d1;
00175 g2+=d1;
00176 }
00177
00178
00179 if (CMath::abs(g1)<eps && CMath::abs(g2)<eps)
00180 break;
00181
00182
00183 float64_t det=h11*h22-h21*h21;
00184 float64_t dA=-(h22*g1-h21*g2)/det;
00185 float64_t dB=-(-h21*g1+h11*g2)/det;
00186 float64_t gd=g1*dA+g2*dB;
00187
00188
00189 float64_t stepsize=1;
00190
00191 while (stepsize>=minstep)
00192 {
00193 float64_t newA=a+stepsize*dA;
00194 float64_t newB=b+stepsize*dB;
00195
00196
00197 float64_t newf=0.0;
00198 for (index_t i=0; i<length; ++i)
00199 {
00200 float64_t fApB=m_current_values[i]*newA+newB;
00201 if (fApB>=0)
00202 newf+=t[i]*fApB+CMath::log(1+CMath::exp(-fApB));
00203 else
00204 newf+=(t[i]-1)*fApB+CMath::log(1+CMath::exp(fApB));
00205 }
00206
00207
00208 if (newf<fval+0.0001*stepsize*gd)
00209 {
00210 a=newA;
00211 b=newB;
00212 fval=newf;
00213 break;
00214 }
00215 else
00216 stepsize=stepsize/2.0;
00217 }
00218
00219 if (stepsize<minstep)
00220 {
00221 SG_WARNING("%s::scores_to_probabilities(): line search fails, A=%f, "
00222 "B=%f, g1=%f, g2=%f, dA=%f, dB=%f, gd=%f\n",
00223 get_name(), a, b, g1, g2, dA, dB, gd);
00224 }
00225 }
00226
00227 if (it>=maxiter-1)
00228 {
00229 SG_WARNING("%s::scores_to_probabilities(): reaching maximal iterations,"
00230 " g1=%f, g2=%f\n", get_name(), g1, g2);
00231 }
00232
00233 SG_DEBUG("fitted sigmoid: a=%f, b=%f\n", a, b);
00234
00235
00236 for (index_t i=0; i<m_current_values.vlen; ++i)
00237 {
00238 float64_t fApB=m_current_values[i]*a+b;
00239 m_current_values[i]=fApB>=0 ? CMath::exp(-fApB)/(1.0+exp(-fApB)) :
00240 1.0/(1+CMath::exp(fApB));
00241 }
00242
00243 SG_DEBUG("leaving CBinaryLabels::scores_to_probabilities()\n");
00244 }