SHOGUN  6.0.0
AmariIndex.cpp
Go to the documentation of this file.
1 #include "AmariIndex.h"
2
3
6
7 using namespace shogun;
8 using namespace Eigen;
9
11 {
12  Map<MatrixXd> W(SGW.matrix,SGW.num_rows,SGW.num_cols);
13  Map<MatrixXd> A(SGA.matrix,SGA.num_rows,SGA.num_cols);
14
15  REQUIRE(W.rows() == W.cols(), "amari_index - W must be square\n")
16  REQUIRE(A.rows() == A.cols(), "amari_index - A must be square\n")
17  REQUIRE(W.rows() == A.rows(), "amari_index - A and W must be the same size\n")
18  REQUIRE(W.rows() >= 2, "amari_index - input must be at least 2x2\n")
19
20  // normalizing both mixing matrices
21  if (standardize)
22  {
23  for (int r = 0; r < W.rows(); r++)
24  {
25  W.row(r).normalize();
26  if (W.row(r).maxCoeff() < -1*W.row(r).minCoeff())
27  W.row(r) *= -1;
28  }
29
30  A = A.inverse();
31  for (int r = 0; r < A.rows(); r++)
32  {
33  A.row(r).normalize();
34  if (A.row(r).maxCoeff() < -1*A.row(r).minCoeff())
35  A.row(r) *= -1;
36  }
37  A = A.inverse();
38
39  bool swap = false;
40  do
41  {
42  swap = false;
43  for (int j = 1; j < A.cols(); j++)
44  {
45  if (A(0,j) < A(0,j-1))
46  {
47  A.col(j).swap(A.col(j-1));
48  swap = true;
49  }
50  }
51
52  } while(swap);
53  }
54
55  // calculating the permutation matrix
56  MatrixXd P = (W * A).cwiseAbs();
57  int k = P.rows();
58
59  // summing the error in the permutation matrix
60  MatrixXd E1(k,k);
61  for (int r = 0; r < k; r++)
62  E1.row(r) = P.row(r) / P.row(r).maxCoeff();
63
64  float64_t row_error = (E1.rowwise().sum().array()-1).sum();
65
66  MatrixXd E2(k,k);
67  for (int c = 0; c < k; c++)
68  E2.col(c) = P.col(c) / P.col(c).maxCoeff();
69
70  float64_t col_error = (E2.colwise().sum().array()-1).sum();
71
72  return 1.0 / (float)(2*k*(k-1)) * (row_error + col_error);
73
74 }
Definition: SGMatrix.h:24
#define REQUIRE(x,...)
Definition: SGIO.h:205
index_t num_cols
Definition: SGMatrix.h:465
index_t num_rows
Definition: SGMatrix.h:463
float64_t amari_index(SGMatrix< float64_t > SGW, SGMatrix< float64_t > SGA, bool standardize)
function amari_index
Definition: AmariIndex.cpp:10
double float64_t
Definition: common.h:60
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18
T sum(const Container< T > &a, bool no_diag=false)

SHOGUN Machine Learning Toolbox - Documentation