23 using namespace Eigen;
33 const slep_options& options)
46 MatrixXd w = MatrixXd::Zero(n_feats, n_classes);
48 VectorXd c = VectorXd::Zero(n_classes);
50 if (options.last_result)
54 for (i=0; i<n_classes; i++)
57 for (j=0; j<n_feats; j++)
62 MatrixXd wp = w, wwp = MatrixXd::Zero(n_feats, n_classes);
63 VectorXd cp = c, ccp = VectorXd::Zero(n_classes);
65 MatrixXd search_w = MatrixXd::Zero(n_feats, n_classes);
67 VectorXd search_c = VectorXd::Zero(n_classes);
69 MatrixXd Aw = MatrixXd::Zero(n_vecs, n_classes);
70 for (j=0; j<n_classes; j++)
71 features->
dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
72 MatrixXd As = MatrixXd::Zero(n_vecs, n_classes);
73 MatrixXd Awp = MatrixXd::Zero(n_vecs, n_classes);
75 MatrixXd g = MatrixXd::Zero(n_feats, n_classes);
76 VectorXd gc = VectorXd::Zero(n_classes);
78 MatrixXd v = MatrixXd::Zero(n_feats, n_classes);
81 double L = 1.0/(n_vecs*n_classes);
83 double alphap = 0, alpha = 1;
88 double objective = 0.0;
89 double objective_p = 0.0;
97 double beta = (alphap-1)/alpha;
99 search_w = w + beta*wwp;
100 search_c = c + beta*ccp;
103 As = Aw + beta*(Aw-Awp);
110 for (i=0; i<n_vecs; i++)
113 int vec_class = labels_vector[i];
115 for (j=0; j<n_classes; j++)
118 double aa = ((vec_class == j) ? -1.0 : 1.0)*(As(i,j) + search_c(j));
119 double bb = aa > 0.0 ? aa : 0.0;
123 double b = ((vec_class == j) ? -1.0 : 1.0)*(1-prob);
140 while (inner_iter<5000)
148 general_altra_mt(w.data(),v.data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes,lambda/L);
150 altra_mt(w.data(),v.data(),n_classes,n_feats,options.ind_t,options.n_nodes,lambda/L);
154 for (j=0; j<n_classes; j++)
155 features->
dense_dot_range(Aw.col(j).data(), 0, n_vecs, NULL, w.col(j).data(), n_feats, 0.0);
159 for (i=0; i<n_vecs; i++)
161 int vec_class = labels_vector[i];
162 for (j=0; j<n_classes; j++)
164 double aa = ((vec_class == j) ? -1.0 : 1.0)*(Aw(i,j) + c(j));
165 double bb = aa > 0.0 ? aa : 0.0;
172 double r_sum = (v.squaredNorm() + (c-search_c).squaredNorm())/2;
173 double l_sum = fun_x - fun_s - v.cwiseProduct(g).sum() - (c-search_c).
dot(gc);
178 SG_SINFO(
"Gradient step makes little improvement (%f)\n",r_sum)
183 if (l_sum <= r_sum*L)
200 objective_p = objective;
204 double tree_norm = 0.0;
207 for (i=0; i<n_classes; i++)
208 tree_norm +=
general_treeNorm(w.col(i).data(),n_classes,n_feats,options.G,options.ind_t,options.n_nodes);
212 for (i=0; i<n_classes; i++)
213 tree_norm +=
treeNorm(w.col(i).data(),n_classes,n_feats,options.ind_t,options.n_nodes);
217 objective += lambda*tree_norm;
222 if ((
CMath::abs(objective - objective_p) < options.tolerance*
CMath::abs(objective_p)) && (iter>2))
224 SG_SINFO(
"Objective changes less than tolerance\n")
230 SG_SINFO(
"%d iterations passed, objective = %f\n",iter,objective)
235 for (j=0; j<n_classes; j++)
237 for (i=0; i<n_feats; i++)
242 for (j=0; j<n_classes; j++)
244 return slep_result_t(r_w, r_c);
Class Time that implements a stopwatch based on either cpu time or wall clock time.
slep_result_t slep_mc_tree_lr(CDotFeatures *features, CMulticlassLabels *labels, float64_t z, const slep_options &options)
double treeNorm(double *x, int ldx, int n, double *ind, int nodes)
virtual void dense_dot_range(float64_t *output, int32_t start, int32_t stop, float64_t *alphas, float64_t *vec, int32_t dim, float64_t b)
Vector::Scalar dot(Vector a, Vector b)
virtual int32_t get_num_vectors() const =0
virtual void add_to_dense_vec(float64_t alpha, int32_t vec_idx1, float64_t *vec2, int32_t vec2_len, bool abs_val=false)=0
Features that support dot products among other operations.
int32_t get_num_classes()
virtual int32_t get_dim_feature_space() const =0
SGVector< float64_t > get_labels()
Multiclass Labels for multi-class classification.
double general_treeNorm(double *x, int ldx, int n, double *G, double *ind, int nodes)
void general_altra_mt(double *X, double *V, int n, int k, double *G, double *ind, int nodes, double mult)
static bool cancel_computations()
all of classes and functions are contained in the shogun namespace
static float64_t exp(float64_t x)
static float64_t log(float64_t v)
static float32_t sqrt(float32_t x)
void altra_mt(double *X, double *V, int n, int k, double *ind, int nodes, double mult)