SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
mathematics
linalg
linsolver
ConjugateOrthogonalCGSolver.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Soumyajit De
8
*/
9
10
#include <
shogun/lib/common.h
>
11
12
#ifdef HAVE_EIGEN3
13
14
#include <
shogun/lib/SGVector.h
>
15
#include <
shogun/lib/Time.h
>
16
#include <
shogun/mathematics/eigen3.h
>
17
#include <
shogun/mathematics/Math.h
>
18
#include <
shogun/mathematics/linalg/linop/LinearOperator.h
>
19
#include <
shogun/mathematics/linalg/linsolver/ConjugateOrthogonalCGSolver.h
>
20
#include <
shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h
>
21
using namespace
Eigen;
22
23
namespace
shogun
24
{
25
26
CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver()
27
:
CIterativeLinearSolver
<
complex128_t
,
float64_t
>()
28
{
29
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
30
}
31
32
CConjugateOrthogonalCGSolver::CConjugateOrthogonalCGSolver
(
bool
store_residuals)
33
:
CIterativeLinearSolver
<
complex128_t
,
float64_t
>(store_residuals)
34
{
35
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
36
}
37
38
CConjugateOrthogonalCGSolver::~CConjugateOrthogonalCGSolver
()
39
{
40
SG_GCDEBUG
(
"%s destroyed (%p)\n"
, this->
get_name
(),
this
);
41
}
42
43
SGVector<complex128_t>
CConjugateOrthogonalCGSolver::solve
(
44
CLinearOperator<complex128_t>
* A,
SGVector<float64_t>
b)
45
{
46
SG_DEBUG
(
"CConjugateOrthogonalCGSolver::solve(): Entering..\n"
);
47
48
// sanity check
49
REQUIRE
(A,
"Operator is NULL!\n"
);
50
REQUIRE
(A->
get_dimension
()==b.
vlen
,
"Dimension mismatch!\n, %d vs %d"
,
51
A->
get_dimension
(), b.
vlen
);
52
53
// the final solution vector, initial guess is 0
54
SGVector<complex128_t>
result(b.
vlen
);
55
result.set_const(0.0);
56
57
// the rest of the part hinges on eigen3 for computing norms
58
Map<VectorXcd> x(result.vector, result.vlen);
59
Map<VectorXd> b_map(b.
vector
, b.
vlen
);
60
61
// direction vector
62
SGVector<complex128_t>
p_(result.vlen);
63
Map<VectorXcd> p(p_.
vector
, p_.
vlen
);
64
65
// residual r_i=b-Ax_i, here x_0=[0], so r_0=b
66
VectorXcd r=b_map.cast<
complex128_t
>();
67
68
// initial direction is same as residual
69
p=r;
70
71
// the iterator for this iterative solver
72
IterativeSolverIterator<complex128_t>
it(r,
m_max_iteration_limit
,
73
m_relative_tolerence
,
m_absolute_tolerence
);
74
75
// start the timer
76
CTime
time;
77
time.
start
();
78
79
// set the residuals to zero
80
if
(
m_store_residuals
)
81
m_residuals
.
set_const
(0.0);
82
83
// CG iteration begins
84
complex128_t
r_norm2=r.transpose()*r;
85
86
for
(it.
begin
(r); !it.
end
(r); ++it)
87
{
88
SG_DEBUG
(
"CG iteration %d, residual norm %f\n"
,
89
it.
get_iter_info
().iteration_count,
90
it.
get_iter_info
().residual_norm);
91
92
if
(
m_store_residuals
)
93
{
94
m_residuals
[it.
get_iter_info
().iteration_count]
95
=it.
get_iter_info
().residual_norm;
96
}
97
98
// apply linear operator to the direction vector
99
SGVector<complex128_t>
Ap_=A->
apply
(p_);
100
Map<VectorXcd> Ap(Ap_.
vector
, Ap_.
vlen
);
101
102
// compute p^{T}Ap, if zero, failure
103
complex128_t
p_T_times_Ap=p.transpose()*Ap;
104
if
(p_T_times_Ap==0.0)
105
break
;
106
107
// compute the alpha parameter of CG
108
complex128_t
alpha=r_norm2/p_T_times_Ap;
109
110
// update the solution vector and residual
111
// x_{i}=x_{i-1}+\alpha_{i}p
112
x+=alpha*p;
113
114
// r_{i}=r_{i-1}-\alpha_{i}p
115
r-=alpha*Ap;
116
117
// compute new ||r||_{2}, if zero, converged
118
complex128_t
r_norm2_i=r.transpose()*r;
119
if
(r_norm2_i==0.0)
120
break
;
121
122
// compute the beta parameter of CG
123
complex128_t
beta=r_norm2_i/r_norm2;
124
125
// update direction, and ||r||_{2}
126
r_norm2=r_norm2_i;
127
p=r+beta*p;
128
}
129
130
float64_t
elapsed=time.
cur_time_diff
();
131
132
if
(!it.
succeeded
(r))
133
SG_WARNING
(
"Did not converge!\n"
);
134
135
SG_INFO
(
"Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n"
,
136
it.
get_iter_info
().iteration_count, it.
get_iter_info
().residual_norm, elapsed);
137
138
SG_DEBUG
(
"CConjugateOrthogonalCGSolver::solve(): Leaving..\n"
);
139
return
result;
140
}
141
142
}
143
#endif // HAVE_EIGEN3
SHOGUN
Machine Learning Toolbox - Documentation