SHOGUN
v3.0.0
Main Page
Related Pages
Modules
Classes
Files
File List
File Members
All
Classes
Namespaces
Files
Functions
Variables
Typedefs
Enumerations
Enumerator
Friends
Macros
Groups
Pages
src
shogun
multiclass
MCLDA.h
Go to the documentation of this file.
1
/*
2
* This program is free software; you can redistribute it and/or modify
3
* it under the terms of the GNU General Public License as published by
4
* the Free Software Foundation; either version 3 of the License, or
5
* (at your option) any later version.
6
*
7
* Written (W) 2013 Kevin Hughes
8
* Copyright (C) 2013 Kevin Hughes
9
*
10
* Thanks to Fernando José Iglesias García (shogun)
11
* and Matthieu Perrot (scikit-learn)
12
*/
13
14
#ifndef _MCLDA_H__
15
#define _MCLDA_H__
16
17
#include <
shogun/lib/config.h
>
18
19
#ifdef HAVE_EIGEN3
20
21
#include <
shogun/features/DotFeatures.h
>
22
#include <
shogun/features/DenseFeatures.h
>
23
#include <
shogun/machine/NativeMulticlassMachine.h
>
24
#include <
shogun/lib/SGNDArray.h
>
25
26
namespace
shogun
27
{
28
29
//#define DEBUG_MCLDA
30
39
class
CMCLDA
:
public
CNativeMulticlassMachine
40
{
41
public
:
42
MACHINE_PROBLEM_TYPE
(
PT_MULTICLASS
)
43
44
49
CMCLDA
(
float64_t
tolerance = 1e-4,
bool
store_cov =
false
);
50
58
CMCLDA
(
CDenseFeatures<float64_t>
* traindat,
CLabels
* trainlab,
float64_t
tolerance = 1e-4,
bool
store_cov =
false
);
59
60
virtual
~CMCLDA
();
61
67
virtual
CMulticlassLabels
*
apply_multiclass
(
CFeatures
* data=NULL);
68
73
inline
void
set_tolerance
(
float64_t
tolerance) { m_tolerance = tolerance; }
74
79
inline
bool
get_tolerance
() {
return
m_tolerance; }
80
85
virtual
EMachineType
get_classifier_type
() {
return
CT_LDA
; }
// for now add to machine typers properly later
86
91
virtual
void
set_features
(
CDotFeatures
* feat)
92
{
93
if
(feat->
get_feature_class
() !=
C_DENSE
||
94
feat->
get_feature_type
() !=
F_DREAL
)
95
SG_ERROR
(
"MCLDA requires SIMPLE REAL valued features\n"
)
96
97
SG_REF
(feat);
98
SG_UNREF
(m_features);
99
m_features = feat;
100
}
101
106
virtual
CDotFeatures
*
get_features
() {
SG_REF
(m_features);
return
m_features; }
107
112
virtual
const
char
*
get_name
()
const
{
return
"MCLDA"
; }
113
120
inline
SGVector< float64_t >
get_mean
(int32_t c)
const
121
{
122
return
SGVector< float64_t >
(m_means.
get_column_vector
(c), m_dim,
false
);
123
}
124
129
inline
SGMatrix< float64_t >
get_cov
()
const
130
{
131
return
m_cov;
132
}
133
134
protected
:
141
virtual
bool
train_machine
(
CFeatures
* data = NULL);
142
143
private
:
144
void
init();
145
146
void
cleanup();
147
148
private
:
150
CDotFeatures
* m_features;
151
153
float64_t
m_tolerance;
154
156
bool
m_store_cov;
157
159
int32_t m_num_classes;
160
162
int32_t m_dim;
163
167
SGMatrix< float64_t >
m_cov;
168
170
SGMatrix< float64_t >
m_means;
171
173
SGVector< float64_t >
m_xbar;
174
176
int32_t m_rank;
177
179
SGMatrix< float64_t >
m_scalings;
180
182
SGMatrix< float64_t >
m_coef;
183
185
SGVector< float64_t >
m_intercept;
186
187
};
/* class MCLDA */
188
}
/* namespace shogun */
189
190
#endif
/* HAVE_EIGEN3 */
191
#endif
/* _MCLDA_H__ */
SHOGUN
Machine Learning Toolbox - Documentation