31 #ifndef MATRIX_PRODUCT_IMPL_H_
32 #define MATRIX_PRODUCT_IMPL_H_
41 #include <viennacl/linalg/prod.hpp>
42 #include <viennacl/matrix.hpp>
43 #endif // HAVE_VIENNACL
51 namespace implementation
57 template <enum Backend,
class Matrix>
61 typedef typename Matrix::Scalar
T;
73 static void compute(Matrix A, Matrix B, Matrix C,
74 bool transpose_A,
bool transpose_B,
bool overwrite);
79 template <
class Matrix>
83 typedef typename Matrix::Scalar
T;
100 bool transpose_A,
bool transpose_B)
108 compute(A, B, retMatrix, transpose_A, transpose_B,
true);
124 bool transpose_A,
bool transpose_B,
bool overwrite)
132 if (transpose_A && transpose_B)
133 C_eig = A_eig.transpose() * B_eig.transpose();
135 else if (transpose_A)
136 C_eig = A_eig.transpose() * B_eig;
138 else if (transpose_B)
139 C_eig = A_eig * B_eig.transpose();
142 C_eig = A_eig * B_eig;
146 if (transpose_A && transpose_B)
147 C_eig += A_eig.transpose() * B_eig.transpose();
149 else if (transpose_A)
150 C_eig += A_eig.transpose() * B_eig;
152 else if (transpose_B)
153 C_eig += A_eig * B_eig.transpose();
156 C_eig += A_eig * B_eig;
164 template <
class Matrix>
165 struct matrix_product<Backend::VIENNACL, Matrix>
168 typedef typename Matrix::Scalar
T;
171 typedef CGPUMatrix<T> ReturnType;
181 static ReturnType
compute(CGPUMatrix<T> A, CGPUMatrix<T> B,
182 bool transpose_A,
bool transpose_B)
184 REQUIRE(A.matrix,
"Matrix A is not initialized!\n");
185 REQUIRE(B.matrix,
"Matrix B is not initialized!\n");
186 REQUIRE(A.num_cols == B.num_rows,
"Number of columns for A (%d) and "
187 "number of rows for B (%d) should be equal!\n", A.num_cols, B.num_rows);
189 ReturnType retMatrix(A.num_rows, B.num_cols);
190 compute(A, B, retMatrix, transpose_A, transpose_B,
true);
205 static void compute(CGPUMatrix<T> A, CGPUMatrix<T> B, CGPUMatrix<T> C,
206 bool transpose_A,
bool transpose_B,
bool overwrite)
210 if (transpose_A && transpose_B)
211 C.vcl_matrix() = viennacl::linalg::prod(
212 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
214 else if (transpose_A)
215 C.vcl_matrix() = viennacl::linalg::prod(
216 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
218 else if (transpose_B)
219 C.vcl_matrix() = viennacl::linalg::prod(
220 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
223 C.vcl_matrix() = viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
227 if (transpose_A && transpose_B)
228 C.vcl_matrix() += viennacl::linalg::prod(
229 viennacl::trans(A.vcl_matrix()), viennacl::trans(B.vcl_matrix()));
231 else if (transpose_A)
232 C.vcl_matrix() += viennacl::linalg::prod(
233 viennacl::trans(A.vcl_matrix()), B.vcl_matrix());
235 else if (transpose_B)
236 C.vcl_matrix() += viennacl::linalg::prod(
237 A.vcl_matrix(), viennacl::trans(B.vcl_matrix()));
240 C.vcl_matrix() += viennacl::linalg::prod(A.vcl_matrix(), B.vcl_matrix());
245 #endif // HAVE_VIENNACL
252 #endif // MATRIX_PRODUCT_IMPL_H_
static void compute(Matrix A, Matrix B, Matrix C, bool transpose_A, bool transpose_B, bool overwrite)
static ReturnType compute(SGMatrix< T > A, SGMatrix< T > B, bool transpose_A, bool transpose_B)
all of classes and functions are contained in the shogun namespace
Eigen::Matrix< T, Eigen::Dynamic, Eigen::Dynamic > MatrixXt
static void compute(SGMatrix< T > A, SGMatrix< T > B, SGMatrix< T > C, bool transpose_A, bool transpose_B, bool overwrite)