This implements the adaptive momentum correction method.
A standard momentum correction performs update based on a momentum (eg, \(\mu\)), a previous descend direction (eg, \(v\)) and a current descend direction (eg, \(d\)).
The idea of adaptive momentum correction method is If signs of the last two momentum corrections are different, the current descend direction is discounted On the other hand, if the the signs are the same, the method raises the current descend direction Please see method get_corrected_descend_direction() for details
在文件 AdaptMomentumCorrection.h 第 50 行定义.
Public 成员函数 | |
AdaptMomentumCorrection () | |
virtual void | set_momentum_correction (MomentumCorrection *correction) |
virtual | ~AdaptMomentumCorrection () |
virtual bool | is_initialized () |
virtual void | set_correction_weight (float64_t weight) |
virtual void | initialize_previous_direction (index_t len) |
virtual DescendPair | get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx) |
virtual void | update_context (CMinimizerContext *context) |
virtual void | load_from_context (CMinimizerContext *context) |
virtual void | set_adapt_rate (float64_t adapt_rate, float64_t rate_min=0.0, float64_t rate_max=CMath::INFTY) |
virtual void | set_init_descend_rate (float64_t init_descend_rate) |
virtual float64_t | get_previous_descend_direction (index_t idx) |
virtual float64_t | get_length_previous_descend_direction () |
在文件 AdaptMomentumCorrection.h 第 55 行定义.
|
virtual |
在文件 AdaptMomentumCorrection.h 第 73 行定义.
|
virtual |
Get corrected descend direction
negative_descend_direction | the negative descend direction |
idx | the index of the direction |
实现了 MomentumCorrection.
在文件 AdaptMomentumCorrection.h 第 111 行定义.
|
virtualinherited |
Get the length of the previous descend direction (velocity)
在文件 MomentumCorrection.h 第 143 行定义.
Get the previous descend direction (velocity) given the index
idx | index of the previous descend direction |
在文件 MomentumCorrection.h 第 132 行定义.
|
virtual |
Initialize m_previous_descend_direction
重载 MomentumCorrection .
在文件 AdaptMomentumCorrection.h 第 99 行定义.
|
virtual |
Is the standard momentum method initialized?
重载 MomentumCorrection .
在文件 AdaptMomentumCorrection.h 第 79 行定义.
|
virtual |
Load the given context object to restore mutable variables
This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)
context | a context object |
重载 MomentumCorrection .
在文件 AdaptMomentumCorrection.h 第 173 行定义.
|
virtual |
Set adaptive weights used in this method
adapt_rate | the rate is used to discount/raise the current descend direction (see get_corrected_descend_direction() ) |
rate_min | minimum of the rate |
rate_max | maximum of the rate |
在文件 AdaptMomentumCorrection.h 第 191 行定义.
|
virtual |
Set the weight (momentum) for the standard momentum method
weight | momentum |
重载 DescendCorrection .
在文件 AdaptMomentumCorrection.h 第 89 行定义.
|
virtual |
Set the init rate used to discount/raise the current descend direction
init_descend_rate | the init rate (default 1.0) |
在文件 AdaptMomentumCorrection.h 第 206 行定义.
|
virtual |
Set a standard momentum method
correction | standard momentum method (eg, StandardMomentumCorrection) |
在文件 AdaptMomentumCorrection.h 第 65 行定义.
|
virtual |
Update a context object to store mutable variables used in descend update
This method will be called by DescendUpdaterWithCorrection::update_context()
context | a context object |
重载 MomentumCorrection .
在文件 AdaptMomentumCorrection.h 第 153 行定义.
|
protected |
the adapt rate
在文件 AdaptMomentumCorrection.h 第 217 行定义.
element wise rate used to discount/raise the current descend direction
在文件 AdaptMomentumCorrection.h 第 213 行定义.
|
protected |
the init rate
在文件 AdaptMomentumCorrection.h 第 223 行定义.
|
protected |
the standard momentum method
在文件 AdaptMomentumCorrection.h 第 215 行定义.
used in momentum methods
在文件 MomentumCorrection.h 第 149 行定义.
|
protected |
the maximum of the adapt rate
在文件 AdaptMomentumCorrection.h 第 221 行定义.
|
protected |
the minimum of the adapt rate
在文件 AdaptMomentumCorrection.h 第 219 行定义.
|
protectedinherited |
weight of correction
在文件 DescendCorrection.h 第 110 行定义.