SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
JADiag.cpp
Go to the documentation of this file.
2 
3 
4 #include <shogun/base/init.h>
5 
8 
9 using namespace shogun;
10 using namespace Eigen;
11 
12 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
13  float64_t *logdet, float64_t *decr, float64_t *result);
14 
16  double eps, int itermax)
17 {
18  int d = C.dims[0];
19  int L = C.dims[2];
20 
21  // check that the input matrices are pos def
22  for (int i = 0; i < L; i++)
23  {
24  Map<MatrixXd> Ci(C.get_matrix(i),d,d);
25 
26  EigenSolver<MatrixXd> eig;
27  eig.compute(Ci);
28 
29  MatrixXd D = eig.pseudoEigenvalueMatrix();
30 
31  for (int j = 0; j < d; j++)
32  {
33  if (D(j,j) < 0)
34  {
35  SG_SERROR("Input Matrix %d is not Positive-definite\n", i)
36  }
37  }
38  }
39 
41  if (V0.num_rows == d && V0.num_cols == d)
42  V = V0.clone();
43  else
45 
46  VectorXd w(L);
47  w.setOnes();
48 
49  MatrixXd ctot(d, d*L);
50  for (int i = 0; i < L; i++)
51  {
52  Map<MatrixXd> Ci(C.get_matrix(i),d,d);
53  ctot.block(0,i*d,d,d) = Ci;
54  }
55 
56  int iter = 0;
57  float64_t decr = 1;
58  float64_t logdet = log(5.184e17);
59  float64_t result = 0;
60  std::vector<float64_t> crit;
61  while (decr > eps && iter < itermax)
62  {
63  if(logdet == 0)// is NA
64  {
65  SG_SERROR("log det does not exist\n")
66  break;
67  }
68 
69  jadiagw(ctot.data(),
70  w.data(),
71  &d, &L,
72  V.matrix,
73  &logdet,
74  &decr,
75  &result);
76 
77  crit.push_back(result);
78  iter = iter + 1;
79  }
80 
81  if (iter == itermax)
82  SG_SERROR("Convergence not reached\n")
83 
84  return V;
85 
86 }
87 
88 void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[],
89  float64_t *logdet, float64_t *decr, float64_t *result)
90 {
91  int n = *ptn;
92  int m = *ptm;
93  //int i1,j1;
94  int n2 = n*n, mn2 = m*n2,
95  i, ic, ii, ij, j, jc, jj, k, k0;
96  float64_t sumweigh, p2, q1, p, q,
97  alpha, beta, gamma, a12, a21, /*tiny,*/ det;
98  register float64_t tmp1, tmp2, tmp, weigh;
99 
100  for (sumweigh = 0, i = 0; i < m; i++)
101  sumweigh += w[i];
102 
103  det = 1;
104  *decr = 0;
105 
106  for (i = 1, ic = n; i < n ; i++, ic += n)
107  {
108  for (j = jc = 0; j < i; j++, jc += n)
109  {
110  ii = i + ic;
111  jj = j + jc;
112  ij = i + jc;
113 
114  for (q1 = p2 = p = q = 0, k0 = k = 0; k0 < m; k0++, k += n2)
115  {
116  weigh = w[k0];
117  tmp1 = c[ii+k];
118  tmp2 = c[jj+k];
119  tmp = c[ij+k];
120  p += weigh*tmp/tmp1;
121  q += weigh*tmp/tmp2;
122  q1 += weigh*tmp1/tmp2;
123  p2 += weigh*tmp2/tmp1;
124  }
125 
126  q1 /= sumweigh;
127  p2 /= sumweigh;
128  p /= sumweigh;
129  q /= sumweigh;
130  beta = 1 - p2*q1;// p1 = q2 = 1
131 
132  if (q1 <= p2)// the same as q1*q2 <= p1*p2
133  {
134  alpha = p2*q - p;// q2 = 1
135 
136  if (fabs(alpha) - beta < 10e-20)// beta <= 0 always
137  {
138  beta = -1;
139  gamma = p/p2;
140  }
141  else
142  {
143  gamma = - (p*beta + alpha)/p2;// p1 = 1
144  }
145 
146  *decr += sumweigh*(p*p - alpha*alpha/beta)/p2;
147  }
148  else
149  {
150  gamma = p*q1 - q;// p1 = 1
151 
152  if (fabs(gamma) - beta < 10e-20)// beta <= 0 always
153  {
154  beta = -1;
155  alpha = q/q1;
156  }
157  else
158  {
159  alpha = - (q*beta + gamma)/q1;// q2 = 1
160  }
161 
162  *decr += sumweigh*(q*q - gamma*gamma/beta)/q1;
163  }
164 
165  tmp = (beta - sqrt(beta*beta - 4*alpha*gamma))/2;
166  a12 = gamma/tmp;
167  a21 = alpha/tmp;
168 
169  for (k = 0; k < mn2; k += n2)
170  {
171  for (ii = i, jj = j; ii < ij; ii += n, jj += n)
172  {
173  tmp = c[ii+k];
174  c[ii+k] += a12*c[jj+k];
175  c[jj+k] += a21*tmp;
176  }// at exit ii = ij = i + jc
177 
178  tmp = c[i+ic+k];
179  c[i+ic+k] += a12*(2*c[ij+k] + a12*c[jj+k]);
180  c[jj+k] += a21*c[ij+k];
181  c[ij+k] += a21*tmp;// = element of index j,i
182 
183  for (; ii < ic; ii += n, jj++)
184  {
185  tmp = c[ii+k];
186  c[ii+k] += a12*c[jj+k];
187  c[jj+k] += a21*tmp;
188  }
189 
190  for (; ++ii, ++jj < jc+n; )
191  {
192  tmp = c[ii+k];
193  c[ii+k] += a12*c[jj+k];
194  c[jj+k] += a21*tmp;
195  }
196 
197  }
198 
199  for (k = 0; k < n2; k += n)
200  {
201  tmp = a[i+k];
202  a[i+k] += a12*a[j+k];
203  a[j+k] += a21*tmp;
204  }
205 
206  det *= 1 - a12*a21;// compute determinant
207  }
208  }
209 
210  *logdet += 2*sumweigh*log(det);
211 
212  for (tmp = 0, k0 = k = 0; k0 < m; k0++, k += n2)
213  {
214  for (det = 1, ii = 0; ii < n2; ii += n+1)
215  {
216  det *= c[ii+k];
217  tmp += w[k0]*log(det);
218  }
219  }
220 
221  *result = tmp - *logdet;
222 
223  return;
224 }
void jadiagw(float64_t c[], float64_t w[], int *ptn, int *ptm, float64_t a[], float64_t *logdet, float64_t *decr, float64_t *result)
Definition: JADiag.cpp:88
Definition: SGMatrix.h:20
SGMatrix< T > clone()
Definition: SGMatrix.cpp:256
index_t num_cols
Definition: SGMatrix.h:376
static SGMatrix< float64_t > diagonalize(SGNDArray< float64_t > C, SGMatrix< float64_t > V0=SGMatrix< float64_t >(NULL, 0, 0, false), double eps=CMath::MACHINE_EPSILON, int itermax=200)
Definition: JADiag.cpp:15
index_t num_rows
Definition: SGMatrix.h:374
T * get_matrix(index_t matIdx) const
Definition: SGNDArray.h:72
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
index_t * dims
Definition: SGNDArray.h:177
#define SG_SERROR(...)
Definition: SGIO.h:179
static SGMatrix< T > create_identity_matrix(index_t size, T scale)

SHOGUN Machine Learning Toolbox - Documentation