31 #ifndef MATRIX_PRODUCT_IMPL_H_
32 #define MATRIX_PRODUCT_IMPL_H_
43 #include <viennacl/linalg/prod.hpp>
44 #include <viennacl/matrix.hpp>
45 #endif // HAVE_VIENNACL
53 namespace implementation
59 template <enum Backend,
class Matrix>
63 typedef typename Matrix::Scalar
T;
75 static void compute(Matrix A, Matrix B, Matrix C,
76 bool transpose_A,
bool transpose_B,
bool overwrite);
82 template <
class Matrix>
86 typedef typename Matrix::Scalar
T;
103 bool transpose_A,
bool transpose_B)
111 compute(A, B, retMatrix, transpose_A, transpose_B,
true);
127 bool transpose_A,
bool transpose_B,
bool overwrite)
135 if (transpose_A && transpose_B)
136 C_eig = A_eig.transpose() * B_eig.transpose();
138 else if (transpose_A)
139 C_eig = A_eig.transpose() * B_eig;
141 else if (transpose_B)
142 C_eig = A_eig * B_eig.transpose();
145 C_eig = A_eig * B_eig;
149 if (transpose_A && transpose_B)
150 C_eig += A_eig.transpose() * B_eig.transpose();
152 else if (transpose_A)
153 C_eig += A_eig.transpose() * B_eig;
155 else if (transpose_B)
156 C_eig += A_eig * B_eig.transpose();
159 C_eig += A_eig * B_eig;
163 #endif // HAVE_EIGEN3
168 template <
class Matrix>
169 struct matrix_product<Backend::VIENNACL, Matrix>
172 typedef typename Matrix::Scalar
T;
175 typedef CGPUMatrix<T> ReturnType;
185 static ReturnType
compute(CGPUMatrix<T> A, CGPUMatrix<T> B,
186 bool transpose_A,
bool transpose_B)
188 REQUIRE(A.matrix,
"Matrix A is not initialized!\n");
189 REQUIRE(B.matrix,
"Matrix B is not initialized!\n");
190 REQUIRE(A.num_cols == B.num_rows,
"Number of columns for A (%d) and "
191 "number of rows for B (%d) should be equal!\n", A.num_cols, B.num_rows);
193 ReturnType retMatrix(A.num_rows, B.num_cols);
194 compute(A, B, retMatrix, transpose_A, transpose_B,
true);
209 static void compute(CGPUMatrix<T> A, CGPUMatrix<T> B, CGPUMatrix<T> C,
210 bool transpose_A,
bool transpose_B,
bool overwrite)
214 if (transpose_A && transpose_B)
215 C.vcl_matrix() = viennacl::linalg::prod(
216 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
218 else if (transpose_A)
219 C.vcl_matrix() = viennacl::linalg::prod(
220 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
222 else if (transpose_B)
223 C.vcl_matrix() = viennacl::linalg::prod(
224 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
227 C.vcl_matrix() = viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
231 if (transpose_A && transpose_B)
232 C.vcl_matrix() += viennacl::linalg::prod(
233 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
235 else if (transpose_A)
236 C.vcl_matrix() += viennacl::linalg::prod(
237 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
239 else if (transpose_B)
240 C.vcl_matrix() += viennacl::linalg::prod(
241 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
244 C.vcl_matrix() += viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
249 #endif // HAVE_VIENNACL
256 #endif // MATRIX_PRODUCT_IMPL_H_
static void compute(Matrix A, Matrix B, Matrix C, bool transpose_A, bool transpose_B, bool overwrite)
static ReturnType compute(SGMatrix< T > A, SGMatrix< T > B, bool transpose_A, bool transpose_B)
all of classes and functions are contained in the shogun namespace
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > MatrixXt
static void compute(SGMatrix< T > A, SGMatrix< T > B, SGMatrix< T > C, bool transpose_A, bool transpose_B, bool overwrite)