1313)
1414
1515from numba_dpex import config
16- from numba_dpex .core .datamodel .models import dpex_data_model_manager as dpex_dmm
1716from numba_dpex .core .parfors .reduction_helper import (
1817 ReductionHelper ,
1918 ReductionKernelVariables ,
2019)
2120from numba_dpex .core .utils .kernel_launcher import KernelLaunchIRBuilder
21+ from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
22+ from numba_dpex .core .datamodel .models import (
23+ dpex_data_model_manager as kernel_dmm ,
24+ )
2225
2326from ..exceptions import UnsupportedParforError
2427from ..types .dpnp_ndarray_type import DpnpNdArray
25- from .kernel_builder import create_kernel_for_parfor
28+ from .kernel_builder import ParforKernel , create_kernel_for_parfor
2629from .reduction_kernel_builder import (
2730 create_reduction_main_kernel_for_parfor ,
2831 create_reduction_remainder_kernel_for_parfor ,
2932)
3033
31- _KernelArgs = namedtuple (
32- "_KernelArgs" ,
33- ["num_flattened_args" , "arg_vals" , "arg_types" ],
34- )
35-
3634
3735# A global list of kernels to keep the objects alive indefinitely.
3836keep_alive_kernels = []
@@ -68,11 +66,8 @@ def _getvar(lowerer, x):
6866 var_val = lowerer .varmap [x ]
6967
7068 if var_val :
71- if not isinstance (var_val .type , llvmir .PointerType ):
72- with lowerer .builder .goto_entry_block ():
73- var_val_ptr = lowerer .builder .alloca (var_val .type )
74- lowerer .builder .store (var_val , var_val_ptr )
75- return var_val_ptr
69+ if isinstance (var_val .type , llvmir .PointerType ):
70+ return lowerer .builder .load (var_val )
7671 else :
7772 return var_val
7873 else :
@@ -91,75 +86,15 @@ class ParforLowerImpl:
9186 for a parfor and submits it to a queue.
9287 """
9388
94- def _build_kernel_arglist (self , kernel_fn , lowerer , kernel_builder ):
95- """Creates local variables for all the arguments and the argument types
96- that are passes to the kernel function.
97-
98- Args:
99- kernel_fn: Kernel function to be launched.
100- lowerer: The Numba lowerer used to generate the LLVM IR
101-
102- Raises:
103- AssertionError: If the LLVM IR Value for an argument defined in
104- Numba IR is not found.
105- """
106- num_flattened_args = 0
107-
108- # Compute number of args to be passed to the kernel. Note that the
109- # actual number of kernel arguments is greater than the count of
110- # kernel_fn.kernel_args as arrays get flattened.
111- for arg_type in kernel_fn .kernel_arg_types :
112- if isinstance (arg_type , DpnpNdArray ):
113- datamodel = dpex_dmm .lookup (arg_type )
114- num_flattened_args += datamodel .flattened_field_count
115- elif arg_type == types .complex64 or arg_type == types .complex128 :
116- num_flattened_args += 2
117- else :
118- num_flattened_args += 1
119-
120- # Create LLVM values for the kernel args list and kernel arg types list
121- args_list = kernel_builder .allocate_kernel_arg_array (num_flattened_args )
122- args_ty_list = kernel_builder .allocate_kernel_arg_ty_array (
123- num_flattened_args
124- )
125- callargs_ptrs = []
126- for arg in kernel_fn .kernel_args :
127- callargs_ptrs .append (_getvar (lowerer , arg ))
128-
129- kernel_builder .populate_kernel_args_and_args_ty_arrays (
130- kernel_argtys = kernel_fn .kernel_arg_types ,
131- callargs_ptrs = callargs_ptrs ,
132- args_list = args_list ,
133- args_ty_list = args_ty_list ,
134- )
135-
136- return _KernelArgs (
137- num_flattened_args = num_flattened_args ,
138- arg_vals = args_list ,
139- arg_types = args_ty_list ,
140- )
141-
142- def _submit_parfor_kernel (
89+ def _loop_ranges (
14390 self ,
14491 lowerer ,
145- kernel_fn ,
14692 loop_ranges ,
14793 ):
148- """
149- Adds a call to submit a kernel function into the function body of the
150- current Numba JIT compiled function.
151- """
152- # Ensure that the Python arguments are kept alive for the duration of
153- # the kernel execution
154- keep_alive_kernels .append (kernel_fn .kernel )
155- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
156-
157- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
158- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
159-
16094 # Create a global range over which to submit the kernel based on the
16195 # loop_ranges of the parfor
16296 global_range = []
97+
16398 # SYCL ranges can have at max 3 dimension. If the parfor is of a higher
16499 # dimension then the indexing for the higher dimensions is done inside
165100 # the kernel.
@@ -173,48 +108,19 @@ def _submit_parfor_kernel(
173108 "non-unit strides are not yet supported."
174109 )
175110 global_range .append (stop )
176-
111+ # For now the local_range is always an empty list as numba_dpex always
112+ # submits kernels generated for parfor nodes as range kernels.
113+ # The provision is kept here if in future there is newer functionality
114+ # to submit these kernels as ndrange.
177115 local_range = []
178116
179- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
180- kernel_ref = lowerer .builder .inttoptr (
181- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
182- cgutils .voidptr_t ,
183- )
184- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
185-
186- # Submit a synchronous kernel
187- kernel_builder .submit_sycl_kernel (
188- sycl_kernel_ref = kernel_ref ,
189- sycl_queue_ref = curr_queue_ref ,
190- total_kernel_args = args .num_flattened_args ,
191- arg_list = args .arg_vals ,
192- arg_ty_list = args .arg_types ,
193- global_range = global_range ,
194- local_range = local_range ,
195- )
196-
197- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
198- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
117+ return global_range , local_range
199118
200- def _submit_reduction_main_parfor_kernel (
119+ def _reduction_ranges (
201120 self ,
202121 lowerer ,
203- kernel_fn ,
204122 reductionHelper = None ,
205123 ):
206- """
207- Adds a call to submit the main kernel of a parfor reduction into the
208- function body of the current Numba JIT compiled function.
209- """
210- # Ensure that the Python arguments are kept alive for the duration of
211- # the kernel execution
212- keep_alive_kernels .append (kernel_fn .kernel )
213- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
214-
215- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
216-
217- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
218124 # Create a global range over which to submit the kernel based on the
219125 # loop_ranges of the parfor
220126 global_range = []
@@ -228,75 +134,63 @@ def _submit_reduction_main_parfor_kernel(
228134 _load_range (lowerer , reductionHelper .work_group_size )
229135 )
230136
231- kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
232- kernel_ref = lowerer .builder .inttoptr (
233- lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
234- cgutils .voidptr_t ,
235- )
236- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
237-
238- # Submit a synchronous kernel
239- kernel_builder .submit_sycl_kernel (
240- sycl_kernel_ref = kernel_ref ,
241- sycl_queue_ref = curr_queue_ref ,
242- total_kernel_args = args .num_flattened_args ,
243- arg_list = args .arg_vals ,
244- arg_ty_list = args .arg_types ,
245- global_range = global_range ,
246- local_range = local_range ,
247- )
137+ return global_range , local_range
248138
249- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
250- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
139+ def _remainder_ranges (self , lowerer ):
140+ # Create a global range over which to submit the kernel based on the
141+ # loop_ranges of the parfor
142+ global_range = []
251143
252- def _submit_reduction_remainder_parfor_kernel (
144+ stop = _load_range (lowerer , 1 )
145+
146+ global_range .append (stop )
147+
148+ local_range = []
149+
150+ return global_range , local_range
151+
152+ def _submit_parfor_kernel (
253153 self ,
254154 lowerer ,
255- kernel_fn ,
155+ kernel_fn : ParforKernel ,
156+ global_range ,
157+ local_range ,
256158 ):
257159 """
258- Adds a call to submit the remainder kernel of a parfor reduction into
259- the function body of the current Numba JIT compiled function.
160+ Adds a call to submit a kernel function into the function body of the
161+ current Numba JIT compiled function.
260162 """
261163 # Ensure that the Python arguments are kept alive for the duration of
262164 # the kernel execution
263165 keep_alive_kernels .append (kernel_fn .kernel )
166+ kl_builder = KernelLaunchIRBuilder (
167+ lowerer .context , lowerer .builder , kernel_dmm
168+ )
264169
265- kernel_builder = KernelLaunchIRBuilder (lowerer .context , lowerer .builder )
266-
267- ptr_to_queue_ref = kernel_builder .get_queue (exec_queue = kernel_fn .queue )
268-
269- args = self ._build_kernel_arglist (kernel_fn , lowerer , kernel_builder )
270- # Create a global range over which to submit the kernel based on the
271- # loop_ranges of the parfor
272- global_range = []
273-
274- stop = _load_range (lowerer , 1 )
170+ queue_ref = kl_builder .get_queue (exec_queue = kernel_fn .queue )
275171
276- global_range . append ( stop )
277-
278- local_range = []
172+ kernel_args = []
173+ for arg in kernel_fn . kernel_args :
174+ kernel_args . append ( _getvar ( lowerer , arg ))
279175
280176 kernel_ref_addr = kernel_fn .kernel .addressof_ref ()
281177 kernel_ref = lowerer .builder .inttoptr (
282178 lowerer .context .get_constant (types .uintp , kernel_ref_addr ),
283179 cgutils .voidptr_t ,
284180 )
285- curr_queue_ref = lowerer .builder .load (ptr_to_queue_ref )
286-
287- # Submit a synchronous kernel
288- kernel_builder .submit_sycl_kernel (
289- sycl_kernel_ref = kernel_ref ,
290- sycl_queue_ref = curr_queue_ref ,
291- total_kernel_args = args .num_flattened_args ,
292- arg_list = args .arg_vals ,
293- arg_ty_list = args .arg_types ,
294- global_range = global_range ,
295- local_range = local_range ,
181+
182+ kl_builder .set_kernel (kernel_ref )
183+ kl_builder .set_queue (queue_ref )
184+ kl_builder .set_range (global_range , local_range )
185+ kl_builder .set_arguments (
186+ kernel_fn .kernel_arg_types , kernel_args = kernel_args
296187 )
188+ kl_builder .set_dependant_event_list (dep_events = [])
189+ event_ref = kl_builder .submit ()
297190
298- # At this point we can free the DPCTLSyclQueueRef (curr_queue)
299- kernel_builder .free_queue (ptr_to_sycl_queue_ref = ptr_to_queue_ref )
191+ sycl .dpctl_event_wait (lowerer .builder , event_ref )
192+ sycl .dpctl_event_delete (lowerer .builder , event_ref )
193+ sycl .dpctl_queue_delete (lowerer .builder , queue_ref )
300194
301195 def _reduction_codegen (
302196 self ,
@@ -360,10 +254,15 @@ def _reduction_codegen(
360254 parfor_reddict ,
361255 )
362256
363- self ._submit_reduction_main_parfor_kernel (
257+ global_range , local_range = self ._reduction_ranges (
258+ lowerer , reductionHelperList [0 ]
259+ )
260+
261+ self ._submit_parfor_kernel (
364262 lowerer ,
365263 parfor_kernel ,
366- reductionHelperList [0 ],
264+ global_range ,
265+ local_range ,
367266 )
368267
369268 parfor_kernel = create_reduction_remainder_kernel_for_parfor (
@@ -376,9 +275,13 @@ def _reduction_codegen(
376275 reductionHelperList ,
377276 )
378277
379- self ._submit_reduction_remainder_parfor_kernel (
278+ global_range , local_range = self ._remainder_ranges (lowerer )
279+
280+ self ._submit_parfor_kernel (
380281 lowerer ,
381282 parfor_kernel ,
283+ global_range ,
284+ local_range ,
382285 )
383286
384287 reductionKernelVar .copy_final_sum_to_host (parfor_kernel )
@@ -492,11 +395,14 @@ def _lower_parfor_as_kernel(self, lowerer, parfor):
492395 # FIXME: Make the exception more informative
493396 raise UnsupportedParforError
494397
398+ global_range , local_range = self ._loop_ranges (lowerer , loop_ranges )
399+
495400 # Finally submit the kernel
496401 self ._submit_parfor_kernel (
497402 lowerer ,
498403 parfor_kernel ,
499- loop_ranges ,
404+ global_range ,
405+ local_range ,
500406 )
501407
502408 # TODO: free the kernel at this point
0 commit comments