18 using namespace Eigen;
37 const malsar_options& options)
44 int n_tasks = options.n_tasks;
48 for (
int i=0; i<n_tasks; i++)
55 MatrixXd Ws = MatrixXd::Zero(n_feats, n_tasks);
56 VectorXd Cs = VectorXd::Zero(n_tasks);
57 MatrixXd Ms = MatrixXd::Identity(n_tasks, n_tasks)*options.n_clusters/n_tasks;
59 MatrixXd IMsqinv = Ms;
60 MatrixXd invEtaMWt = Ms;
62 MatrixXd Wz=Ws, Wzp=Ws, Wz_old=Ws, delta_Wzp=Ws, gWs=Ws;
63 VectorXd Cz=Cs, Czp=Cs, Cz_old=Cs, delta_Czp=Cs, gCs=Cs;
64 MatrixXd Mz=Ms, Mzp=Ms, Mz_old=Ms, delta_Mzp=Ms, gMs=Ms;
67 double eta = rho2/rho1;
68 double c = rho1*eta*(1+eta);
71 double gamma=1, gamma_inc=2;
72 double obj=0.0, obj_old=0.0;
74 double* diag_H =
SG_MALLOC(
double, n_tasks);
81 internal::set_is_malloc_allowed(
false);
83 while (!done && iter <= options.max_iter)
85 double alpha = double(t_old - 1)/t;
89 Ws = (1+alpha)*Wz - alpha*Wz_old;
90 Cs = (1+alpha)*Cz - alpha*Cz_old;
91 Ms = (1+alpha)*Mz - alpha*Mz_old;
96 internal::set_is_malloc_allowed(
true);
98 IM = (eta*MatrixXd::Identity(n_tasks,n_tasks)+Ms);
101 IMsqinv = (IM*IM).inverse();
102 invEtaMWt = IM.inverse()*Ws.transpose();
104 gMs.noalias() = -c*(Ws.transpose()*Ws)*IMsqinv;
105 gWs.noalias() += 2*c*invEtaMWt.transpose();
106 internal::set_is_malloc_allowed(
false);
110 for (task=0; task<n_tasks; task++)
113 int n_vecs_task = task_idx.
vlen;
114 for (
int i=0; i<n_vecs_task; i++)
116 double aa = -y[task_idx[i]]*(features->
dense_dot(task_idx[i], Ws.col(task).data(), n_feats)+Cs[task]);
117 double bb = CMath::max(aa,0.0);
120 Fs += (CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb)/n_vecs_task;
121 double b = -y[task_idx[i]]*(1 - 1/(1+CMath::exp(aa)))/n_vecs_task;
131 Fs += c*(Ws*invEtaMWt).trace();
138 while (inner_iter <= 1000)
140 Wzp = Ws - gWs/gamma;
141 Czp = Cs - gCs/gamma;
143 internal::set_is_malloc_allowed(
true);
144 EigenSolver<MatrixXd> eigensolver(Ms-gMs/gamma);
148 for (
int i=0; i<n_tasks; i++)
151 f[i] = -2*eigensolver.eigenvalues()[i].real();
152 SG_SDEBUG(
"%dth eigenvalue %f\n",i,eigensolver.eigenvalues()[i].real());
156 x[i] = double(options.n_clusters)/n_tasks;
158 double b = options.n_clusters;
161 libqp_state_T problem_state =
libqp_gsmo_solver(&
get_col,diag_H,f,a,b,lb,ub,x,n_tasks,1000,1e-6,NULL);
162 SG_SDEBUG(
"Libqp objective = %f\n",problem_state.QP);
163 SG_SDEBUG(
"Exit code = %d\n",problem_state.exitflag);
164 SG_SDEBUG(
"%d iteration passed\n",problem_state.nIter);
166 for (
int i=0; i<n_tasks; i++)
169 Map<VectorXd> Mzp_DiagSigz(x,n_tasks);
170 Mzp_Pz = eigensolver.eigenvectors().real();
171 Mzp = Mzp_Pz*Mzp_DiagSigz.asDiagonal()*Mzp_Pz.transpose();
172 internal::set_is_malloc_allowed(
false);
174 for (
int i=0; i<n_tasks; i++)
175 Mzp_DiagSigz[i] += eta;
176 internal::set_is_malloc_allowed(
true);
178 (Mzp_DiagSigz.cwiseInverse().asDiagonal())*
181 internal::set_is_malloc_allowed(
false);
184 for (task=0; task<n_tasks; task++)
187 int n_vecs_task = task_idx.
vlen;
188 for (
int i=0; i<n_vecs_task; i++)
190 double aa = -y[task_idx[i]]*(features->
dense_dot(task_idx[i], Wzp.col(task).data(), n_feats)+Cs[task]);
191 double bb = CMath::max(aa,0.0);
193 Fzp += (CMath::log(CMath::exp(-bb) + CMath::exp(aa-bb)) + bb)/n_vecs_task;
196 Fzp += c*(Wzp*invEtaMWt).trace();
199 delta_Wzp = Wzp - Ws;
200 delta_Czp = Czp - Cs;
201 delta_Mzp = Mzp - Ms;
204 double nrm_delta_Wzp = delta_Wzp.squaredNorm();
205 double nrm_delta_Czp = delta_Czp.squaredNorm();
206 double nrm_delta_Mzp = delta_Mzp.squaredNorm();
208 double r_sum = (nrm_delta_Wzp + nrm_delta_Czp + nrm_delta_Mzp)/3;
210 double Fzp_gamma = 0.0;
211 if (n_feats > n_tasks)
213 Fzp_gamma = Fs + (delta_Wzp.transpose()*gWs).trace() +
214 (delta_Czp.transpose()*gCs).trace() +
215 (delta_Mzp.transpose()*gMs).trace() +
216 (gamma/2)*nrm_delta_Wzp +
217 (gamma/2)*nrm_delta_Czp +
218 (gamma/2)*nrm_delta_Mzp;
222 Fzp_gamma = Fs + (gWs.transpose()*delta_Wzp).trace() +
223 (gCs.transpose()*delta_Czp).trace() +
224 (gMs.transpose()*delta_Mzp).trace() +
225 (gamma/2)*nrm_delta_Wzp +
226 (gamma/2)*nrm_delta_Czp +
227 (gamma/2)*nrm_delta_Mzp;
238 if (Fzp <= Fzp_gamma)
258 switch (options.termination)
263 if ( CMath::abs(obj-obj_old) <= options.tolerance )
270 if ( CMath::abs(obj-obj_old) <= options.tolerance*CMath::abs(obj_old))
275 if (CMath::abs(obj) <= options.tolerance)
279 if (iter>=options.max_iter)
286 t = 0.5 * (1 + CMath::sqrt(1.0 + 4*t*t));
288 internal::set_is_malloc_allowed(
true);
289 SG_SDEBUG(
"%d iteration passed, objective = %f\n",iter,obj);
300 for (
int i=0; i<n_feats; i++)
302 for (task=0; task<n_tasks; task++)
303 tasks_w(i,task) = Wzp(i,task);
307 for (
int i=0; i<n_tasks; i++) tasks_c[i] = Czp[i];
308 return malsar_result_t(tasks_w, tasks_c);