SHOGUN
4.1.0
|
The class implements the AdaDelta method.
\[ \begin{array}{l} g_\theta=(1-\lambda){(\frac{ \partial f(\cdot) }{\partial \theta })}^2+\lambda g_\theta\\ d_\theta=\alpha\frac{\sqrt{s_\theta+\epsilon}}{\sqrt{g_\theta+\epsilon}}\frac{ \partial f(\cdot) }{\partial \theta }\\ s_\theta=(1-\lambda){(d_\theta)}^2+\lambda s_\theta \end{array} \]
.
where \( \frac{ \partial f(\cdot) }{\partial \theta } \) is a negative descend direction (eg, gradient) wrt \(\theta\), \(\lambda\) is a decay factor, \(\epsilon\) is used to avoid dividing by 0, \( \alpha \) is a build-in learning rate \(d_\theta\) is a corrected negative descend direction.
Reference: Matthew D. Zeiler, ADADELTA: An Adaptive Learning Rate Method, arXiv:1212.5701
Definition at line 58 of file AdaDeltaUpdater.h.
Public Member Functions | |
AdaDeltaUpdater () | |
AdaDeltaUpdater (float64_t learning_rate, float64_t epsilon, float64_t decay_factor) | |
virtual | ~AdaDeltaUpdater () |
virtual void | set_learning_rate (float64_t learning_rate) |
virtual void | set_epsilon (float64_t epsilon) |
virtual void | set_decay_factor (float64_t decay_factor) |
virtual void | update_context (CMinimizerContext *context) |
virtual void | load_from_context (CMinimizerContext *context) |
virtual void | update_variable (SGVector< float64_t > variable_reference, SGVector< float64_t > raw_negative_descend_direction, float64_t learning_rate) |
virtual void | set_descend_correction (DescendCorrection *correction) |
virtual bool | enables_descend_correction () |
Protected Member Functions | |
virtual float64_t | get_negative_descend_direction (float64_t variable, float64_t gradient, index_t idx, float64_t learning_rate) |
AdaDeltaUpdater | ( | ) |
Definition at line 36 of file AdaDeltaUpdater.cpp.
AdaDeltaUpdater | ( | float64_t | learning_rate, |
float64_t | epsilon, | ||
float64_t | decay_factor | ||
) |
Parameterized Constructor
learning_rate | learning_rate |
epsilon | epsilon |
decay_factor | decay_factor |
Definition at line 42 of file AdaDeltaUpdater.cpp.
|
virtual |
Definition at line 73 of file AdaDeltaUpdater.cpp.
|
virtualinherited |
Do we enable descend correction?
Definition at line 145 of file DescendUpdaterWithCorrection.h.
|
protectedvirtual |
Get the negative descend direction given current variable and gradient
It will be called at update_variable()
variable | current variable (eg, \(\theta\)) |
gradient | current gradient (eg, \( \frac{ \partial f(\cdot) }{\partial \theta }\)) |
idx | the index of the variable |
learning_rate | learning rate (for AdaDelta, learning_rate is NOT used because there is a build-in learning_rate) |
Implements DescendUpdaterWithCorrection.
Definition at line 124 of file AdaDeltaUpdater.cpp.
|
virtual |
Return a context object which stores mutable variables Usually it is used in serialization.
This method will be called by FirstOrderMinimizer::load_from_context(CMinimizerContext* context)
Reimplemented from DescendUpdaterWithCorrection.
Definition at line 106 of file AdaDeltaUpdater.cpp.
|
virtual |
Set decay_factor
decay_factor | decay factor |
Definition at line 65 of file AdaDeltaUpdater.cpp.
|
virtualinherited |
Set the type of descend correction
correction | the type of descend correction |
Definition at line 135 of file DescendUpdaterWithCorrection.h.
|
virtual |
|
virtual |
Set learning rate
learning_rate | learning rate |
Definition at line 51 of file AdaDeltaUpdater.cpp.
|
virtual |
Update a context object to store mutable variables
This method will be called by FirstOrderMinimizer::save_to_context()
context | a context object |
Reimplemented from DescendUpdaterWithCorrection.
Definition at line 86 of file AdaDeltaUpdater.cpp.
|
virtual |
Update the target variable based on the given negative descend direction
Note that this method will update the target variable in place. This method will be called by FirstOrderMinimizer::minimize()
variable_reference | a reference of the target variable |
raw_negative_descend_direction | the negative descend direction given the current value |
learning_rate | learning rate |
Reimplemented from DescendUpdaterWithCorrection.
Definition at line 140 of file AdaDeltaUpdater.cpp.
|
protected |
learning_rate \( \alpha \) at iteration
Definition at line 141 of file AdaDeltaUpdater.h.
|
protectedinherited |
descend correction object
Definition at line 165 of file DescendUpdaterWithCorrection.h.
|
protected |
decay term ( \( \lambda \))
Definition at line 147 of file AdaDeltaUpdater.h.
|
protected |
\( \epsilon \)
Definition at line 144 of file AdaDeltaUpdater.h.
\( g_\theta \)
Definition at line 150 of file AdaDeltaUpdater.h.
\( s_\theta \)
Definition at line 153 of file AdaDeltaUpdater.h.