58 REQUIRE(learning_rate>0,
"Learning_rate (%f) must be positive\n",
65 REQUIRE(epsilon>0,
"Epsilon (%f) must be non-negative\n",
72 REQUIRE(decay_factor>0.0 && decay_factor<=1.0,
73 "Decay factor (%f) for first moment must in (0,1]\n",
80 REQUIRE(decay_factor>0.0 && decay_factor<=1.0,
81 "Decay factor (%f) for second moment must in (0,1]\n",
88 void AdamUpdater::init()
105 SG_ADD(&m_gradient_second_moment,
"AdamUpdater__m_gradient_second_moment",
139 REQUIRE(variable_reference.
vlen==raw_negative_descend_direction.
vlen,
"");
SGVector< float64_t > m_gradient_first_moment
virtual void set_second_moment_decay_factor(float64_t decay_factor)
float64_t m_decay_factor_first_moment
float64_t m_log_learning_rate
virtual void set_learning_rate(float64_t learning_rate)
SGVector< float64_t > m_gradient_second_moment
float64_t m_log_scale_pre_iteration
virtual float64_t get_negative_descend_direction(float64_t variable, float64_t gradient, index_t idx, float64_t learning_rate)
virtual void update_variable(SGVector< float64_t > variable_reference, SGVector< float64_t > raw_negative_descend_direction, float64_t learning_rate)
all of classes and functions are contained in the shogun namespace
This is a base class for descend update with descend based correction.
static float64_t exp(float64_t x)
static float64_t log(float64_t v)
virtual void update_variable(SGVector< float64_t > variable_reference, SGVector< float64_t > raw_negative_descend_direction, float64_t learning_rate)
static float32_t sqrt(float32_t x)
virtual void set_first_moment_decay_factor(float64_t decay_factor)
virtual void set_epsilon(float64_t epsilon)
static int32_t pow(bool x, int32_t n)
int64_t m_iteration_counter
void set_const(T const_elem)
float64_t m_decay_factor_second_moment