SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
List of all members | Public Member Functions | Protected Attributes
AdaptMomentumCorrection Class Reference

Detailed Description

This implements the adaptive momentum correction method.

A standard momentum correction performs update based on a momentum (eg, \(\mu\)), a previous descend direction (eg, \(v\)) and a current descend direction (eg, \(d\)).

The idea of adaptive momentum correction method is If signs of the last two momentum corrections are different, the current descend direction is discounted On the other hand, if the the signs are the same, the method raises the current descend direction Please see method get_corrected_descend_direction() for details

Definition at line 50 of file AdaptMomentumCorrection.h.

Inheritance diagram for AdaptMomentumCorrection:
Inheritance graph
[legend]

Public Member Functions

 AdaptMomentumCorrection ()
 
virtual void set_momentum_correction (MomentumCorrection *correction)
 
virtual ~AdaptMomentumCorrection ()
 
virtual bool is_initialized ()
 
virtual void set_correction_weight (float64_t weight)
 
virtual void initialize_previous_direction (index_t len)
 
virtual DescendPair get_corrected_descend_direction (float64_t negative_descend_direction, index_t idx)
 
virtual void update_context (CMinimizerContext *context)
 
virtual void load_from_context (CMinimizerContext *context)
 
virtual void set_adapt_rate (float64_t adapt_rate, float64_t rate_min=0.0, float64_t rate_max=CMath::INFTY)
 
virtual void set_init_descend_rate (float64_t init_descend_rate)
 
virtual float64_t get_previous_descend_direction (index_t idx)
 
virtual float64_t get_length_previous_descend_direction ()
 

Protected Attributes

SGVector< float64_tm_descend_rate
 
MomentumCorrectionm_momentum_correction
 
float64_t m_adapt_rate
 
float64_t m_rate_min
 
float64_t m_rate_max
 
float64_t m_init_descend_rate
 
SGVector< float64_tm_previous_descend_direction
 
float64_t m_weight
 

Constructor & Destructor Documentation

Definition at line 55 of file AdaptMomentumCorrection.h.

virtual ~AdaptMomentumCorrection ( )
virtual

Definition at line 73 of file AdaptMomentumCorrection.h.

Member Function Documentation

virtual DescendPair get_corrected_descend_direction ( float64_t  negative_descend_direction,
index_t  idx 
)
virtual

Get corrected descend direction

Parameters
negative_descend_directionthe negative descend direction
idxthe index of the direction
Returns
DescendPair (corrected descend direction and the change to correct descend direction)

Implements MomentumCorrection.

Definition at line 111 of file AdaptMomentumCorrection.h.

virtual float64_t get_length_previous_descend_direction ( )
virtualinherited

Get the length of the previous descend direction (velocity)

Returns
the length of the previous descend direction

Definition at line 143 of file MomentumCorrection.h.

virtual float64_t get_previous_descend_direction ( index_t  idx)
virtualinherited

Get the previous descend direction (velocity) given the index

Parameters
idxindex of the previous descend direction
Returns
the previous descend direction

Definition at line 132 of file MomentumCorrection.h.

virtual void initialize_previous_direction ( index_t  len)
virtual

Initialize m_previous_descend_direction

Returns
len the length of m_previous_descend_direction to be initialized

Reimplemented from MomentumCorrection.

Definition at line 99 of file AdaptMomentumCorrection.h.

virtual bool is_initialized ( )
virtual

Is the standard momentum method initialized?

Returns
whether the standard method is initialized

Reimplemented from MomentumCorrection.

Definition at line 79 of file AdaptMomentumCorrection.h.

virtual void load_from_context ( CMinimizerContext context)
virtual

Load the given context object to restore mutable variables

This method will be called by DescendUpdaterWithCorrection::load_from_context(CMinimizerContext* context)

Parameters
contexta context object

Reimplemented from MomentumCorrection.

Definition at line 173 of file AdaptMomentumCorrection.h.

virtual void set_adapt_rate ( float64_t  adapt_rate,
float64_t  rate_min = 0.0,
float64_t  rate_max = CMath::INFTY 
)
virtual

Set adaptive weights used in this method

Parameters
adapt_ratethe rate is used to discount/raise the current descend direction (see get_corrected_descend_direction() )
rate_minminimum of the rate
rate_maxmaximum of the rate

Definition at line 191 of file AdaptMomentumCorrection.h.

virtual void set_correction_weight ( float64_t  weight)
virtual

Set the weight (momentum) for the standard momentum method

Parameters
weightmomentum

Reimplemented from DescendCorrection.

Definition at line 89 of file AdaptMomentumCorrection.h.

virtual void set_init_descend_rate ( float64_t  init_descend_rate)
virtual

Set the init rate used to discount/raise the current descend direction

Parameters
init_descend_ratethe init rate (default 1.0)

Definition at line 206 of file AdaptMomentumCorrection.h.

virtual void set_momentum_correction ( MomentumCorrection correction)
virtual

Set a standard momentum method

Parameters
correctionstandard momentum method (eg, StandardMomentumCorrection)

Definition at line 65 of file AdaptMomentumCorrection.h.

virtual void update_context ( CMinimizerContext context)
virtual

Update a context object to store mutable variables used in descend update

This method will be called by DescendUpdaterWithCorrection::update_context()

Parameters
contexta context object

Reimplemented from MomentumCorrection.

Definition at line 153 of file AdaptMomentumCorrection.h.

Member Data Documentation

float64_t m_adapt_rate
protected

the adapt rate

Definition at line 217 of file AdaptMomentumCorrection.h.

SGVector<float64_t> m_descend_rate
protected

element wise rate used to discount/raise the current descend direction

Definition at line 213 of file AdaptMomentumCorrection.h.

float64_t m_init_descend_rate
protected

the init rate

Definition at line 223 of file AdaptMomentumCorrection.h.

MomentumCorrection* m_momentum_correction
protected

the standard momentum method

Definition at line 215 of file AdaptMomentumCorrection.h.

SGVector<float64_t> m_previous_descend_direction
protectedinherited

used in momentum methods

Definition at line 149 of file MomentumCorrection.h.

float64_t m_rate_max
protected

the maximum of the adapt rate

Definition at line 221 of file AdaptMomentumCorrection.h.

float64_t m_rate_min
protected

the minimum of the adapt rate

Definition at line 219 of file AdaptMomentumCorrection.h.

float64_t m_weight
protectedinherited

weight of correction

Definition at line 110 of file DescendCorrection.h.


The documentation for this class was generated from the following file:

SHOGUN Machine Learning Toolbox - Documentation