@@ -374,78 +374,95 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
374374 )
375375 return self .builder .load (sycl_queue_val )
376376
377- def _allocate_kernel_arg_array (self , num_kernel_args ):
378- """Allocates an array to store the LLVM Value for every kernel argument.
377+ def _allocate_array (
378+ self , numba_type : types .Type , size : int
379+ ) -> llvmir .Instruction :
380+ """Allocates an LLVM array of given type and size.
379381
380382 Args:
381- num_kernel_args (int): The number of kernel arguments that
382- determines the size of args array to allocate.
383+ numba_type: type of the array to allocate,
384+ size: The size of the array to allocate.
383385
384- Returns: An LLVM IR value pointing to an array to store the kernel
385- arguments.
386+ Returns: An LLVM IR value pointing to the array.
386387 """
387- args_list = cgutils .alloca_once (
388+ return cgutils .alloca_once (
388389 self .builder ,
389- utils . LLVMTypes . byte_ptr_t ,
390- size = self .context .get_constant (types .uintp , num_kernel_args ),
390+ self . context . get_value_type ( numba_type ) ,
391+ size = self .context .get_constant (types .uintp , size ),
391392 )
392393
393- return args_list
394+ def _populate_array_from_python_list (
395+ self ,
396+ numba_type : types .Type ,
397+ py_array : list [llvmir .Instruction ],
398+ ll_array : llvmir .Instruction ,
399+ force_cast : bool = False ,
400+ ):
401+ """Populates LLVM values from an input Python list into an LLVM array.
402+
403+ Args:
404+ numba_type: type of the array to allocate,
405+ py_array: array of llvm ir values to populate.
406+ ll_array: llvm ir value that represents an array to populate,
407+ force_cast: either force cast values to the provided type.
408+ """
409+ for idx , ll_value in enumerate (py_array ):
410+ ll_array_dst = self .builder .gep (
411+ ll_array ,
412+ [self .context .get_constant (types .int32 , idx )],
413+ )
414+ # bitcast may be extra, but won't hurt,
415+ if force_cast :
416+ ll_value = self .builder .bitcast (
417+ ll_value ,
418+ self .context .get_value_type (numba_type ),
419+ )
420+ self .builder .store (ll_value , ll_array_dst )
394421
395- def _allocate_kernel_arg_ty_array (self , num_kernel_args ):
396- """Allocates an array to store the LLVM Value for the typenum for
397- every kernel argument.
422+ def _create_ll_from_py_list (
423+ self ,
424+ numba_type : types .Type ,
425+ list_of_ll_values : list [llvmir .Instruction ],
426+ force_cast : bool = False ,
427+ ) -> llvmir .Instruction :
428+ """Allocates an LLVM IR array of the same size as the input python list
429+ of LLVM IR Values and populates the array with the LLVM Values in the
430+ list.
398431
399432 Args:
400- num_kernel_args (int): The number of kernel arguments that
401- determines the size of args array to allocate.
433+ numba_type: type of the array to allocate,
434+ list_of_ll_values: list of LLVM IR values to populate,
435+ force_cast: either force cast values to the provided type.
402436
403- Returns: An LLVM IR value pointing to an array to store the kernel
404- arguments typenums as defined in dpctl.
437+ Returns: An LLVM IR value pointing to the array.
405438 """
406- args_ty_list = cgutils .alloca_once (
407- self .builder ,
408- utils .LLVMTypes .int32_t ,
409- size = self .context .get_constant (types .uintp , num_kernel_args ),
439+ ll_array = self ._allocate_array (numba_type , len (list_of_ll_values ))
440+ self ._populate_array_from_python_list (
441+ numba_type , list_of_ll_values , ll_array , force_cast
410442 )
411443
412- return args_ty_list
444+ return ll_array
413445
414446 def _create_sycl_range (self , idx_range ):
415- """Allocate a size_t[3] array to store the extents of a sycl::range.
447+ """Allocate an array to store the extents of a sycl::range.
416448
417449 Sycl supports upto 3-dimensional ranges and a such the array is
418450 statically sized to length three. Only the elements that store an actual
419451 range value are populated based on the size of the idx_range argument.
420452
421453 """
422- intp_t = utils .get_llvm_type (context = self .context , type = types .intp )
423- intp_ptr_t = utils .get_llvm_ptr_type (intp_t )
424- num_dim = len (idx_range )
454+ int64_range = [
455+ self .builder .sext (rext , utils .LLVMTypes .int64_t )
456+ if rext .type != utils .LLVMTypes .int64_t
457+ else rext
458+ for rext in idx_range
459+ ]
425460
426- # form the global range
427- range_list = cgutils .alloca_once (
428- self .builder ,
429- utils .get_llvm_type (context = self .context , type = types .uintp ),
430- size = self .context .get_constant (types .uintp , MAX_SIZE_OF_SYCL_RANGE ),
431- )
432-
433- for i in range (num_dim ):
434- rext = idx_range [i ]
435- if rext .type != utils .LLVMTypes .int64_t :
436- rext = self .builder .sext (rext , utils .LLVMTypes .int64_t )
437-
438- # we reverse the global range to account for how sycl and opencl
439- # range differs
440- self .builder .store (
441- rext ,
442- self .builder .gep (
443- range_list ,
444- [self .context .get_constant (types .uintp , (num_dim - 1 ) - i )],
445- ),
446- )
461+ # we reverse the global range to account for how sycl and opencl
462+ # range differs
463+ int64_range .reverse ()
447464
448- return self .builder . bitcast ( range_list , intp_ptr_t )
465+ return self ._create_ll_from_py_list ( types . uintp , int64_range )
449466
450467 def set_kernel (self , sycl_kernel_ref : llvmir .Instruction ):
451468 """Sets kernel to the argument list."""
@@ -597,10 +614,14 @@ def set_arguments(
597614 )
598615
599616 # Create LLVM values for the kernel args list and kernel arg types list
600- args_list = self ._allocate_kernel_arg_array (num_flattened_kernel_args )
617+ args_list = self ._allocate_array (
618+ types .voidptr ,
619+ num_flattened_kernel_args ,
620+ )
601621
602- args_ty_list = self ._allocate_kernel_arg_ty_array (
603- num_flattened_kernel_args
622+ args_ty_list = self ._allocate_array (
623+ types .int32 ,
624+ num_flattened_kernel_args ,
604625 )
605626
606627 kernel_args_ptrs = []
@@ -624,20 +645,17 @@ def set_arguments(
624645 types .uintp , num_flattened_kernel_args
625646 )
626647
627- def _extract_arguments_from_tuple (
648+ def _extract_llvm_values_from_tuple (
628649 self ,
629- ty_kernel_args_tuple : UniTuple ,
630- ll_kernel_args_tuple : llvmir .Instruction ,
650+ ll_tuple : llvmir .Instruction ,
631651 ) -> list [llvmir .Instruction ]:
632652 """Extracts LLVM IR values from llvm tuple into python array."""
633653
634- kernel_args = []
635- for pos in range (len (ty_kernel_args_tuple )):
636- kernel_args .append (
637- self .builder .extract_value (ll_kernel_args_tuple , pos )
638- )
654+ llvm_values = []
655+ for pos in range (len (ll_tuple .type )):
656+ llvm_values .append (self .builder .extract_value (ll_tuple , pos ))
639657
640- return kernel_args
658+ return llvm_values
641659
642660 def set_arguments_form_tuple (
643661 self ,
@@ -647,9 +665,7 @@ def set_arguments_form_tuple(
647665 """Sets flattened kernel args, kernel arg types and number of those
648666 arguments to the argument list based on the arguments stored in tuple.
649667 """
650- kernel_args = self ._extract_arguments_from_tuple (
651- ty_kernel_args_tuple , ll_kernel_args_tuple
652- )
668+ kernel_args = self ._extract_llvm_values_from_tuple (ll_kernel_args_tuple )
653669 self .set_arguments (ty_kernel_args_tuple , kernel_args )
654670
655671 def set_dependant_event_list (self , dep_events : list [llvmir .Instruction ]):
@@ -708,22 +724,7 @@ def _allocate_meminfo_array(
708724 )
709725 ]
710726
711- meminfo_list = cgutils .alloca_once (
712- self .builder ,
713- utils .get_llvm_type (context = self .context , type = types .voidptr ),
714- size = self .context .get_constant (types .uintp , len (meminfos )),
715- )
716-
717- for meminfo_num , meminfo in enumerate (meminfos ):
718- meminfo_arg_dst = self .builder .gep (
719- meminfo_list ,
720- [self .context .get_constant (types .int32 , meminfo_num )],
721- )
722- meminfo_ptr = self .builder .bitcast (
723- meminfo ,
724- utils .get_llvm_type (context = self .context , type = types .voidptr ),
725- )
726- self .builder .store (meminfo_ptr , meminfo_arg_dst )
727+ meminfo_list = self ._create_ll_from_py_list (types .voidptr , meminfos )
727728
728729 return len (meminfos ), meminfo_list
729730
0 commit comments