32 #ifndef DESCENDUPDATERWITHCORRECTION_H
33 #define DESCENDUPDATERWITHCORRECTION_H
70 REQUIRE(variable_reference.
vlen>0,
"variable_reference must set\n");
72 "The length of variable_reference (%d) and the length of gradient (%d) do not match\n",
73 variable_reference.
vlen,raw_negative_descend_direction.
vlen);
78 if(momentum_correction)
85 for(
index_t idx=0; idx<variable_reference.
vlen; idx++)
88 variable_reference[idx], raw_negative_descend_direction[idx], idx, learning_rate);
92 negative_descend_direction, idx);
93 variable_reference[idx]+=pair.descend_direction;
97 variable_reference[idx]-=negative_descend_direction;
112 REQUIRE(context,
"Context must set\n");
126 REQUIRE(context,
"Context must set\n");
The class is used to serialize and deserialize variables for the optimization framework.
DescendCorrection * m_correction
virtual void initialize_previous_direction(index_t len)
virtual void update_context(CMinimizerContext *context)=0
virtual DescendPair get_corrected_descend_direction(float64_t negative_descend_direction, index_t idx)=0
virtual void set_descend_correction(DescendCorrection *correction)
This is a base class for descend update.
virtual float64_t get_negative_descend_direction(float64_t variable, float64_t raw_negative_descend_direction, index_t idx, float64_t learning_rate)=0
This is a base class for descend based correction method.
virtual void load_from_context(CMinimizerContext *context)
virtual void update_variable(SGVector< float64_t > variable_reference, SGVector< float64_t > raw_negative_descend_direction, float64_t learning_rate)
all of classes and functions are contained in the shogun namespace
This is a base class for momentum correction methods.
This is a base class for descend update with descend based correction.
virtual bool enables_descend_correction()
virtual ~DescendUpdaterWithCorrection()
virtual bool is_initialized()
virtual void load_from_context(CMinimizerContext *context)=0
virtual void update_context(CMinimizerContext *context)