SHOGUN  4.1.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
FFDiag.cpp
Go to the documentation of this file.
2 
3 #ifdef HAVE_EIGEN3
4 
5 #include <shogun/base/init.h>
6 
9 
10 using namespace shogun;
11 using namespace Eigen;
12 
13 void getW(float64_t *C, int *ptN, int *ptK, float64_t *W);
14 
16  double eps, int itermax)
17 {
18  int n = C0.dims[0];
19  int K = C0.dims[2];
20 
21  index_t * C_dims = SG_MALLOC(index_t, 3);
22  C_dims[0] = C0.dims[0];
23  C_dims[1] = C0.dims[1];
24  C_dims[2] = C0.dims[2];
25  SGNDArray<float64_t> C(C_dims,3);
26  memcpy(C.array, C0.array, C0.dims[0]*C0.dims[1]*C0.dims[2]*sizeof(float64_t));
27 
29  if (V0.num_rows == n && V0.num_cols == n)
30  V = V0.clone();
31  else
33 
34  MatrixXd Id(n,n); Id.setIdentity();
35  Map<MatrixXd> EV(V.matrix,n,n);
36 
37  float64_t inum = 0;
38  float64_t df = 1;
39  std::vector<float64_t> crit;
40  while (df > eps && inum < itermax)
41  {
42  MatrixXd W = MatrixXd::Zero(n,n);
43 
44  getW(C.get_matrix(0),
45  &n, &K,
46  W.data());
47 
48  W.transposeInPlace();
49  int e = CMath::ceil(log2(W.array().abs().rowwise().sum().maxCoeff()));
50  int s = std::max(0,e-1);
51  W /= pow(2,s);
52 
53  EV = (Id+W) * EV;
54  MatrixXd d = MatrixXd::Zero(EV.rows(),EV.cols());
55  d.diagonal() = VectorXd::Ones(EV.diagonalSize()).cwiseQuotient((EV * EV.transpose()).diagonal().cwiseSqrt());
56  EV = d * EV;
57 
58  for (int i = 0; i < K; i++)
59  {
60  Map<MatrixXd> Ci(C.get_matrix(i), n, n);
61  Map<MatrixXd> C0i(C0.get_matrix(i), n, n);
62  Ci = EV * C0i * EV.transpose();
63  }
64 
65  float64_t f = 0;
66  for (int i = 0; i < K; i++)
67  {
68  Map<MatrixXd> C0i(C0.get_matrix(i), n, n);
69  MatrixXd F = EV * C0i * EV.transpose();
70  f += (F.transpose() * F).diagonal().sum() - F.array().pow(2).matrix().diagonal().sum();
71  }
72 
73  crit.push_back(f);
74 
75  if (inum > 1)
76  df = CMath::abs(crit[inum-1]-crit[inum]);
77 
78  inum++;
79  }
80 
81  if (inum == itermax)
82  SG_SERROR("Convergence not reached\n")
83 
84  return V;
85 
86 }
87 
88 void getW(float64_t *C, int *ptN, int *ptK, float64_t *W)
89 {
90  int N=*ptN;
91  int K=*ptK;
92  int auxij,auxji,auxii,auxjj;
93  float64_t z[N][N];
94  float64_t y[N][N];
95 
96  for (int i = 0; i < N; i++)
97  {
98  for (int j = 0; j < N; j++)
99  {
100  z[i][j] = 0;
101  y[i][j] = 0;
102  }
103  }
104 
105  for (int i = 0; i < N; i++)
106  {
107  for (int j = 0; j < N; j++)
108  {
109  for (int k = 0; k < K; k++)
110  {
111  auxij = N*N*k+N*i+j;
112  auxji = N*N*k+N*j+i;
113  auxii = N*N*k+N*i+i;
114  auxjj = N*N*k+N*j+j;
115  z[i][j] += C[auxii]*C[auxjj];
116  y[i][j] += 0.5*C[auxjj]*(C[auxij]+C[auxji]);
117  }
118  }
119  }
120 
121  for (int i = 0; i < N-1; i++)
122  {
123  for (int j = i+1; j < N; j++)
124  {
125  auxij = N*i+j;
126  auxji = N*j+i;
127  W[auxij] = (z[j][i]*y[j][i] - z[i][i]*y[i][j])/(z[j][j]*z[i][i]-z[i][j]*z[i][j]);
128  W[auxji] = (z[i][j]*y[i][j] - z[j][j]*y[j][i])/(z[j][j]*z[i][i]-z[i][j]*z[i][j]);
129  }
130  }
131 
132  return;
133 }
134 #endif //HAVE_EIGEN3
int32_t index_t
Definition: common.h:62
static float64_t ceil(float64_t d)
Definition: Math.h:416
static SGMatrix< float64_t > diagonalize(SGNDArray< float64_t > C, SGMatrix< float64_t > V0=SGMatrix< float64_t >(NULL, 0, 0, false), double eps=CMath::MACHINE_EPSILON, int itermax=200)
Definition: FFDiag.cpp:15
Definition: SGMatrix.h:20
SGMatrix< T > clone()
Definition: SGMatrix.cpp:260
index_t num_cols
Definition: SGMatrix.h:378
index_t num_rows
Definition: SGMatrix.h:376
T * get_matrix(index_t matIdx) const
Definition: SGNDArray.h:72
double float64_t
Definition: common.h:50
void getW(float64_t *C, int *ptN, int *ptK, float64_t *W)
Definition: FFDiag.cpp:88
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
index_t * dims
Definition: SGNDArray.h:177
#define SG_SERROR(...)
Definition: SGIO.h:179
Matrix::Scalar max(Matrix m)
Definition: Redux.h:66
static SGMatrix< T > create_identity_matrix(index_t size, T scale)
static T abs(T a)
Definition: Math.h:179

SHOGUN Machine Learning Toolbox - Documentation