SHOGUN  4.2.0
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
AmariIndex.cpp
Go to the documentation of this file.
1 #include "AmariIndex.h"
2 
3 
6 
7 using namespace shogun;
8 using namespace Eigen;
9 
11 {
12  Map<MatrixXd> W(SGW.matrix,SGW.num_rows,SGW.num_cols);
13  Map<MatrixXd> A(SGA.matrix,SGA.num_rows,SGA.num_cols);
14 
15  REQUIRE(W.rows() == W.cols(), "amari_index - W must be square\n")
16  REQUIRE(A.rows() == A.cols(), "amari_index - A must be square\n")
17  REQUIRE(W.rows() == A.rows(), "amari_index - A and W must be the same size\n")
18  REQUIRE(W.rows() >= 2, "amari_index - input must be at least 2x2\n")
19 
20  // normalizing both mixing matrices
21  if (standardize)
22  {
23  for (int r = 0; r < W.rows(); r++)
24  {
25  W.row(r).normalize();
26  if (W.row(r).maxCoeff() < -1*W.row(r).minCoeff())
27  W.row(r) *= -1;
28  }
29 
30  A = A.inverse();
31  for (int r = 0; r < A.rows(); r++)
32  {
33  A.row(r).normalize();
34  if (A.row(r).maxCoeff() < -1*A.row(r).minCoeff())
35  A.row(r) *= -1;
36  }
37  A = A.inverse();
38 
39  bool swap = false;
40  do
41  {
42  swap = false;
43  for (int j = 1; j < A.cols(); j++)
44  {
45  if (A(0,j) < A(0,j-1))
46  {
47  A.col(j).swap(A.col(j-1));
48  swap = true;
49  }
50  }
51 
52  } while(swap);
53  }
54 
55  // calculating the permutation matrix
56  MatrixXd P = (W * A).cwiseAbs();
57  int k = P.rows();
58 
59  // summing the error in the permutation matrix
60  MatrixXd E1(k,k);
61  for (int r = 0; r < k; r++)
62  E1.row(r) = P.row(r) / P.row(r).maxCoeff();
63 
64  float64_t row_error = (E1.rowwise().sum().array()-1).sum();
65 
66  MatrixXd E2(k,k);
67  for (int c = 0; c < k; c++)
68  E2.col(c) = P.col(c) / P.col(c).maxCoeff();
69 
70  float64_t col_error = (E2.colwise().sum().array()-1).sum();
71 
72  return 1.0 / (float)(2*k*(k-1)) * (row_error + col_error);
73 
74 }
Definition: SGMatrix.h:20
#define REQUIRE(x,...)
Definition: SGIO.h:206
index_t num_cols
Definition: SGMatrix.h:376
index_t num_rows
Definition: SGMatrix.h:374
float64_t amari_index(SGMatrix< float64_t > SGW, SGMatrix< float64_t > SGA, bool standardize)
function amari_index
Definition: AmariIndex.cpp:10
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18

SHOGUN Machine Learning Toolbox - Documentation