SHOGUN
4.1.0
|
This is a base class for descend update with descend based correction.
The class enables descend update with descend-based correction.
Given a target variable, \(w\), and its negative descend direction \(g\), the class will first correct the descend direction, \(g\), and then update \(w\) based on \(g^{corrected}\) (eg, subtracting \(g^{corrected}\))
Note that an example of \(d\) is to simply use the gradient wrt \(w\). An example of using descend based correction can be found at StandardMomentumCorrection
Definition at line 52 of file DescendUpdaterWithCorrection.h.
Public Member Functions | |
virtual | ~DescendUpdaterWithCorrection () |
virtual void | update_variable (SGVector< float64_t > variable_reference, SGVector< float64_t > raw_negative_descend_direction, float64_t learning_rate) |
virtual void | update_context (CMinimizerContext *context) |
virtual void | load_from_context (CMinimizerContext *context) |
virtual void | set_descend_correction (DescendCorrection *correction) |
virtual bool | enables_descend_correction () |
Protected Member Functions | |
virtual float64_t | get_negative_descend_direction (float64_t variable, float64_t raw_negative_descend_direction, index_t idx, float64_t learning_rate)=0 |
Protected Attributes | |
DescendCorrection * | m_correction |
|
virtual |
Definition at line 56 of file DescendUpdaterWithCorrection.h.
|
virtual |
Do we enable descend correction?
Definition at line 145 of file DescendUpdaterWithCorrection.h.
|
protectedpure virtual |
Get the negative descend direction given current variable and raw negative descend direction
It will be called by update_variable()
variable | current variable |
raw_negative_descend_direction | current raw negative descend direction |
idx | the index of the variable |
learning_rate | learning rate |
Implemented in AdaDeltaUpdater, RmsPropUpdater, AdaGradUpdater, and GradientDescendUpdater.
|
virtual |
Load the given context object to restore mutable variables
This method will be called by FirstOrderMinimizer::load_from_context(CMinimizerContext* context)
context | a context object |
Implements DescendUpdater.
Reimplemented in AdaDeltaUpdater, RmsPropUpdater, and AdaGradUpdater.
Definition at line 124 of file DescendUpdaterWithCorrection.h.
|
virtual |
Set the type of descend correction
correction | the type of descend correction |
Definition at line 135 of file DescendUpdaterWithCorrection.h.
|
virtual |
Update a context object to store mutable variables used in descend update
This method will be called by FirstOrderMinimizer::save_to_context()
context | a context object |
Implements DescendUpdater.
Reimplemented in AdaDeltaUpdater, RmsPropUpdater, and AdaGradUpdater.
Definition at line 110 of file DescendUpdaterWithCorrection.h.
|
virtual |
Update the target variable based on the given negative descend direction
Note that this method will update the target variable in place. This method will be called by FirstOrderMinimizer::minimize()
variable_reference | a reference of the target variable |
raw_negative_descend_direction | the negative descend direction given the current value |
learning_rate | learning rate |
Implements DescendUpdater.
Reimplemented in AdaDeltaUpdater, RmsPropUpdater, and AdaGradUpdater.
Definition at line 67 of file DescendUpdaterWithCorrection.h.
|
protected |
descend correction object
Definition at line 165 of file DescendUpdaterWithCorrection.h.