SHOGUN  v3.0.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
JADiag.cpp
Go to the documentation of this file.
1 #ifdef HAVE_EIGEN3
2 
4 
5 #include <shogun/base/init.h>
6 
9 
10 using namespace shogun;
11 using namespace Eigen;
12 
13 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
14  float64_t *logdet, float64_t *decr, float64_t *result);
15 
17  double eps, int itermax)
18 {
19  int d = C.dims[0];
20  int L = C.dims[2];
21 
22  // check that the input matrices are pos def
23  for (int i = 0; i < L; i++)
24  {
25  Map<MatrixXd> Ci(C.get_matrix(i),d,d);
26 
27  EigenSolver<MatrixXd> eig;
28  eig.compute(Ci);
29 
30  MatrixXd D = eig.pseudoEigenvalueMatrix();
31 
32  for (int j = 0; j < d; j++)
33  {
34  if (D(j,j) < 0)
35  {
36  SG_SERROR("Input Matrix %d is not Positive-definite\n", i)
37  }
38  }
39  }
40 
42  if (V0.num_rows == d && V0.num_cols == d)
43  V = V0.clone();
44  else
46 
47  VectorXd w(L);
48  w.setOnes();
49 
50  MatrixXd ctot(d, d*L);
51  for (int i = 0; i < L; i++)
52  {
53  Map<MatrixXd> Ci(C.get_matrix(i),d,d);
54  ctot.block(0,i*d,d,d) = Ci;
55  }
56 
57  int iter = 0;
58  float64_t decr = 1;
59  float64_t logdet = log(5.184e17);
60  float64_t result = 0;
61  std::vector<float64_t> crit;
62  while (decr > eps && iter < itermax)
63  {
64  if(logdet == 0)// is NA
65  {
66  SG_SERROR("log det does not exist\n")
67  break;
68  }
69 
70  jadiagw(ctot.data(),
71  w.data(),
72  &d, &L,
73  V.matrix,
74  &logdet,
75  &decr,
76  &result);
77 
78  crit.push_back(result);
79  iter = iter + 1;
80  }
81 
82  if (iter == itermax)
83  SG_SERROR("Convergence not reached\n")
84 
85  return V;
86 
87 }
88 
89 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
90  float64_t *logdet, float64_t *decr, float64_t *result)
91 {
92  int n = *ptn;
93  int m = *ptm;
94  //int i1,j1;
95  int n2 = n*n, mn2 = m*n2,
96  i, ic, ii, ij, j, jc, jj, k, k0;
97  float64_t sumweigh, p2, q1, p, q,
98  alpha, beta, gamma, a12, a21, /*tiny,*/ det;
99  register float64_t tmp1, tmp2, tmp, weigh;
100 
101  for (sumweigh = 0, i = 0; i < m; i++)
102  sumweigh += w[i];
103 
104  det = 1;
105  *decr = 0;
106 
107  for (i = 1, ic = n; i < n ; i++, ic += n)
108  {
109  for (j = jc = 0; j < i; j++, jc += n)
110  {
111  ii = i + ic;
112  jj = j + jc;
113  ij = i + jc;
114 
115  for (q1 = p2 = p = q = 0, k0 = k = 0; k0 < m; k0++, k += n2)
116  {
117  weigh = w[k0];
118  tmp1 = c[ii+k];
119  tmp2 = c[jj+k];
120  tmp = c[ij+k];
121  p += weigh*tmp/tmp1;
122  q += weigh*tmp/tmp2;
123  q1 += weigh*tmp1/tmp2;
124  p2 += weigh*tmp2/tmp1;
125  }
126 
127  q1 /= sumweigh;
128  p2 /= sumweigh;
129  p /= sumweigh;
130  q /= sumweigh;
131  beta = 1 - p2*q1;// p1 = q2 = 1
132 
133  if (q1 <= p2)// the same as q1*q2 <= p1*p2
134  {
135  alpha = p2*q - p;// q2 = 1
136 
137  if (fabs(alpha) - beta < 10e-20)// beta <= 0 always
138  {
139  beta = -1;
140  gamma = p/p2;
141  }
142  else
143  {
144  gamma = - (p*beta + alpha)/p2;// p1 = 1
145  }
146 
147  *decr += sumweigh*(p*p - alpha*alpha/beta)/p2;
148  }
149  else
150  {
151  gamma = p*q1 - q;// p1 = 1
152 
153  if (fabs(gamma) - beta < 10e-20)// beta <= 0 always
154  {
155  beta = -1;
156  alpha = q/q1;
157  }
158  else
159  {
160  alpha = - (q*beta + gamma)/q1;// q2 = 1
161  }
162 
163  *decr += sumweigh*(q*q - gamma*gamma/beta)/q1;
164  }
165 
166  tmp = (beta - sqrt(beta*beta - 4*alpha*gamma))/2;
167  a12 = gamma/tmp;
168  a21 = alpha/tmp;
169 
170  for (k = 0; k < mn2; k += n2)
171  {
172  for (ii = i, jj = j; ii < ij; ii += n, jj += n)
173  {
174  tmp = c[ii+k];
175  c[ii+k] += a12*c[jj+k];
176  c[jj+k] += a21*tmp;
177  }// at exit ii = ij = i + jc
178 
179  tmp = c[i+ic+k];
180  c[i+ic+k] += a12*(2*c[ij+k] + a12*c[jj+k]);
181  c[jj+k] += a21*c[ij+k];
182  c[ij+k] += a21*tmp;// = element of index j,i
183 
184  for (; ii < ic; ii += n, jj++)
185  {
186  tmp = c[ii+k];
187  c[ii+k] += a12*c[jj+k];
188  c[jj+k] += a21*tmp;
189  }
190 
191  for (; ++ii, ++jj < jc+n; )
192  {
193  tmp = c[ii+k];
194  c[ii+k] += a12*c[jj+k];
195  c[jj+k] += a21*tmp;
196  }
197 
198  }
199 
200  for (k = 0; k < n2; k += n)
201  {
202  tmp = a[i+k];
203  a[i+k] += a12*a[j+k];
204  a[j+k] += a21*tmp;
205  }
206 
207  det *= 1 - a12*a21;// compute determinant
208  }
209  }
210 
211  *logdet += 2*sumweigh*log(det);
212 
213  for (tmp = 0, k0 = k = 0; k0 < m; k0++, k += n2)
214  {
215  for (det = 1, ii = 0; ii < n2; ii += n+1)
216  {
217  det *= c[ii+k];
218  tmp += w[k0]*log(det);
219  }
220  }
221 
222  *result = tmp - *logdet;
223 
224  return;
225 }
226 #endif //HAVE_EIGEN3

SHOGUN Machine Learning Toolbox - Documentation