00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #include <shogun/lib/slep/slep_mc_tree_lr.h>
00012 #ifdef HAVE_EIGEN3
00013 #include <shogun/lib/slep/tree/general_altra.h>
00014 #include <shogun/lib/slep/tree/altra.h>
00015 #include <shogun/lib/slep/q1/eppMatrix.h>
00016 #include <shogun/mathematics/Math.h>
00017 #include <shogun/mathematics/eigen3.h>
00018 #include <shogun/lib/Signal.h>
00019 #include <shogun/lib/Time.h>
00020 #include <iostream>
00021
00022 using namespace shogun;
00023 using namespace Eigen;
00024 using namespace std;
00025
00026 namespace shogun
00027 {
00028
00029 slep_result_t slep_mc_tree_lr(
00030 CDotFeatures* features,
00031 CMulticlassLabels* labels,
00032 float64_t z,
00033 const slep_options& options)
00034 {
00035 int i,j;
00036
00037 int n_feats = features->get_dim_feature_space();
00038 int n_vecs = features->get_num_vectors();
00039 int n_classes = labels->get_num_classes();
00040
00041
00042 SGVector<float64_t> labels_vector = labels->get_labels();
00043
00044
00045
00046 MatrixXd w = MatrixXd::Zero(n_feats, n_classes);
00047
00048 VectorXd c = VectorXd::Zero(n_classes);
00049
00050 if (options.last_result)
00051 {
00052 SGMatrix<float64_t> last_w = options.last_result->w;
00053 SGVector<float64_t> last_c = options.last_result->c;
00054 for (i=0; i<n_classes; i++)
00055 {
00056 c[i] = last_c[i];
00057 for (j=0; j<n_feats; j++)
00058 w(j,i) = last_w(j,i);
00059 }
00060 }
00061
00062 MatrixXd wp = w, wwp = MatrixXd::Zero(n_feats, n_classes);
00063 VectorXd cp = c, ccp = VectorXd::Zero(n_classes);
00064
00065 MatrixXd search_w = MatrixXd::Zero(n_feats, n_classes);
00066
00067 VectorXd search_c = VectorXd::Zero(n_classes);
00068
00069 MatrixXd Aw = MatrixXd::Zero(n_vecs, n_classes);
00070 for (j=0; j<n_classes; j++)
00071 features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
00072 MatrixXd As = MatrixXd::Zero(n_vecs, n_classes);
00073 MatrixXd Awp = MatrixXd::Zero(n_vecs, n_classes);
00074
00075 MatrixXd g = MatrixXd::Zero(n_feats, n_classes);
00076 VectorXd gc = VectorXd::Zero(n_classes);
00077
00078 MatrixXd v = MatrixXd::Zero(n_feats, n_classes);
00079
00080
00081 double L = 1.0/(n_vecs*n_classes);
00082
00083 double alphap = 0, alpha = 1;
00084
00085
00086 double lambda = z;
00087
00088 double objective = 0.0;
00089 double objective_p = 0.0;
00090
00091 int iter = 0;
00092 bool done = false;
00093 CTime time;
00094 internal::set_is_malloc_allowed(false);
00095 while ((!done) && (iter<options.max_iter) && (!CSignal::cancel_computations()))
00096 {
00097 double beta = (alphap-1)/alpha;
00098
00099 search_w = w + beta*wwp;
00100 search_c = c + beta*ccp;
00101
00102
00103 As = Aw + beta*(Aw-Awp);
00104
00105
00106 double fun_s = 0;
00107 g.setZero();
00108 gc.setZero();
00109
00110 for (i=0; i<n_vecs; i++)
00111 {
00112
00113 int vec_class = labels_vector[i];
00114
00115 for (j=0; j<n_classes; j++)
00116 {
00117
00118 double aa = ((vec_class == j) ? -1.0 : 1.0)*(As(i,j) + search_c(j));
00119 double bb = aa > 0.0 ? aa : 0.0;
00120
00121 fun_s += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
00122 double prob = 1.0/(1+CMath::exp(aa));
00123 double b = ((vec_class == j) ? -1.0 : 1.0)*(1-prob);
00124
00125 gc[j] += b;
00126
00127 features->add_to_dense_vec(b, i, g.col(j).data(), n_feats);
00128 }
00129 }
00130
00131
00132 wp = w;
00133 Awp = Aw;
00134 cp = c;
00135
00136 int inner_iter = 0;
00137 double fun_x = 0;
00138
00139
00140 while (inner_iter<5000)
00141 {
00142
00143 v = search_w - g/L;
00144 c = search_c - gc/L;
00145
00146
00147 if (options.general)
00148 general_altra_mt(w.data(),v.data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes,lambda/L);
00149 else
00150 altra_mt(w.data(),v.data(),n_classes,n_feats,options.ind_t,options.n_nodes,lambda/L);
00151 v = w - search_w;
00152
00153
00154 for (j=0; j<n_classes; j++)
00155 features->dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
00156
00157
00158 fun_x = 0;
00159 for (i=0; i<n_vecs; i++)
00160 {
00161 int vec_class = labels_vector[i];
00162 for (j=0; j<n_classes; j++)
00163 {
00164 double aa = ((vec_class == j) ? -1.0 : 1.0)*(Aw(i,j) + c(j));
00165 double bb = aa > 0.0 ? aa : 0.0;
00166 fun_x += CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb;
00167 }
00168 }
00169
00170
00171
00172 double r_sum = (v.squaredNorm() + (c-search_c).squaredNorm())/2;
00173 double l_sum = fun_x - fun_s - v.cwiseProduct(g).sum() - (c-search_c).dot(gc);
00174
00175
00176 if (r_sum <= 1e-20)
00177 {
00178 SG_SINFO("Gradient step makes little improvement (%f)\n",r_sum);
00179 done = true;
00180 break;
00181 }
00182
00183 if (l_sum <= r_sum*L)
00184 break;
00185 else
00186 L = CMath::max(2*L, l_sum/r_sum);
00187
00188 inner_iter++;
00189 }
00190
00191
00192 alphap = alpha;
00193 alpha = (1+CMath::sqrt(4*alpha*alpha+1))/2;
00194
00195
00196 wwp = w - wp;
00197 ccp = c - cp;
00198
00199
00200 objective_p = objective;
00201 objective = fun_x;
00202
00203
00204 double tree_norm = 0.0;
00205 if (options.general)
00206 {
00207 for (i=0; i<n_classes; i++)
00208 tree_norm += general_treeNorm(w.col(i).data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes);
00209 }
00210 else
00211 {
00212 for (i=0; i<n_classes; i++)
00213 tree_norm += treeNorm(w.col(i).data(),n_classes,n_feats,options.ind_t,options.n_nodes);
00214 }
00215
00216
00217 objective += lambda*tree_norm;
00218
00219
00220
00221
00222 if ((CMath::abs(objective - objective_p) < options.tolerance*CMath::abs(objective_p)) && (iter>2))
00223 {
00224 SG_SINFO("Objective changes less than tolerance\n");
00225 done = true;
00226 }
00227
00228 iter++;
00229 }
00230 SG_SINFO("%d iterations passed, objective = %f\n",iter,objective);
00231 internal::set_is_malloc_allowed(true);
00232
00233
00234 SGMatrix<float64_t> r_w(n_feats,n_classes);
00235 for (j=0; j<n_classes; j++)
00236 {
00237 for (i=0; i<n_feats; i++)
00238 r_w(i,j) = w(i,j);
00239 }
00240
00241 SGVector<float64_t> r_c(n_classes);
00242 for (j=0; j<n_classes; j++)
00243 r_c[j] = c[j];
00244 return slep_result_t(r_w, r_c);
00245 };
00246 };
00247 #endif