SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
所有成员列表 | Public 成员函数 | Protected 属性
StandardMomentumCorrection类 参考

详细描述

This implements the plain momentum correction.

Given a target variable, \(w\), and a current descend direction, \(d\), the momentum method performs the following update:

\begin{eqnarray*} v^{new} &=& \mu v - d \\ w^{new} &=& w + v^{new} \end{eqnarray*}

where \(\mu\) is a momentum, \(v\) is a previous descend direction, \(d\) is a current descend direction (eg, \(d=\lambda g\), where \(\lambda\) is a learn rate, \(g\) is gradient), and \(v^{new}\) is a corrected descend direction.

The get_corrected_descend_direction() method will do

\[ v^{new} = \mu v - d \]

and return \(v^{new}\)

A good introduction to the momentum update can be found at http://cs231n.github.io/neural-networks-3/#sgd

If you read the introduction at http://cs231n.github.io/neural-networks-3/#sgd , you may know that \(v\) is also called velocity.

在文件 StandardMomentumCorrection.h62 行定义.

类 StandardMomentumCorrection 继承关系图:
Inheritance graph
[图例]

Public 成员函数

 StandardMomentumCorrection ()
 
virtual ~StandardMomentumCorrection ()
 
virtual DescendPair get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx)
 
virtual bool is_initialized ()
 
virtual void initialize_previous_direction (index_t len)
 
virtual void update_context (CMinimizerContext *context)
 
virtual void load_from_context (CMinimizerContext *context)
 
virtual float64_t get_previous_descend_direction (index_t idx)
 
virtual float64_t get_length_previous_descend_direction ()
 
virtual void set_correction_weight (float64_t weight)
 

Protected 属性

SGVector< float64_tm_previous_descend_direction
 
float64_t m_weight
 

构造及析构函数说明

在文件 StandardMomentumCorrection.h66 行定义.

virtual ~StandardMomentumCorrection ( )
virtual

在文件 StandardMomentumCorrection.h73 行定义.

成员函数说明

virtual DescendPair get_corrected_descend_direction ( float64_t  negative_descend_direction,
index_t  idx 
)
virtual

Get corrected descend direction

参数
negative_descend_directionthe negative descend direction
idxthe index of the direction
返回
DescendPair (corrected descend direction and the change to correct descend direction)

实现了 MomentumCorrection.

在文件 StandardMomentumCorrection.h82 行定义.

virtual float64_t get_length_previous_descend_direction ( )
virtualinherited

Get the length of the previous descend direction (velocity)

返回
the length of the previous descend direction

在文件 MomentumCorrection.h143 行定义.

virtual float64_t get_previous_descend_direction ( index_t  idx)
virtualinherited

Get the previous descend direction (velocity) given the index

参数
idxindex of the previous descend direction
返回
the previous descend direction

在文件 MomentumCorrection.h132 行定义.

virtual void initialize_previous_direction ( index_t  len)
virtualinherited

Initialize m_previous_descend_direction?

返回
len the length of m_previous_descend_direction to be initialized

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h73 行定义.

virtual bool is_initialized ( )
virtualinherited

Is the m_previous_descend_direction initialized?

返回
whether m_previous_descend_direction is initialized

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h64 行定义.

virtual void load_from_context ( CMinimizerContext context)
virtualinherited

Load the given context object to restore mutable variables

This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)

参数
contexta context object

实现了 DescendCorrection.

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h116 行定义.

virtual void set_correction_weight ( float64_t  weight)
virtualinherited

Set the weight used in descend correction

param weight the weight

AdaptMomentumCorrection 重载.

在文件 DescendCorrection.h74 行定义.

virtual void update_context ( CMinimizerContext context)
virtualinherited

Update a context object to store mutable variables used in descend update

This method will be called by DescendUpdaterWithCorrection::update_context()

参数
contexta context object

实现了 DescendCorrection.

AdaptMomentumCorrection 重载.

在文件 MomentumCorrection.h98 行定义.

类成员变量说明

SGVector<float64_t> m_previous_descend_direction
protectedinherited

used in momentum methods

在文件 MomentumCorrection.h149 行定义.

float64_t m_weight
protectedinherited

weight of correction

在文件 DescendCorrection.h110 行定义.


该类的文档由以下文件生成:

SHOGUN 机器学习工具包 - 项目文档