SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
AmariIndex.cpp
浏览该文件的文档.
1 #include "AmariIndex.h"
2 
3 #ifdef HAVE_EIGEN3
4 
7 
8 using namespace shogun;
9 using namespace Eigen;
10 
12 {
13  Map<MatrixXd> W(SGW.matrix,SGW.num_rows,SGW.num_cols);
14  Map<MatrixXd> A(SGA.matrix,SGA.num_rows,SGA.num_cols);
15 
16  REQUIRE(W.rows() == W.cols(), "amari_index - W must be square\n")
17  REQUIRE(A.rows() == A.cols(), "amari_index - A must be square\n")
18  REQUIRE(W.rows() == A.rows(), "amari_index - A and W must be the same size\n")
19  REQUIRE(W.rows() >= 2, "amari_index - input must be at least 2x2\n")
20 
21  // normalizing both mixing matrices
22  if (standardize)
23  {
24  for (int r = 0; r < W.rows(); r++)
25  {
26  W.row(r).normalize();
27  if (W.row(r).maxCoeff() < -1*W.row(r).minCoeff())
28  W.row(r) *= -1;
29  }
30 
31  A = A.inverse();
32  for (int r = 0; r < A.rows(); r++)
33  {
34  A.row(r).normalize();
35  if (A.row(r).maxCoeff() < -1*A.row(r).minCoeff())
36  A.row(r) *= -1;
37  }
38  A = A.inverse();
39 
40  bool swap = false;
41  do
42  {
43  swap = false;
44  for (int j = 1; j < A.cols(); j++)
45  {
46  if (A(0,j) < A(0,j-1))
47  {
48  A.col(j).swap(A.col(j-1));
49  swap = true;
50  }
51  }
52 
53  } while(swap);
54  }
55 
56  // calculating the permutation matrix
57  MatrixXd P = (W * A).cwiseAbs();
58  int k = P.rows();
59 
60  // summing the error in the permutation matrix
61  MatrixXd E1(k,k);
62  for (int r = 0; r < k; r++)
63  E1.row(r) = P.row(r) / P.row(r).maxCoeff();
64 
65  float64_t row_error = (E1.rowwise().sum().array()-1).sum();
66 
67  MatrixXd E2(k,k);
68  for (int c = 0; c < k; c++)
69  E2.col(c) = P.col(c) / P.col(c).maxCoeff();
70 
71  float64_t col_error = (E2.colwise().sum().array()-1).sum();
72 
73  return 1.0 / (float)(2*k*(k-1)) * (row_error + col_error);
74 
75 }
76 #endif //HAVE_EIGEN3
Definition: SGMatrix.h:20
#define REQUIRE(x,...)
Definition: SGIO.h:206
index_t num_cols
Definition: SGMatrix.h:378
index_t num_rows
Definition: SGMatrix.h:376
float64_t amari_index(SGMatrix< float64_t > SGW, SGMatrix< float64_t > SGA, bool standardize)
function amari_index
Definition: AmariIndex.cpp:11
double float64_t
Definition: common.h:50
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18

SHOGUN 机器学习工具包 - 项目文档