SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
CGMShiftedFamilySolver.cpp
浏览该文件的文档.
1 /*
2  * This program is free software; you can redistribute it and/or modify
3  * it under the terms of the GNU General Public License as published by
4  * the Free Software Foundation; either version 3 of the License, or
5  * (at your option) any later version.
6  *
7  * Written (W) 2013 Soumyajit De
8  */
9 
10 #include <shogun/lib/common.h>
11 
12 #ifdef HAVE_EIGEN3
13 
14 #include <shogun/lib/SGVector.h>
15 #include <shogun/lib/Time.h>
21 
22 using namespace Eigen;
23 
24 namespace shogun
25 {
26 
27 CCGMShiftedFamilySolver::CCGMShiftedFamilySolver()
29 {
30 }
31 
34 {
35 }
36 
38 {
39 }
40 
43 {
44  SGVector<complex128_t> shifts(1);
45  shifts[0]=0.0;
46  SGVector<complex128_t> weights(1);
47  weights[0]=1.0;
48 
49  return solve_shifted_weighted(A, b, shifts, weights).get_real();
50 }
51 
55 {
56  SG_DEBUG("Entering\n");
57 
58  // sanity check
59  REQUIRE(A, "Operator is NULL!\n");
60  REQUIRE(A->get_dimension()==b.vlen, "Dimension mismatch! [%d vs %d]\n",
61  A->get_dimension(), b.vlen);
62  REQUIRE(shifts.vector,"Shifts are not initialized!\n");
63  REQUIRE(weights.vector,"Weights are not initialized!\n");
64  REQUIRE(shifts.vlen==weights.vlen, "Number of shifts and number of "
65  "weights are not equal! [%d vs %d]\n", shifts.vlen, weights.vlen);
66 
67  // the solution matrix, one column per shift, initial guess 0 for all
68  MatrixXcd x_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
69  MatrixXcd p_sh=MatrixXcd::Zero(b.vlen, shifts.vlen);
70 
71  // non-shifted direction
73 
74  // the rest of the part hinges on eigen3 for computing norms
75  Map<VectorXd> b_map(b.vector, b.vlen);
76  Map<VectorXd> p(p_.vector, p_.vlen);
77 
78  // residual r_i=b-Ax_i, here x_0=[0], so r_0=b
79  VectorXd r=b_map;
80 
81  // initial direction is same as residual
82  p=r;
83  p_sh=r.replicate(1, shifts.vlen).cast<complex128_t>();
84 
85  // non shifted initializers
86  float64_t r_norm2=r.dot(r);
87  float64_t beta_old=1.0;
88  float64_t alpha=1.0;
89 
90  // shifted quantities
91  SGVector<complex128_t> alpha_sh(shifts.vlen);
92  SGVector<complex128_t> beta_sh(shifts.vlen);
93  SGVector<complex128_t> zeta_sh_old(shifts.vlen);
94  SGVector<complex128_t> zeta_sh_cur(shifts.vlen);
95  SGVector<complex128_t> zeta_sh_new(shifts.vlen);
96 
97  // shifted initializers
98  zeta_sh_old.set_const(1.0);
99  zeta_sh_cur.set_const(1.0);
100 
101  // the iterator for this iterative solver
104 
105  // start the timer
106  CTime time;
107  time.start();
108 
109  // set the residuals to zero
110  if (m_store_residuals)
111  m_residuals.set_const(0.0);
112 
113  // CG iteration begins
114  for (it.begin(r); !it.end(r); ++it)
115  {
116 
117  SG_DEBUG("CG iteration %d, residual norm %f\n",
118  it.get_iter_info().iteration_count,
119  it.get_iter_info().residual_norm);
120 
121  if (m_store_residuals)
122  {
123  m_residuals[it.get_iter_info().iteration_count]
124  =it.get_iter_info().residual_norm;
125  }
126 
127  // apply linear operator to the direction vector
128  SGVector<float64_t> Ap_=A->apply(p_);
129  Map<VectorXd> Ap(Ap_.vector, Ap_.vlen);
130 
131  // compute p^{T}Ap, if zero, failure
132  float64_t p_dot_Ap=p.dot(Ap);
133  if (p_dot_Ap==0.0)
134  break;
135 
136  // compute the beta parameter of CG_M
137  float64_t beta=-r_norm2/p_dot_Ap;
138 
139  // compute the zeta-shifted parameter of CG_M
140  compute_zeta_sh_new(zeta_sh_old, zeta_sh_cur, shifts, beta_old, beta,
141  alpha, zeta_sh_new);
142 
143  // compute beta-shifted parameter of CG_M
144  compute_beta_sh(zeta_sh_new, zeta_sh_cur, beta, beta_sh);
145 
146  // update the solution vector and residual
147  for (index_t i=0; i<shifts.vlen; ++i)
148  x_sh.col(i)-=beta_sh[i]*p_sh.col(i);
149 
150  // r_{i}=r_{i-1}+\beta_{i}Ap
151  r+=beta*Ap;
152 
153  // compute new ||r||_{2}, if zero, converged
154  float64_t r_norm2_i=r.dot(r);
155  if (r_norm2_i==0.0)
156  break;
157 
158  // compute the alpha parameter of CG_M
159  alpha=r_norm2_i/r_norm2;
160 
161  // update ||r||_{2}
162  r_norm2=r_norm2_i;
163 
164  // update direction
165  p=r+alpha*p;
166 
167  compute_alpha_sh(zeta_sh_new, zeta_sh_cur, beta_sh, beta, alpha, alpha_sh);
168 
169  for (index_t i=0; i<shifts.vlen; ++i)
170  {
171  p_sh.col(i)*=alpha_sh[i];
172  p_sh.col(i)+=zeta_sh_new[i]*r;
173  }
174 
175  // update parameters
176  for (index_t i=0; i<shifts.vlen; ++i)
177  {
178  zeta_sh_old[i]=zeta_sh_cur[i];
179  zeta_sh_cur[i]=zeta_sh_new[i];
180  }
181  beta_old=beta;
182  }
183 
184  float64_t elapsed=time.cur_time_diff();
185 
186  if (!it.succeeded(r))
187  SG_WARNING("Did not converge!\n");
188 
189  SG_INFO("Iteration took %d times, residual norm=%.20lf, time elapsed=%f\n",
190  it.get_iter_info().iteration_count, it.get_iter_info().residual_norm, elapsed);
191 
192  // compute the final result vector multiplied by weights
193  SGVector<complex128_t> result(b.vlen);
194  result.set_const(0.0);
195  Map<VectorXcd> x(result.vector, result.vlen);
196 
197  for (index_t i=0; i<x_sh.cols(); ++i)
198  x+=x_sh.col(i)*weights[i];
199 
200  SG_DEBUG("Leaving\n");
201  return result;
202 }
203 
204 }
205 #endif // HAVE_EIGEN3
Class Time that implements a stopwatch based on either cpu time or wall clock time.
Definition: Time.h:47
#define SG_INFO(...)
Definition: SGIO.h:118
std::complex< float64_t > complex128_t
Definition: common.h:67
void compute_zeta_sh_new(const SGVector< complex128_t > &zeta_sh_old, const SGVector< complex128_t > &zeta_sh_cur, const SGVector< complex128_t > &shifts, const float64_t &beta_old, const float64_t &beta_cur, const float64_t &alpha, SGVector< complex128_t > &zeta_sh_new)
SGVector< float64_t > get_real()
Definition: SGVector.cpp:883
void begin(const VectorXt &residual)
const index_t get_dimension() const
int32_t index_t
Definition: common.h:62
Definition: SGMatrix.h:20
#define REQUIRE(x,...)
Definition: SGIO.h:206
const bool end(const VectorXt &residual)
void compute_beta_sh(const SGVector< complex128_t > &zeta_sh_new, const SGVector< complex128_t > &zeta_sh_cur, const float64_t &beta_cur, SGVector< complex128_t > &beta_sh)
float64_t cur_time_diff(bool verbose=false)
Definition: Time.cpp:68
index_t vlen
Definition: SGVector.h:494
template class that is used as an iterator for an iterative linear solver. In the iteration of solvin...
double float64_t
Definition: common.h:50
float64_t start(bool verbose=false)
Definition: Time.cpp:59
virtual SGVector< T > apply(SGVector< T > b) const =0
#define SG_DEBUG(...)
Definition: SGIO.h:107
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
const bool succeeded(const VectorXt &residual)
virtual SGVector< float64_t > solve(CLinearOperator< float64_t > *A, SGVector< float64_t > b)
#define SG_WARNING(...)
Definition: SGIO.h:128
virtual SGVector< complex128_t > solve_shifted_weighted(CLinearOperator< float64_t > *A, SGVector< float64_t > b, SGVector< complex128_t > shifts, SGVector< complex128_t > weights)
void compute_alpha_sh(const SGVector< complex128_t > &zeta_sh_cur, const SGVector< complex128_t > &zeta_sh_old, const SGVector< complex128_t > &beta_sh_old, const float64_t &beta_old, const float64_t &alpha, SGVector< complex128_t > &alpha_sh)
void set_const(T const_elem)
Definition: SGVector.cpp:152
abstract template base for CG based solvers to the solution of shifted linear systems of the form fo...

SHOGUN 机器学习工具包 - 项目文档