35 #ifndef __OPENCL_UTIL_H__
36 #define __OPENCL_UTIL_H__
40 #include <viennacl/ocl/backend.hpp>
41 #include <viennacl/ocl/kernel.hpp>
42 #include <viennacl/ocl/program.hpp>
43 #include <viennacl/ocl/utils.hpp>
44 #include <viennacl/tools/tools.hpp>
50 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11)
51 #include <initializer_list>
53 #endif // defined(HAVE_CXX0X) || defined(HAVE_CXX11)
61 namespace implementation
69 std::string get_type_string()
71 return viennacl::ocl::type_to_string<T>::apply();
81 std::string generate_kernel_preamble(std::string kernel_name)
83 std::string type_string = get_type_string<T>();
85 std::string source =
"";
86 viennacl::ocl::append_double_precision_pragma<T>(viennacl::ocl::current_context(), source);
87 source.append(
"#define DATATYPE " + type_string +
"\n");
88 source.append(
"#define KERNEL_NAME " + kernel_name +
"\n");
89 source.append(
"#define WORK_GROUP_SIZE_1D " + std::to_string(OCL_WORK_GROUP_SIZE_1D) +
"\n");
90 source.append(
"#define WORK_GROUP_SIZE_2D " + std::to_string(OCL_WORK_GROUP_SIZE_2D) +
"\n");
96 inline bool kernel_exists(std::string kernel_name)
98 return viennacl::ocl::current_context().has_program(kernel_name);
102 inline viennacl::ocl::kernel& get_kernel(std::string kernel_name)
104 return viennacl::ocl::current_context().get_program(kernel_name).get_kernel(kernel_name);
108 inline viennacl::ocl::kernel& compile_kernel(std::string kernel_name, std::string source)
110 viennacl::ocl::program & prog =
111 viennacl::ocl::current_context().add_program(source, kernel_name);
113 return prog.get_kernel(kernel_name);
117 inline uint32_t align_to_multiple_1d(uint32_t n)
119 return viennacl::tools::align_to_multiple<uint32_t>(n, OCL_WORK_GROUP_SIZE_1D);
123 inline uint32_t align_to_multiple_2d(uint32_t n)
125 return viennacl::tools::align_to_multiple<uint32_t>(n, OCL_WORK_GROUP_SIZE_2D);
141 viennacl::ocl::kernel& generate_single_arg_elementwise_kernel(
142 std::string kernel_name, std::string operation)
144 if (ocl::kernel_exists(kernel_name))
145 return ocl::get_kernel(kernel_name);
147 std::string source = ocl::generate_kernel_preamble<T>(kernel_name);
149 source.append(
"inline DATATYPE operation(DATATYPE element)\n{\n");
150 source.append(operation);
151 source.append(
"\n}\n");
155 __kernel void KERNEL_NAME(
156 __global DATATYPE* vec, int size, int vec_offset,
157 __global DATATYPE* result, int result_offset)
159 int i = get_global_id(0);
162 result[i+result_offset] = operation(vec[i+vec_offset]);
167 viennacl::ocl::kernel& kernel = ocl::compile_kernel(kernel_name, source);
169 kernel.local_work_size(0, OCL_WORK_GROUP_SIZE_1D);
188 viennacl::ocl::kernel& generate_two_arg_elementwise_kernel(
189 std::string kernel_name, std::string operation)
191 if (ocl::kernel_exists(kernel_name))
192 return ocl::get_kernel(kernel_name);
194 std::string source = ocl::generate_kernel_preamble<T>(kernel_name);
196 source.append(
"inline DATATYPE operation(DATATYPE element1, DATATYPE element2)\n{\n");
197 source.append(operation);
198 source.append(
"\n}\n");
202 __kernel void KERNEL_NAME(
203 __global DATATYPE* vec1, int size, int vec1_offset,
204 __global DATATYPE* vec2, int vec2_offset,
205 __global DATATYPE* result, int result_offset)
207 int i = get_global_id(0);
210 result[i+result_offset] =
211 operation(vec1[i+vec1_offset], vec2[i+vec2_offset]);
216 viennacl::ocl::kernel& kernel = ocl::compile_kernel(kernel_name, source);
218 kernel.local_work_size(0, OCL_WORK_GROUP_SIZE_1D);
234 inline std::string replace_all(std::string str,
const std::string& from,
const std::string& to)
237 while ((start_pos=str.find(from, start_pos))!=std::string::npos)
239 str.replace(start_pos, from.length(), to);
240 start_pos+=to.length();
245 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11)
262 inline std::string format(
const char* str, std::initializer_list<shogun::linalg::ocl::Parameter> params)
264 std::string fmt(str);
265 for (
auto i=params.begin(); i!=params.end(); ++i)
266 fmt=replace_all(fmt,
"{"+i->m_name+
"}", *i);
267 return fmt.append(
"\n");
269 #endif // defined(HAVE_CXX0X) || defined(HAVE_CXX11)
279 #endif // HAVE_VIENNACL
281 #endif // __OPENCL_UTIL_H__
all of classes and functions are contained in the shogun namespace