35 #ifndef __OPENCL_UTIL_H__ 36 #define __OPENCL_UTIL_H__ 38 #include <shogun/lib/config.h> 40 #include <viennacl/ocl/backend.hpp> 41 #include <viennacl/ocl/kernel.hpp> 42 #include <viennacl/ocl/program.hpp> 43 #include <viennacl/ocl/utils.hpp> 44 #include <viennacl/tools/tools.hpp> 50 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11) 51 #include <initializer_list> 53 #endif // defined(HAVE_CXX0X) || defined(HAVE_CXX11) 61 namespace implementation
69 std::string get_type_string()
71 return viennacl::ocl::type_to_string<T>::apply();
81 std::string generate_kernel_preamble(std::string kernel_name)
83 std::string type_string = get_type_string<T>();
85 std::string source =
"";
86 viennacl::ocl::append_double_precision_pragma<T>(viennacl::ocl::current_context(), source);
87 source.append(
"#define DATATYPE " + type_string +
"\n");
88 source.append(
"#define KERNEL_NAME " + kernel_name +
"\n");
89 source.append(
"#define WORK_GROUP_SIZE_1D " + std::to_string(OCL_WORK_GROUP_SIZE_1D) +
"\n");
90 source.append(
"#define WORK_GROUP_SIZE_2D " + std::to_string(OCL_WORK_GROUP_SIZE_2D) +
"\n");
96 inline bool kernel_exists(std::string kernel_name)
98 return viennacl::ocl::current_context().has_program(kernel_name);
102 inline viennacl::ocl::kernel& get_kernel(std::string kernel_name)
104 return viennacl::ocl::current_context().get_program(kernel_name).get_kernel(kernel_name);
108 inline viennacl::ocl::kernel& compile_kernel(std::string kernel_name, std::string source)
110 viennacl::ocl::program & prog =
111 viennacl::ocl::current_context().add_program(source, kernel_name);
113 return prog.get_kernel(kernel_name);
117 inline uint32_t align_to_multiple_1d(uint32_t n)
119 return viennacl::tools::align_to_multiple<uint32_t>(n, OCL_WORK_GROUP_SIZE_1D);
123 inline uint32_t align_to_multiple_2d(uint32_t n)
125 return viennacl::tools::align_to_multiple<uint32_t>(n, OCL_WORK_GROUP_SIZE_2D);
141 viennacl::ocl::kernel& generate_single_arg_elementwise_kernel(
142 std::string kernel_name, std::string operation)
144 if (ocl::kernel_exists(kernel_name))
145 return ocl::get_kernel(kernel_name);
147 std::string source = ocl::generate_kernel_preamble<T>(kernel_name);
149 source.append(
"inline DATATYPE operation(DATATYPE element)\n{\n");
150 source.append(operation);
151 source.append(
"\n}\n");
155 __kernel void KERNEL_NAME( 156 __global DATATYPE* vec, int size, int vec_offset, 157 __global DATATYPE* result, int result_offset) 159 int i = get_global_id(0); 162 result[i+result_offset] = operation(vec[i+vec_offset]); 167 viennacl::ocl::kernel& kernel = ocl::compile_kernel(kernel_name, source); 169 kernel.local_work_size(0, OCL_WORK_GROUP_SIZE_1D); 188 viennacl::ocl::kernel& generate_two_arg_elementwise_kernel(
189 std::string kernel_name, std::string operation)
191 if (ocl::kernel_exists(kernel_name))
192 return ocl::get_kernel(kernel_name);
194 std::string source = ocl::generate_kernel_preamble<T>(kernel_name);
196 source.append(
"inline DATATYPE operation(DATATYPE element1, DATATYPE element2)\n{\n");
197 source.append(operation);
198 source.append(
"\n}\n");
202 __kernel void KERNEL_NAME( 203 __global DATATYPE* vec1, int size, int vec1_offset, 204 __global DATATYPE* vec2, int vec2_offset, 205 __global DATATYPE* result, int result_offset) 207 int i = get_global_id(0); 210 result[i+result_offset] = 211 operation(vec1[i+vec1_offset], vec2[i+vec2_offset]); 216 viennacl::ocl::kernel& kernel = ocl::compile_kernel(kernel_name, source); 218 kernel.local_work_size(0, OCL_WORK_GROUP_SIZE_1D); 234 inline std::string replace_all(std::string str,
const std::string& from,
const std::string& to)
237 while ((start_pos=str.find(from, start_pos))!=std::string::npos)
239 str.replace(start_pos, from.length(), to);
240 start_pos+=to.length();
245 #if defined(HAVE_CXX0X) || defined(HAVE_CXX11) 262 inline std::string format(
const char* str, std::initializer_list<shogun::linalg::ocl::Parameter> params)
264 std::string fmt(str);
265 for (
auto i=params.begin(); i!=params.end(); ++i)
266 fmt=replace_all(fmt,
"{"+i->m_name+
"}", *i);
267 return fmt.append(
"\n");
269 #endif // defined(HAVE_CXX0X) || defined(HAVE_CXX11) 279 #endif // HAVE_VIENNACL 281 #endif // __OPENCL_UTIL_H__
all of classes and functions are contained in the shogun namespace