@@ -374,78 +374,95 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
374374 )
375375 return self .builder .load (sycl_queue_val )
376376
377- def _allocate_kernel_arg_array (self , num_kernel_args ):
378- """Allocates an array to store the LLVM Value for every kernel argument.
377+ def _allocate_array (
378+ self , numba_type : types .Type , size : int
379+ ) -> llvmir .Instruction :
380+ """Allocates an LLVM array of given type and size.
379381
380382 Args:
381- num_kernel_args (int): The number of kernel arguments that
382- determines the size of args array to allocate.
383+ numba_type: type of the array to allocate,
384+ size: The size of the array to allocate.
383385
384- Returns: An LLVM IR value pointing to an array to store the kernel
385- arguments.
386+ Returns: An LLVM IR value pointing to the array.
386387 """
387- args_list = cgutils .alloca_once (
388+ return cgutils .alloca_once (
388389 self .builder ,
389- utils . LLVMTypes . byte_ptr_t ,
390- size = self .context .get_constant (types .uintp , num_kernel_args ),
390+ self . context . get_value_type ( numba_type ) ,
391+ size = self .context .get_constant (types .uintp , size ),
391392 )
392393
393- return args_list
394+ def _populate_array_from_python_list (
395+ self ,
396+ numba_type : types .Type ,
397+ py_array : list [llvmir .Instruction ],
398+ ll_array : llvmir .Instruction ,
399+ force_cast : bool = False ,
400+ ):
401+ """Populates LLVM values from an input Python list into an LLVM array.
394402
395- def _allocate_kernel_arg_ty_array (self , num_kernel_args ):
396- """Allocates an array to store the LLVM Value for the typenum for
397- every kernel argument.
403+ Args:
404+ numba_type: type of the array to allocate,
405+ py_array: array of llvm ir values to populate.
406+ ll_array: llvm ir value that represents an array to populate,
407+ force_cast: either force cast values to the provided type.
408+ """
409+ for idx , ll_value in enumerate (py_array ):
410+ ll_array_dst = self .builder .gep (
411+ ll_array ,
412+ [self .context .get_constant (types .int32 , idx )],
413+ )
414+ # bitcast may be extra, but won't hurt,
415+ if force_cast :
416+ ll_value = self .builder .bitcast (
417+ ll_value ,
418+ self .context .get_value_type (numba_type ),
419+ )
420+ self .builder .store (ll_value , ll_array_dst )
421+
422+ def _create_ll_from_py_list (
423+ self ,
424+ numba_type : types .Type ,
425+ list_of_ll_values : list [llvmir .Instruction ],
426+ force_cast : bool = False ,
427+ ) -> llvmir .Instruction :
428+ """Allocates an LLVM IR array of the same size as the input python list
429+ of LLVM IR Values and populates the array with the LLVM Values in the
430+ list.
398431
399432 Args:
400- num_kernel_args (int): The number of kernel arguments that
401- determines the size of args array to allocate.
433+ numba_type: type of the array to allocate,
434+ list_of_ll_values: list of LLVM IR values to populate,
435+ force_cast: either force cast values to the provided type.
402436
403- Returns: An LLVM IR value pointing to an array to store the kernel
404- arguments typenums as defined in dpctl.
437+ Returns: An LLVM IR value pointing to the array.
405438 """
406- args_ty_list = cgutils .alloca_once (
407- self .builder ,
408- utils .LLVMTypes .int32_t ,
409- size = self .context .get_constant (types .uintp , num_kernel_args ),
439+ ll_array = self ._allocate_array (numba_type , len (list_of_ll_values ))
440+ self ._populate_array_from_python_list (
441+ numba_type , list_of_ll_values , ll_array , force_cast
410442 )
411443
412- return args_ty_list
444+ return ll_array
413445
414446 def _create_sycl_range (self , idx_range ):
415- """Allocate a size_t[3] array to store the extents of a sycl::range.
447+ """Allocate an array to store the extents of a sycl::range.
416448
417449 Sycl supports upto 3-dimensional ranges and a such the array is
418450 statically sized to length three. Only the elements that store an actual
419451 range value are populated based on the size of the idx_range argument.
420452
421453 """
422- intp_t = utils .get_llvm_type (context = self .context , type = types .intp )
423- intp_ptr_t = utils .get_llvm_ptr_type (intp_t )
424- num_dim = len (idx_range )
454+ int64_range = [
455+ self .builder .sext (rext , utils .LLVMTypes .int64_t )
456+ if rext .type != utils .LLVMTypes .int64_t
457+ else rext
458+ for rext in idx_range
459+ ]
425460
426- # form the global range
427- range_list = cgutils .alloca_once (
428- self .builder ,
429- utils .get_llvm_type (context = self .context , type = types .uintp ),
430- size = self .context .get_constant (types .uintp , MAX_SIZE_OF_SYCL_RANGE ),
431- )
432-
433- for i in range (num_dim ):
434- rext = idx_range [i ]
435- if rext .type != utils .LLVMTypes .int64_t :
436- rext = self .builder .sext (rext , utils .LLVMTypes .int64_t )
437-
438- # we reverse the global range to account for how sycl and opencl
439- # range differs
440- self .builder .store (
441- rext ,
442- self .builder .gep (
443- range_list ,
444- [self .context .get_constant (types .uintp , (num_dim - 1 ) - i )],
445- ),
446- )
461+ # we reverse the global range to account for how sycl and opencl
462+ # range differs
463+ int64_range .reverse ()
447464
448- return self .builder . bitcast ( range_list , intp_ptr_t )
465+ return self ._create_ll_from_py_list ( types . uintp , int64_range )
449466
450467 def set_kernel (self , sycl_kernel_ref : llvmir .Instruction ):
451468 """Sets kernel to the argument list."""
@@ -597,10 +614,14 @@ def set_arguments(
597614 )
598615
599616 # Create LLVM values for the kernel args list and kernel arg types list
600- args_list = self ._allocate_kernel_arg_array (num_flattened_kernel_args )
617+ args_list = self ._allocate_array (
618+ types .voidptr ,
619+ num_flattened_kernel_args ,
620+ )
601621
602- args_ty_list = self ._allocate_kernel_arg_ty_array (
603- num_flattened_kernel_args
622+ args_ty_list = self ._allocate_array (
623+ types .int32 ,
624+ num_flattened_kernel_args ,
604625 )
605626
606627 kernel_args_ptrs = []
@@ -624,20 +645,17 @@ def set_arguments(
624645 types .uintp , num_flattened_kernel_args
625646 )
626647
627- def _extract_arguments_from_tuple (
648+ def _extract_llvm_values_from_tuple (
628649 self ,
629- ty_kernel_args_tuple : UniTuple ,
630- ll_kernel_args_tuple : llvmir .Instruction ,
650+ ll_tuple : llvmir .Instruction ,
631651 ) -> list [llvmir .Instruction ]:
632652 """Extracts LLVM IR values from llvm tuple into python array."""
633653
634- kernel_args = []
635- for pos in range (len (ty_kernel_args_tuple )):
636- kernel_args .append (
637- self .builder .extract_value (ll_kernel_args_tuple , pos )
638- )
654+ llvm_values = []
655+ for pos in range (len (ll_tuple .type )):
656+ llvm_values .append (self .builder .extract_value (ll_tuple , pos ))
639657
640- return kernel_args
658+ return llvm_values
641659
642660 def set_arguments_form_tuple (
643661 self ,
@@ -647,27 +665,45 @@ def set_arguments_form_tuple(
647665 """Sets flattened kernel args, kernel arg types and number of those
648666 arguments to the argument list based on the arguments stored in tuple.
649667 """
650- kernel_args = self ._extract_arguments_from_tuple (
651- ty_kernel_args_tuple , ll_kernel_args_tuple
652- )
668+ kernel_args = self ._extract_llvm_values_from_tuple (ll_kernel_args_tuple )
653669 self .set_arguments (ty_kernel_args_tuple , kernel_args )
654670
655- def set_dependant_event_list (self , dep_events : list [llvmir .Instruction ]):
656- """Sets dependant events to the argument list."""
657- if self .arguments .dep_events is not None :
658- return
671+ def set_dependent_events (self , dep_events : list [llvmir .Instruction ]):
672+ """Sets dependent events to the argument list."""
673+ ll_dep_events = self ._create_ll_from_py_list (types .voidptr , dep_events )
674+ self .arguments .dep_events = ll_dep_events
675+ self .arguments .dep_events_len = self .context .get_constant (
676+ types .uintp , len (dep_events )
677+ )
659678
660- if len (dep_events ) > 0 :
661- # TODO: implement for non zero input
662- raise NotImplementedError
679+ def set_dependent_events_from_tuple (
680+ self ,
681+ ty_dependent_events : UniTuple ,
682+ ll_dependent_events : llvmir .Instruction ,
683+ ):
684+ """Set's dependent events from tuple represented by LLVM IR.
663685
664- self .arguments .dep_events = self .builder .bitcast (
665- utils .create_null_ptr (builder = self .builder , context = self .context ),
666- utils .get_llvm_type (context = self .context , type = types .voidptr ),
667- )
668- self .arguments .dep_events_len = self .context .get_constant (
669- types .uintp , 0
686+ Args:
687+ ll_dependent_events: tuple of numba's data models.
688+ """
689+ if len (ty_dependent_events ) == 0 :
690+ self .set_dependent_events ([])
691+ return
692+
693+ ty_event = ty_dependent_events [0 ]
694+ dm_dependent_events = self ._extract_llvm_values_from_tuple (
695+ ll_dependent_events
670696 )
697+ dependent_events = []
698+ for dm_dependent_event in dm_dependent_events :
699+ event_struct_proxy = cgutils .create_struct_proxy (ty_event )(
700+ self .context ,
701+ self .builder ,
702+ value = dm_dependent_event ,
703+ )
704+ dependent_events .append (event_struct_proxy .event_ref )
705+
706+ self .set_dependent_events (dependent_events )
671707
672708 def submit (self ) -> llvmir .Instruction :
673709 """Submits kernel by calling sycl.dpctl_queue_submit_range or
@@ -708,22 +744,7 @@ def _allocate_meminfo_array(
708744 )
709745 ]
710746
711- meminfo_list = cgutils .alloca_once (
712- self .builder ,
713- utils .get_llvm_type (context = self .context , type = types .voidptr ),
714- size = self .context .get_constant (types .uintp , len (meminfos )),
715- )
716-
717- for meminfo_num , meminfo in enumerate (meminfos ):
718- meminfo_arg_dst = self .builder .gep (
719- meminfo_list ,
720- [self .context .get_constant (types .int32 , meminfo_num )],
721- )
722- meminfo_ptr = self .builder .bitcast (
723- meminfo ,
724- utils .get_llvm_type (context = self .context , type = types .voidptr ),
725- )
726- self .builder .store (meminfo_ptr , meminfo_arg_dst )
747+ meminfo_list = self ._create_ll_from_py_list (types .voidptr , meminfos )
727748
728749 return len (meminfos ), meminfo_list
729750
0 commit comments