SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
List of all members | Public Member Functions | Protected Attributes
NesterovMomentumCorrection Class Reference

Detailed Description

This implements the Nesterov's Accelerated Gradient (NAG) correction.

Given a target variable, \(w\), and a descend direction, \(d_{ahead}\) wrt \(w_{ahead}\), the momentum method performs the following update:

\begin{eqnarray*} w_{ahead} &=& w + \mu v \\ v^{new} &=& \mu v - d_{ahead} \\ w^{new} &=& w + v^{new} \end{eqnarray*}

where \(\mu\) is a momentum, \(d_{ahead}\) is descend direction wrt \(w_{ahead}\) (eg, \( d_{ahead}=\lambda g_{ahead}\), where \(\lambda\) is learning rate, \(g_{ahead}\) is gradient wrt \(w_{ahead}\)), \(v\) is a previous descend direction, and \(v^{new}\) is a corrected descend direction.

Note that the Nesterov momentum correction makes use of \(d_{ahead}\) instead of \(d\).

In practice, we use the following implementation:

\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \\ w^{new} &=& w - \mu v^{old} + (1 + \mu) v^{new} \end{eqnarray*}

where \(d\) is descend direction wrt \(w\)

The trick used in this implementation is we store \(w_{ahead}\) and rename it as \(w\). Given a decay descend direction (eg, \(d=\lambda g_{ahead}\), where \(\lambda\) is a decay learning rate), \(w_{ahead}\) is very close to \(w\). When an optimal solution \(w^{opt}\) is found, \(w_{ahead}=w^{opt}\) since \(d^{opt}=0\)

The get_corrected_descend_direction() method will do

\begin{eqnarray*} v^{old} &=& v \\ v^{new} &=& \mu v^{old} - d \end{eqnarray*}

and return \( -\mu v^{old} + (1 + \mu) v^{new}\)

A good introduction to the momentum update can be found at http://cs231n.github.io/neural-networks-3/#sgd

If you read the introduction at http://cs231n.github.io/neural-networks-3/#sgd , you may know that \(v\) is also called velocity.

Definition at line 78 of file NesterovMomentumCorrection.h.

Inheritance diagram for NesterovMomentumCorrection:
Inheritance graph
[legend]

Public Member Functions

 NesterovMomentumCorrection ()
 
virtual ~NesterovMomentumCorrection ()
 
virtual DescendPair get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx)
 
virtual bool is_initialized ()
 
virtual void initialize_previous_direction (index_t len)
 
virtual void update_context (CMinimizerContext *context)
 
virtual void load_from_context (CMinimizerContext *context)
 
virtual float64_t get_previous_descend_direction (index_t idx)
 
virtual float64_t get_length_previous_descend_direction ()
 
virtual void set_correction_weight (float64_t weight)
 

Protected Attributes

SGVector< float64_tm_previous_descend_direction
 
float64_t m_weight
 

Constructor & Destructor Documentation

Definition at line 82 of file NesterovMomentumCorrection.h.

virtual ~NesterovMomentumCorrection ( )
virtual

Definition at line 89 of file NesterovMomentumCorrection.h.

Member Function Documentation

virtual DescendPair get_corrected_descend_direction ( float64_t  negative_descend_direction,
index_t  idx 
)
virtual

Get corrected descend direction

Parameters
negative_descend_directionthe negative descend direction
idxthe index of the direction
Returns
DescendPair (corrected descend direction and the change to correct descend direction)

Implements MomentumCorrection.

Definition at line 97 of file NesterovMomentumCorrection.h.

virtual float64_t get_length_previous_descend_direction ( )
virtualinherited

Get the length of the previous descend direction (velocity)

Returns
the length of the previous descend direction

Definition at line 143 of file MomentumCorrection.h.

virtual float64_t get_previous_descend_direction ( index_t  idx)
virtualinherited

Get the previous descend direction (velocity) given the index

Parameters
idxindex of the previous descend direction
Returns
the previous descend direction

Definition at line 132 of file MomentumCorrection.h.

virtual void initialize_previous_direction ( index_t  len)
virtualinherited

Initialize m_previous_descend_direction?

Returns
len the length of m_previous_descend_direction to be initialized

Reimplemented in AdaptMomentumCorrection.

Definition at line 73 of file MomentumCorrection.h.

virtual bool is_initialized ( )
virtualinherited

Is the m_previous_descend_direction initialized?

Returns
whether m_previous_descend_direction is initialized

Reimplemented in AdaptMomentumCorrection.

Definition at line 64 of file MomentumCorrection.h.

virtual void load_from_context ( CMinimizerContext context)
virtualinherited

Load the given context object to restore mutable variables

This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)

Parameters
contexta context object

Implements DescendCorrection.

Reimplemented in AdaptMomentumCorrection.

Definition at line 116 of file MomentumCorrection.h.

virtual void set_correction_weight ( float64_t  weight)
virtualinherited

Set the weight used in descend correction

param weight the weight

Reimplemented in AdaptMomentumCorrection.

Definition at line 74 of file DescendCorrection.h.

virtual void update_context ( CMinimizerContext context)
virtualinherited

Update a context object to store mutable variables used in descend update

This method will be called by DescendUpdaterWithCorrection::update_context()

Parameters
contexta context object

Implements DescendCorrection.

Reimplemented in AdaptMomentumCorrection.

Definition at line 98 of file MomentumCorrection.h.

Member Data Documentation

SGVector<float64_t> m_previous_descend_direction
protectedinherited

used in momentum methods

Definition at line 149 of file MomentumCorrection.h.

float64_t m_weight
protectedinherited

weight of correction

Definition at line 110 of file DescendCorrection.h.


The documentation for this class was generated from the following file:

SHOGUN Machine Learning Toolbox - Documentation