SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
所有成员列表 | Public 成员函数 | Protected 属性
NesterovMomentumCorrection类 参考

详细描述

This implements the Nesterov's Accelerated Gradient (NAG) correction.

Given a target variable, \(w\), and a descend direction, \(d_{ahead}\) wrt \(w_{ahead}\), the momentum method performs the following update:

\begin{eqnarray*} w_{ahead} &=& w + \mu v \\ v^{new} &=& \mu v - d_{ahead} \\ w^{new} &=& w + v^{new} \end{eqnarray*}

where \(\mu\) is a momentum, \(d_{ahead}\) is descend direction wrt \(w_{ahead}\) (eg, \( d_{ahead}=\lambda g_{ahead}\), where \(\lambda\) is learning rate, \(g_{ahead}\) is gradient wrt \(w_{ahead}\)), \(v\) is a previous descend direction, and \(v^{new}\) is a corrected descend direction.

Note that the Nesterov momentum correction makes use of \(d_{ahead}\) instead of \(d\).

In practice, we use the following implementation:

\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \\ w^{new} &=& w - \mu v^{old} + (1 + \mu) v^{new} \end{eqnarray*}

where \(d\) is descend direction wrt \(w\)

The trick used in this implementation is we store \(w_{ahead}\) and rename it as \(w\). Given a decay descend direction (eg, \(d=\lambda g_{ahead}\), where \(\lambda\) is a decay learning rate), \(w_{ahead}\) is very close to \(w\). When an optimal solution \(w^{opt}\) is found, \(w_{ahead}=w^{opt}\) since \(d^{opt}=0\)

The get_corrected_descend_direction() method will do

\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \end{eqnarray*}

and return \( -\mu v^{old} + (1 + \mu) v^{new}\)

A good introduction to the momentum update can be found at http://cs231n.github.io/neural-networks-3/#sgd

If you read the introduction at http://cs231n.github.io/neural-networks-3/#sgd , you may know that \(v\) is also called velocity.

在文件 NesterovMomentumCorrection.h78 行定义.

类 NesterovMomentumCorrection 继承关系图:
Inheritance graph
[图例]

Public 成员函数

 NesterovMomentumCorrection ()
 
virtual ~NesterovMomentumCorrection ()
 
virtual DescendPair get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx)
 
virtual bool is_initialized ()
 
virtual void initialize_previous_direction (index_t len)
 
virtual void update_context (CMinimizerContext *context)
 
virtual void load_from_context (CMinimizerContext *context)
 
virtual float64_t get_previous_descend_direction (index_t idx)
 
virtual float64_t get_length_previous_descend_direction ()
 
virtual void set_correction_weight (float64_t weight)
 

Protected 属性

SGVector< float64_tm_previous_descend_direction
 
float64_t m_weight
 

构造及析构函数说明

在文件 NesterovMomentumCorrection.h82 行定义.

virtual ~NesterovMomentumCorrection ( )
virtual

在文件 NesterovMomentumCorrection.h89 行定义.

成员函数说明

virtual DescendPair get_corrected_descend_direction ( float64_t  negative_descend_direction,
index_t  idx 
)
virtual

Get corrected descend direction

参数
negative_descend_directionthe negative descend direction
idxthe index of the direction
返回
DescendPair (corrected descend direction and the change to correct descend direction)

实现了 MomentumCorrection.

在文件 NesterovMomentumCorrection.h97 行定义.

virtual float64_t get_length_previous_descend_direction ( )
virtualinherited

Get the length of the previous descend direction (velocity)

返回
the length of the previous descend direction

在文件 MomentumCorrection.h143 行定义.

virtual float64_t get_previous_descend_direction ( index_t  idx)
virtualinherited

Get the previous descend direction (velocity) given the index

参数
idxindex of the previous descend direction
返回
the previous descend direction

在文件 MomentumCorrection.h132 行定义.

virtual void initialize_previous_direction ( index_t  len)
virtualinherited

Initialize m_previous_descend_direction?

返回
len the length of m_previous_descend_direction to be initialized

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h73 行定义.

virtual bool is_initialized ( )
virtualinherited

Is the m_previous_descend_direction initialized?

返回
whether m_previous_descend_direction is initialized

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h64 行定义.

virtual void load_from_context ( CMinimizerContext context)
virtualinherited

Load the given context object to restore mutable variables

This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)

参数
contexta context object

实现了 DescendCorrection.

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h116 行定义.

virtual void set_correction_weight ( float64_t  weight)
virtualinherited

Set the weight used in descend correction

param weight the weight

AdaptMomentumCorrection 重载.

在文件 DescendCorrection.h74 行定义.

virtual void update_context ( CMinimizerContext context)
virtualinherited

Update a context object to store mutable variables used in descend update

This method will be called by DescendUpdaterWithCorrection::update_context()

参数
contexta context object

实现了 DescendCorrection.

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h98 行定义.

类成员变量说明

SGVector<float64_t> m_previous_descend_direction
protectedinherited

used in momentum methods

在文件 MomentumCorrection.h149 行定义.

float64_t m_weight
protectedinherited

weight of correction

在文件 DescendCorrection.h110 行定义.


该类的文档由以下文件生成:

SHOGUN 机器学习工具包 - 项目文档