SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
mathematics
linalg
linsolver
ConjugateGradientSolver.cpp
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Soumyajit De
8
*/
9
10
#include <
shogun/lib/common.h
>
11
12
#ifdef HAVE_EIGEN3
13
14
#include <
shogun/lib/SGVector.h
>
15
#include <
shogun/lib/Time.h
>
16
#include <
shogun/mathematics/eigen3.h
>
17
#include <
shogun/mathematics/linalg/linop/LinearOperator.h
>
18
#include <
shogun/mathematics/linalg/linsolver/ConjugateGradientSolver.h
>
19
#include <
shogun/mathematics/linalg/linsolver/IterativeSolverIterator.h
>
20
21
using namespace
Eigen;
22
23
namespace
shogun
24
{
25
26
CConjugateGradientSolver::CConjugateGradientSolver()
27
:
CIterativeLinearSolver
<
float64_t
>()
28
{
29
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
30
}
31
32
CConjugateGradientSolver::CConjugateGradientSolver
(
bool
store_residuals)
33
:
CIterativeLinearSolver
<
float64_t
>(store_residuals)
34
{
35
SG_GCDEBUG
(
"%s created (%p)\n"
, this->
get_name
(),
this
);
36
}
37
38
CConjugateGradientSolver::~CConjugateGradientSolver
()
39
{
40
SG_GCDEBUG
(
"%s destroyed (%p)\n"
, this->
get_name
(),
this
);
41
}
42
43
SGVector<float64_t>
CConjugateGradientSolver::solve
(
44
CLinearOperator<float64_t>
* A,
SGVector<float64_t>
b)
45
{
46
SG_DEBUG
(
"CConjugateGradientSolve::solve(): Entering..\n"
);
47
48
// sanity check
49
REQUIRE
(A,
"Operator is NULL!\n"
);
50
REQUIRE
(A->
get_dimension
()==b.
vlen
,
"Dimension mismatch!\n"
);
51
52
// the final solution vector, initial guess is 0
53
SGVector<float64_t>
result(b.
vlen
);
54
result.set_const(0.0);
55
56
// the rest of the part hinges on eigen3 for computing norms
57
Map<VectorXd> x(result.vector, result.vlen);
58
Map<VectorXd> b_map(b.
vector
, b.
vlen
);
59
60
// direction vector
61
SGVector<float64_t>
p_(result.vlen);
62
Map<VectorXd> p(p_.
vector
, p_.
vlen
);
63
64
// residual r_i=b-Ax_i, here x_0=[0], so r_0=b
65
VectorXd r=b_map;
66
67
// initial direction is same as residual
68
p=r;
69
70
// the iterator for this iterative solver
71
IterativeSolverIterator<float64_t>
it(b_map,
m_max_iteration_limit
,
72
m_relative_tolerence
,
m_absolute_tolerence
);
73
74
// CG iteration begins
75
float64_t
r_norm2=r.dot(r);
76
77
// start the timer
78
CTime
time;
79
time.
start
();
80
81
// set the residuals to zero
82
if
(
m_store_residuals
)
83
m_residuals
.
set_const
(0.0);
84
85
for
(it.
begin
(r); !it.
end
(r); ++it)
86
{
87
SG_DEBUG
(
"CG iteration %d, residual norm %f\n"
,
88
it.
get_iter_info
().iteration_count,
89
it.
get_iter_info
().residual_norm);
90
91
if
(
m_store_residuals
)
92
{
93
m_residuals
[it.
get_iter_info
().iteration_count]
94
=it.
get_iter_info
().residual_norm;
95
}
96
97
// apply linear operator to the direction vector
98
SGVector<float64_t>
Ap_=A->
apply
(p_);
99
Map<VectorXd> Ap(Ap_.
vector
, Ap_.
vlen
);
100
101
// compute p^{T}Ap, if zero, failure
102
float64_t
p_dot_Ap=p.dot(Ap);
103
if
(p_dot_Ap==0.0)
104
break
;
105
106
// compute the alpha parameter of CG
107
float64_t
alpha=r_norm2/p_dot_Ap;
108
109
// update the solution vector and residual
110
// x_{i}=x_{i-1}+\alpha_{i}p
111
x+=alpha*p;
112
113
// r_{i}=r_{i-1}-\alpha_{i}p
114
r-=alpha*Ap;
115
116
// compute new ||r||_{2}, if zero, converged
117
float64_t
r_norm2_i=r.dot(r);
118
if
(r_norm2_i==0.0)
119
break
;
120
121
// compute the beta parameter of CG
122
float64_t
beta=r_norm2_i/r_norm2;
123
124
// update direction, and ||r||_{2}
125
r_norm2=r_norm2_i;
126
p=r+beta*p;
127
}
128
129
float64_t
elapsed=time.
cur_time_diff
();
130
131
if
(!it.
succeeded
(r))
132
SG_WARNING
(
"Did not converge!\n"
);
133
134
SG_INFO
(
"Iteration took %ld times, residual norm=%.20lf, time elapsed=%lf\n"
,
135
it.
get_iter_info
().iteration_count, it.
get_iter_info
().residual_norm, elapsed);
136
137
SG_DEBUG
(
"CConjugateGradientSolve::solve(): Leaving..\n"
);
138
return
result;
139
}
140
141
}
142
#endif // HAVE_EIGEN3
SHOGUN
Machine Learning Toolbox - Documentation