SHOGUN  4.1.0
 全部  命名空间 文件 函数 变量 类型定义 枚举 枚举值 友元 宏定义  
GPUMatrix.cpp
浏览该文件的文档.
1 /*
2  * Copyright (c) 2014, Shogun Toolbox Foundation
3  * All rights reserved.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7 
8  * 1. Redistributions of source code must retain the above copyright notice,
9  * this list of conditions and the following disclaimer.
10  *
11  * 2. Redistributions in binary form must reproduce the above copyright notice,
12  * this list of conditions and the following disclaimer in the documentation
13  * and/or other materials provided with the distribution.
14  *
15  * 3. Neither the name of the copyright holder nor the names of its
16  * contributors may be used to endorse or promote products derived from this
17  * software without specific prior written permission.
18 
19  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22  * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23  * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29  * POSSIBILITY OF SUCH DAMAGE.
30  *
31  * Written (W) 2014 Khaled Nasr
32  */
33 
34 #include <shogun/lib/config.h>
35 
36 #ifdef HAVE_VIENNACL
37 #ifdef HAVE_CXX11
38 
39 #include <shogun/lib/GPUMatrix.h>
40 #include <viennacl/matrix.hpp>
41 
42 #ifdef HAVE_EIGEN3
44 #endif
45 
46 #include <shogun/lib/SGMatrix.h>
47 
48 namespace shogun
49 {
50 
51 template <class T>
52 CGPUMatrix<T>::CGPUMatrix()
53 {
54  init();
55 }
56 
57 template <class T>
58 CGPUMatrix<T>::CGPUMatrix(index_t nrows, index_t ncols) : matrix(new VCLMemoryArray())
59 {
60  init();
61 
62  num_rows = nrows;
63  num_cols = ncols;
64 
65  viennacl::backend::memory_create(*matrix, sizeof(T)*num_rows*num_cols,
66  viennacl::context());
67 }
68 
69 template <class T>
70 CGPUMatrix<T>::CGPUMatrix(std::shared_ptr<VCLMemoryArray> mem, index_t nrows, index_t ncols,
71  index_t mem_offset)
72 {
73  init();
74 
75  matrix = mem;
76  num_rows = nrows;
77  num_cols = ncols;
78  offset = mem_offset;
79 }
80 
81 template <class T>
82 CGPUMatrix<T>::CGPUMatrix(const SGMatrix< T >& cpu_mat) : matrix(new VCLMemoryArray())
83 {
84  init();
85 
86  num_rows = cpu_mat.num_rows;
87  num_cols = cpu_mat.num_cols;
88 
89  viennacl::backend::memory_create(*matrix, sizeof(T)*num_rows*num_cols,
90  viennacl::context());
91 
92  viennacl::backend::memory_write(*matrix, 0, num_rows*num_cols*sizeof(T),
93  cpu_mat.matrix);
94 }
95 
96 #ifdef HAVE_EIGEN3
97 template <class T>
98 CGPUMatrix<T>::CGPUMatrix(const EigenMatrixXt& cpu_mat)
99 : matrix(new VCLMemoryArray())
100 {
101  init();
102 
103  num_rows = cpu_mat.rows();
104  num_cols = cpu_mat.cols();
105 
106  viennacl::backend::memory_create(*matrix, sizeof(T)*num_rows*num_cols,
107  viennacl::context());
108 
109  viennacl::backend::memory_write(*matrix, 0, num_rows*num_cols*sizeof(T),
110  cpu_mat.data());
111 }
112 
113 template <class T>
114 CGPUMatrix<T>::operator EigenMatrixXt() const
115 {
116  EigenMatrixXt cpu_mat(num_rows, num_cols);
117 
118  viennacl::backend::memory_read(*matrix, offset*sizeof(T), num_rows*num_cols*sizeof(T),
119  cpu_mat.data());
120 
121  return cpu_mat;
122 }
123 #endif
124 
125 template <class T>
126 CGPUMatrix<T>::operator SGMatrix<T>() const
127 {
128  SGMatrix<T> cpu_mat(num_rows, num_cols);
129 
130  viennacl::backend::memory_read(*matrix, offset*sizeof(T), num_rows*num_cols*sizeof(T),
131  cpu_mat.matrix);
132 
133  return cpu_mat;
134 }
135 
136 template <class T>
137 typename CGPUMatrix<T>::VCLMatrixBase CGPUMatrix<T>::vcl_matrix()
138 {
139  return VCLMatrixBase(*matrix,num_rows, offset, 1, num_rows, num_cols, 0, 1, num_cols);
140 }
141 
142 template <class T>
143 void CGPUMatrix<T>::display_matrix(const char* name) const
144 {
145  ((SGMatrix<T>)*this).display_matrix(name);
146 }
147 
148 template <class T>
149 void CGPUMatrix<T>::zero()
150 {
151  vcl_matrix().clear();
152 }
153 
154 template <class T>
155 void CGPUMatrix<T>::set_const(T value)
156 {
157  VCLMatrixBase m = vcl_matrix();
158  viennacl::linalg::matrix_assign(m, value);
159 }
160 
161 template <class T>
162 viennacl::const_entry_proxy<T> CGPUMatrix<T>::operator()(index_t i, index_t j) const
163 {
164  return viennacl::const_entry_proxy<T>(offset+i+j*num_rows, *matrix);
165 }
166 
167 template <class T>
168 viennacl::entry_proxy< T > CGPUMatrix<T>::operator()(index_t i, index_t j)
169 {
170  return viennacl::entry_proxy<T>(offset+i+j*num_rows, *matrix);
171 }
172 
173 template <class T>
174 viennacl::const_entry_proxy< T > CGPUMatrix<T>::operator[](index_t index) const
175 {
176  return viennacl::const_entry_proxy<T>(offset+index, *matrix);
177 }
178 
179 template <class T>
180 viennacl::entry_proxy< T > CGPUMatrix<T>::operator[](index_t index)
181 {
182  return viennacl::entry_proxy<T>(offset+index, *matrix);
183 }
184 
185 template <class T>
186 void CGPUMatrix<T>::init()
187 {
188  num_rows = 0;
189  num_cols = 0;
190  offset = 0;
191 }
192 
193 template class CGPUMatrix<char>;
194 template class CGPUMatrix<uint8_t>;
195 template class CGPUMatrix<int16_t>;
196 template class CGPUMatrix<uint16_t>;
197 template class CGPUMatrix<int32_t>;
198 template class CGPUMatrix<uint32_t>;
199 template class CGPUMatrix<int64_t>;
200 template class CGPUMatrix<uint64_t>;
201 template class CGPUMatrix<float32_t>;
202 template class CGPUMatrix<float64_t>;
203 }
204 
205 #endif // HAVE_CXX11
206 #endif // HAVE_VIENNACL
int32_t index_t
Definition: common.h:62
all of classes and functions are contained in the shogun namespace
Definition: class_list.h:18

SHOGUN 机器学习工具包 - 项目文档