SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
所有成员列表 | Public 成员函数 | Protected 属性
AdaptMomentumCorrection类 参考

详细描述

This implements the adaptive momentum correction method.

A standard momentum correction performs update based on a momentum (eg, \(\mu\)), a previous descend direction (eg, \(v\)) and a current descend direction (eg, \(d\)).

The idea of adaptive momentum correction method is If signs of the last two momentum corrections are different, the current descend direction is discounted On the other hand, if the the signs are the same, the method raises the current descend direction Please see method get_corrected_descend_direction() for details

在文件 AdaptMomentumCorrection.h50 行定义.

类 AdaptMomentumCorrection 继承关系图:
Inheritance graph
[图例]

Public 成员函数

 AdaptMomentumCorrection ()
 
virtual void set_momentum_correction (MomentumCorrection *correction)
 
virtual ~AdaptMomentumCorrection ()
 
virtual bool is_initialized ()
 
virtual void set_correction_weight (float64_t weight)
 
virtual void initialize_previous_direction (index_t len)
 
virtual DescendPair get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx)
 
virtual void update_context (CMinimizerContext *context)
 
virtual void load_from_context (CMinimizerContext *context)
 
virtual void set_adapt_rate (float64_t adapt_rate, float64_t rate_min=0.0, float64_t rate_max=CMath::INFTY)
 
virtual void set_init_descend_rate (float64_t init_descend_rate)
 
virtual float64_t get_previous_descend_direction (index_t idx)
 
virtual float64_t get_length_previous_descend_direction ()
 

Protected 属性

SGVector< float64_tm_descend_rate
 
MomentumCorrectionm_momentum_correction
 
float64_t m_adapt_rate
 
float64_t m_rate_min
 
float64_t m_rate_max
 
float64_t m_init_descend_rate
 
SGVector< float64_tm_previous_descend_direction
 
float64_t m_weight
 

构造及析构函数说明

在文件 AdaptMomentumCorrection.h55 行定义.

virtual ~AdaptMomentumCorrection ( )
virtual

在文件 AdaptMomentumCorrection.h73 行定义.

成员函数说明

virtual DescendPair get_corrected_descend_direction ( float64_t  negative_descend_direction,
index_t  idx 
)
virtual

Get corrected descend direction

参数
negative_descend_directionthe negative descend direction
idxthe index of the direction
返回
DescendPair (corrected descend direction and the change to correct descend direction)

实现了 MomentumCorrection.

在文件 AdaptMomentumCorrection.h111 行定义.

virtual float64_t get_length_previous_descend_direction ( )
virtualinherited

Get the length of the previous descend direction (velocity)

返回
the length of the previous descend direction

在文件 MomentumCorrection.h143 行定义.

virtual float64_t get_previous_descend_direction ( index_t  idx)
virtualinherited

Get the previous descend direction (velocity) given the index

参数
idxindex of the previous descend direction
返回
the previous descend direction

在文件 MomentumCorrection.h132 行定义.

virtual void initialize_previous_direction ( index_t  len)
virtual

Initialize m_previous_descend_direction

返回
len the length of m_previous_descend_direction to be initialized

重载 MomentumCorrection .

在文件 AdaptMomentumCorrection.h99 行定义.

virtual bool is_initialized ( )
virtual

Is the standard momentum method initialized?

返回
whether the standard method is initialized

重载 MomentumCorrection .

在文件 AdaptMomentumCorrection.h79 行定义.

virtual void load_from_context ( CMinimizerContext context)
virtual

Load the given context object to restore mutable variables

This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)

参数
contexta context object

重载 MomentumCorrection .

在文件 AdaptMomentumCorrection.h173 行定义.

virtual void set_adapt_rate ( float64_t  adapt_rate,
float64_t  rate_min = 0.0,
float64_t  rate_max = CMath::INFTY 
)
virtual

Set adaptive weights used in this method

参数
adapt_ratethe rate is used to discount/raise the current descend direction (see get_corrected_descend_direction() )
rate_minminimum of the rate
rate_maxmaximum of the rate

在文件 AdaptMomentumCorrection.h191 行定义.

virtual void set_correction_weight ( float64_t  weight)
virtual

Set the weight (momentum) for the standard momentum method

参数
weightmomentum

重载 DescendCorrection .

在文件 AdaptMomentumCorrection.h89 行定义.

virtual void set_init_descend_rate ( float64_t  init_descend_rate)
virtual

Set the init rate used to discount/raise the current descend direction

参数
init_descend_ratethe init rate (default 1.0)

在文件 AdaptMomentumCorrection.h206 行定义.

virtual void set_momentum_correction ( MomentumCorrection correction)
virtual

Set a standard momentum method

参数
correctionstandard momentum method (eg, StandardMomentumCorrection)

在文件 AdaptMomentumCorrection.h65 行定义.

virtual void update_context ( CMinimizerContext context)
virtual

Update a context object to store mutable variables used in descend update

This method will be called by DescendUpdaterWithCorrection::update_context()

参数
contexta context object

重载 MomentumCorrection .

在文件 AdaptMomentumCorrection.h153 行定义.

类成员变量说明

float64_t m_adapt_rate
protected

the adapt rate

在文件 AdaptMomentumCorrection.h217 行定义.

SGVector<float64_t> m_descend_rate
protected

element wise rate used to discount/raise the current descend direction

在文件 AdaptMomentumCorrection.h213 行定义.

float64_t m_init_descend_rate
protected

the init rate

在文件 AdaptMomentumCorrection.h223 行定义.

MomentumCorrection* m_momentum_correction
protected

the standard momentum method

在文件 AdaptMomentumCorrection.h215 行定义.

SGVector<float64_t> m_previous_descend_direction
protectedinherited

used in momentum methods

在文件 MomentumCorrection.h149 行定义.

float64_t m_rate_max
protected

the maximum of the adapt rate

在文件 AdaptMomentumCorrection.h221 行定义.

float64_t m_rate_min
protected

the minimum of the adapt rate

在文件 AdaptMomentumCorrection.h219 行定义.

float64_t m_weight
protectedinherited

weight of correction

在文件 DescendCorrection.h110 行定义.


该类的文档由以下文件生成:

SHOGUN 机器学习工具包 - 项目文档