This implements the Nesterov's Accelerated Gradient (NAG) correction.
Given a target variable, \(w\), and a descend direction, \(d_{ahead}\) wrt \(w_{ahead}\), the momentum method performs the following update:
\begin{eqnarray*} w_{ahead} &=& w + \mu v \\ v^{new} &=& \mu v - d_{ahead} \\ w^{new} &=& w + v^{new} \end{eqnarray*}
where \(\mu\) is a momentum, \(d_{ahead}\) is descend direction wrt \(w_{ahead}\) (eg, \( d_{ahead}=\lambda g_{ahead}\), where \(\lambda\) is learning rate, \(g_{ahead}\) is gradient wrt \(w_{ahead}\)), \(v\) is a previous descend direction, and \(v^{new}\) is a corrected descend direction.
Note that the Nesterov momentum correction makes use of \(d_{ahead}\) instead of \(d\).
In practice, we use the following implementation:
\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \\ w^{new} &=& w - \mu v^{old} + (1 + \mu) v^{new} \end{eqnarray*}
where \(d\) is descend direction wrt \(w\)
The trick used in this implementation is we store \(w_{ahead}\) and rename it as \(w\). Given a decay descend direction (eg, \(d=\lambda g_{ahead}\), where \(\lambda\) is a decay learning rate), \(w_{ahead}\) is very close to \(w\). When an optimal solution \(w^{opt}\) is found, \(w_{ahead}=w^{opt}\) since \(d^{opt}=0\)
The get_corrected_descend_direction() method will do
\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \end{eqnarray*}
and return \( -\mu v^{old} + (1 + \mu) v^{new}\)
A good introduction to the momentum update can be found at http://cs231n.github.io/neural-networks-3/#sgd
If you read the introduction at http://cs231n.github.io/neural-networks-3/#sgd , you may know that \(v\) is also called velocity.
在文件 NesterovMomentumCorrection.h 第 78 行定义.
Public 成员函数 | |
NesterovMomentumCorrection () | |
virtual | ~NesterovMomentumCorrection () |
virtual DescendPair | get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx) |
virtual bool | is_initialized () |
virtual void | initialize_previous_direction (index_t len) |
virtual void | update_context (CMinimizerContext *context) |
virtual void | load_from_context (CMinimizerContext *context) |
virtual float64_t | get_previous_descend_direction (index_t idx) |
virtual float64_t | get_length_previous_descend_direction () |
virtual void | set_correction_weight (float64_t weight) |
Protected 属性 | |
SGVector< float64_t > | m_previous_descend_direction |
float64_t | m_weight |
在文件 NesterovMomentumCorrection.h 第 82 行定义.
|
virtual |
在文件 NesterovMomentumCorrection.h 第 89 行定义.
|
virtual |
Get corrected descend direction
negative_descend_direction | the negative descend direction |
idx | the index of the direction |
实现了 MomentumCorrection.
在文件 NesterovMomentumCorrection.h 第 97 行定义.
|
virtualinherited |
Get the length of the previous descend direction (velocity)
在文件 MomentumCorrection.h 第 143 行定义.
Get the previous descend direction (velocity) given the index
idx | index of the previous descend direction |
在文件 MomentumCorrection.h 第 132 行定义.
|
virtualinherited |
Initialize m_previous_descend_direction?
被 AdaptMomentumCorrection 重载.
在文件 MomentumCorrection.h 第 73 行定义.
|
virtualinherited |
Is the m_previous_descend_direction initialized?
被 AdaptMomentumCorrection 重载.
在文件 MomentumCorrection.h 第 64 行定义.
|
virtualinherited |
Load the given context object to restore mutable variables
This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)
context | a context object |
实现了 DescendCorrection.
被 AdaptMomentumCorrection 重载.
在文件 MomentumCorrection.h 第 116 行定义.
|
virtualinherited |
Set the weight used in descend correction
param weight the weight
被 AdaptMomentumCorrection 重载.
在文件 DescendCorrection.h 第 74 行定义.
|
virtualinherited |
Update a context object to store mutable variables used in descend update
This method will be called by DescendUpdaterWithCorrection::update_context()
context | a context object |
实现了 DescendCorrection.
被 AdaptMomentumCorrection 重载.
在文件 MomentumCorrection.h 第 98 行定义.
used in momentum methods
在文件 MomentumCorrection.h 第 149 行定义.
|
protectedinherited |
weight of correction
在文件 DescendCorrection.h 第 110 行定义.