SHOGUN
4.1.0
首页
相关页面
模块
类
文件
文件列表
文件成员
全部
类
命名空间
文件
函数
变量
类型定义
枚举
枚举值
友元
宏定义
组
页
src
shogun
optimization
MomentumCorrection.h
浏览该文件的文档.
1
/*
2
* Copyright (c) The Shogun Machine Learning Toolbox
3
* Written (w) 2015 Wu Lin
4
* All rights reserved.
5
*
6
* Redistribution and use in source and binary forms, with or without
7
* modification, are permitted provided that the following conditions are met:
8
*
9
* 1. Redistributions of source code must retain the above copyright notice, this
10
* list of conditions and the following disclaimer.
11
* 2. Redistributions in binary form must reproduce the above copyright notice,
12
* this list of conditions and the following disclaimer in the documentation
13
* and/or other materials provided with the distribution.
14
*
15
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
19
* ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
22
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
*
26
* The views and conclusions contained in the software and documentation are those
27
* of the authors and should not be interpreted as representing official policies,
28
* either expressed or implied, of the Shogun Development Team.
29
*
30
*/
31
32
#ifndef MOMEMTUMCORRECTION_H
33
#define MOMEMTUMCORRECTION_H
34
#include <
shogun/lib/config.h
>
35
#include <
shogun/lib/SGVector.h
>
36
#include <
shogun/optimization/MinimizerContext.h
>
37
#include <
shogun/optimization/DescendCorrection.h
>
38
namespace
shogun
39
{
46
class
MomentumCorrection
:
public
DescendCorrection
47
{
48
public
:
49
50
/* Constructor */
51
MomentumCorrection
()
52
:
DescendCorrection
()
53
{
54
init();
55
}
56
57
/* Destructor */
58
virtual
~MomentumCorrection
() {};
59
64
virtual
bool
is_initialized
()
65
{
66
return
m_previous_descend_direction
.
vlen
>0;
67
}
68
73
virtual
void
initialize_previous_direction
(
index_t
len)
74
{
75
REQUIRE
(len>0,
"The length (%d) must be positive\n"
, len);
76
m_previous_descend_direction
=
SGVector<float64_t>
(len);
77
m_previous_descend_direction
.
set_const
(0.0);
78
}
79
87
virtual
DescendPair
get_corrected_descend_direction
(
float64_t
negative_descend_direction,
88
index_t
idx)=0;
89
98
virtual
void
update_context
(
CMinimizerContext
* context)
99
{
100
REQUIRE
(context,
"context must set\n"
);
101
SGVector<float64_t>
value(
m_previous_descend_direction
.
vlen
);
102
std::copy(
m_previous_descend_direction
.
vector
,
103
m_previous_descend_direction
.
vector
+
m_previous_descend_direction
.
vlen
,
104
value.
vector
);
105
std::string key=
"MomentumCorrection::m_previous_descend_direction"
;
106
context->
save_data
(key, value);
107
}
108
116
virtual
void
load_from_context
(
CMinimizerContext
* context)
117
{
118
REQUIRE
(context,
"context must set\n"
);
119
std::string key=
"MomentumCorrection::m_previous_descend_direction"
;
120
SGVector<float64_t>
value=context->
get_data_sgvector_float64
(key);
121
m_previous_descend_direction
=
SGVector<float64_t>
(value.
vlen
);
122
std::copy(value.
vector
, value.
vector
+value.
vlen
,
123
m_previous_descend_direction
.
vector
);
124
}
125
132
virtual
float64_t
get_previous_descend_direction
(
index_t
idx)
133
{
134
REQUIRE
(idx>=0 && idx<
m_previous_descend_direction
.
vlen
,
135
"Index (%d) is invalid\n"
, idx);
136
return
m_previous_descend_direction
[idx];
137
}
138
143
virtual
float64_t
get_length_previous_descend_direction
()
144
{
145
return
m_previous_descend_direction
.
vlen
;
146
}
147
protected
:
149
SGVector<float64_t>
m_previous_descend_direction
;
150
151
private
:
152
/* Init */
153
void
init()
154
{
155
m_previous_descend_direction=
SGVector<float64_t>
();
156
}
157
};
158
159
}
160
#endif
shogun::CMinimizerContext::save_data
virtual void save_data(const std::string &key, SGVector< float64_t > value)
Definition:
MinimizerContext.h:74
shogun::MomentumCorrection::get_previous_descend_direction
virtual float64_t get_previous_descend_direction(index_t idx)
Definition:
MomentumCorrection.h:132
index_t
int32_t index_t
Definition:
common.h:62
shogun::CMinimizerContext
The class is used to serialize and deserialize variables for the optimization framework.
Definition:
MinimizerContext.h:45
shogun::MomentumCorrection::get_length_previous_descend_direction
virtual float64_t get_length_previous_descend_direction()
Definition:
MomentumCorrection.h:143
shogun::MomentumCorrection::update_context
virtual void update_context(CMinimizerContext *context)
Definition:
MomentumCorrection.h:98
config.h
REQUIRE
#define REQUIRE(x,...)
Definition:
SGIO.h:206
shogun::MomentumCorrection::initialize_previous_direction
virtual void initialize_previous_direction(index_t len)
Definition:
MomentumCorrection.h:73
shogun::CMinimizerContext::get_data_sgvector_float64
virtual SGVector< float64_t > get_data_sgvector_float64(const std::string &key)
Definition:
MinimizerContext.h:113
shogun::MomentumCorrection::~MomentumCorrection
virtual ~MomentumCorrection()
Definition:
MomentumCorrection.h:58
shogun::SGVector::vlen
index_t vlen
Definition:
SGVector.h:494
shogun::SGVector::vector
T * vector
Definition:
SGVector.h:492
shogun::SGVector< float64_t >
float64_t
double float64_t
Definition:
common.h:50
shogun::MomentumCorrection::m_previous_descend_direction
SGVector< float64_t > m_previous_descend_direction
Definition:
MomentumCorrection.h:149
MinimizerContext.h
shogun::DescendCorrection
This is a base class for descend based correction method.
Definition:
DescendCorrection.h:57
shogun::MomentumCorrection::load_from_context
virtual void load_from_context(CMinimizerContext *context)
Definition:
MomentumCorrection.h:116
shogun
all of classes and functions are contained in the shogun namespace
Definition:
class_list.h:18
shogun::MomentumCorrection
This is a base class for momentum correction methods.
Definition:
MomentumCorrection.h:46
DescendCorrection.h
shogun::MomentumCorrection::is_initialized
virtual bool is_initialized()
Definition:
MomentumCorrection.h:64
SGVector.h
shogun::MomentumCorrection::get_corrected_descend_direction
virtual DescendPair get_corrected_descend_direction(float64_t negative_descend_direction, index_t idx)=0
shogun::SGVector::set_const
void set_const(T const_elem)
Definition:
SGVector.cpp:152
shogun::MomentumCorrection::MomentumCorrection
MomentumCorrection()
Definition:
MomentumCorrection.h:51
SHOGUN
机器学习工具包 - 项目文档