55"""Module that contains numba style wrapper around sycl kernel submit."""
66
77from dataclasses import dataclass
8+ from functools import cached_property
9+ from typing import NamedTuple , Union
810
911import dpctl
1012from llvmlite import ir as llvmir
1113from llvmlite .ir .builder import IRBuilder
1214from numba .core import cgutils , types
1315from numba .core .cpu import CPUContext
1416from numba .core .datamodel import DataModelManager
17+ from numba .core .types .containers import UniTuple
1518
1619from numba_dpex import config , utils
20+ from numba_dpex .core .exceptions import UnreachableError
1721from numba_dpex .core .runtime .context import DpexRTContext
1822from numba_dpex .core .types import DpnpNdArray
23+ from numba_dpex .core .types .range_types import NdRangeType , RangeType
1924from numba_dpex .dpctl_iface import libsyclinterface_bindings as sycl
2025from numba_dpex .dpctl_iface ._helpers import numba_type_to_dpctl_typenum
26+ from numba_dpex .utils import create_null_ptr
2127
2228MAX_SIZE_OF_SYCL_RANGE = 3
2329
2430
31+ # TODO: probably not best place for it. Should be in kernel_dispatcher once we
32+ # get merge experimental. Right now it will cause cyclic import
33+ class SPIRVKernelModule (NamedTuple ):
34+ """Represents SPIRV binary code and function name in this binary"""
35+
36+ kernel_name : str
37+ kernel_bitcode : bytes
38+
39+
2540@dataclass
2641class _KernelLaunchIRArguments : # pylint: disable=too-many-instance-attributes
2742 """List of kernel launch arguments used in sycl.dpctl_queue_submit_range and
@@ -62,6 +77,22 @@ def to_list(self):
6277 return res
6378
6479
80+ @dataclass
81+ class _KernelLaunchIRCachedArguments :
82+ """Arguments that are being used in KernelLaunchIRBuilder that are either
83+ intermediate structure of the KernelLaunchIRBuilder like llvm IR array
84+ stored as a python array of llvm IR values or llvm IR values that may be
85+ used as an input for builder functions.
86+
87+ Main goal is to prevent passing same argument during build process several
88+ times and to avoid passing output of the builder as an argument for another
89+ build method."""
90+
91+ arg_list : list [llvmir .Instruction ] = None
92+ arg_ty_list : list [types .Type ] = None
93+ device_event_ref : llvmir .Instruction = None
94+
95+
6596class KernelLaunchIRBuilder :
6697 """
6798 KernelLaunchIRBuilder(lowerer, cres)
@@ -86,9 +117,16 @@ def __init__(
86117 """
87118 self .context = context
88119 self .builder = builder
89- self .rtctx = DpexRTContext (self .context )
90120 self .arguments = _KernelLaunchIRArguments ()
121+ self .cached_arguments = _KernelLaunchIRCachedArguments ()
91122 self .kernel_dmm = kernel_dmm
123+ self ._cleanups = []
124+
125+ @cached_property
126+ def dpexrt (self ):
127+ """Dpex runtime context."""
128+
129+ return DpexRTContext (self .context )
92130
93131 def _build_nullptr (self ):
94132 """Builds the LLVM IR to represent a null pointer.
@@ -329,7 +367,7 @@ def get_queue(self, exec_queue: dpctl.SyclQueue) -> llvmir.Instruction:
329367 # Store the queue returned by DPEXRTQueue_CreateFromFilterString in a
330368 # local variable
331369 self .builder .store (
332- self .rtctx .get_queue_from_filter_string (
370+ self .dpexrt .get_queue_from_filter_string (
333371 builder = self .builder , device = device
334372 ),
335373 sycl_queue_val ,
@@ -413,10 +451,76 @@ def set_kernel(self, sycl_kernel_ref: llvmir.Instruction):
413451 """Sets kernel to the argument list."""
414452 self .arguments .sycl_kernel_ref = sycl_kernel_ref
415453
454+ def set_kernel_from_spirv (self , kernel_module : SPIRVKernelModule ):
455+ """Sets kernel to the argument list from the SPIRV bytecode.
456+
457+ It pastes bytecode as a constant string and create kernel bundle from it
458+ using SYCL API. It caches kernel, so it won't be sent to device second
459+ time.
460+ """
461+ # Inserts a global constant byte string in the current LLVM module to
462+ # store the passed in SPIR-V binary blob.
463+ queue_ref = self .arguments .sycl_queue_ref
464+
465+ kernel_bc_byte_str = self .context .insert_const_bytes (
466+ self .builder .module ,
467+ bytes = kernel_module .kernel_bitcode ,
468+ )
469+
470+ kernel_name = self .context .insert_const_string (
471+ self .builder .module , kernel_module .kernel_name
472+ )
473+
474+ context_ref = sycl .dpctl_queue_get_context (self .builder , queue_ref )
475+ device_ref = sycl .dpctl_queue_get_device (self .builder , queue_ref )
476+
477+ # build_or_get_kernel steals reference to context and device cause it
478+ # needs to keep them alive for keys.
479+ kernel_ref = self .dpexrt .build_or_get_kernel (
480+ self .builder ,
481+ [
482+ context_ref ,
483+ device_ref ,
484+ llvmir .Constant (
485+ llvmir .IntType (64 ), hash (kernel_module .kernel_bitcode )
486+ ),
487+ kernel_bc_byte_str ,
488+ llvmir .Constant (
489+ llvmir .IntType (64 ), len (kernel_module .kernel_bitcode )
490+ ),
491+ self .builder .load (create_null_ptr (self .builder , self .context )),
492+ kernel_name ,
493+ ],
494+ )
495+
496+ self ._cleanups .append (self ._clean_kernel_ref )
497+ self .set_kernel (kernel_ref )
498+
499+ def _clean_kernel_ref (self ):
500+ sycl .dpctl_kernel_delete (self .builder , self .arguments .sycl_kernel_ref )
501+ self .arguments .sycl_kernel_ref = None
502+
416503 def set_queue (self , sycl_queue_ref : llvmir .Instruction ):
417504 """Sets queue to the argument list."""
418505 self .arguments .sycl_queue_ref = sycl_queue_ref
419506
507+ def set_queue_from_arguments (
508+ self ,
509+ ):
510+ """Sets the sycl queue from the first DpnpNdArray argument provided
511+ earlier."""
512+ queue_ref = get_queue_from_llvm_values (
513+ self .context ,
514+ self .builder ,
515+ self .cached_arguments .arg_ty_list ,
516+ self .cached_arguments .arg_list ,
517+ )
518+
519+ if queue_ref is None :
520+ raise ValueError ("There are no arguments that contain queue" )
521+
522+ self .set_queue (queue_ref )
523+
420524 def set_range (
421525 self ,
422526 global_range : list ,
@@ -430,10 +534,52 @@ def set_range(
430534 types .uintp , len (global_range )
431535 )
432536
537+ def set_range_from_indexer (
538+ self ,
539+ ty_indexer_arg : Union [RangeType , NdRangeType ],
540+ ll_index_arg : llvmir .BaseStructType ,
541+ ):
542+ """Returns two lists of LLVM IR Values that hold the unboxed extents of
543+ a Python Range or NdRange object.
544+ """
545+ ndim = ty_indexer_arg .ndim
546+ global_range_extents = []
547+ local_range_extents = []
548+ indexer_datamodel = self .context .data_model_manager .lookup (
549+ ty_indexer_arg
550+ )
551+
552+ if isinstance (ty_indexer_arg , RangeType ):
553+ for dim_num in range (ndim ):
554+ dim_pos = indexer_datamodel .get_field_position (
555+ "dim" + str (dim_num )
556+ )
557+ global_range_extents .append (
558+ self .builder .extract_value (ll_index_arg , dim_pos )
559+ )
560+ elif isinstance (ty_indexer_arg , NdRangeType ):
561+ for dim_num in range (ndim ):
562+ gdim_pos = indexer_datamodel .get_field_position (
563+ "gdim" + str (dim_num )
564+ )
565+ global_range_extents .append (
566+ self .builder .extract_value (ll_index_arg , gdim_pos )
567+ )
568+ ldim_pos = indexer_datamodel .get_field_position (
569+ "ldim" + str (dim_num )
570+ )
571+ local_range_extents .append (
572+ self .builder .extract_value (ll_index_arg , ldim_pos )
573+ )
574+ else :
575+ raise UnreachableError
576+
577+ self .set_range (global_range_extents , local_range_extents )
578+
433579 def set_arguments (
434580 self ,
435- ty_kernel_args : list ,
436- kernel_args : list ,
581+ ty_kernel_args : list [ types . Type ] ,
582+ kernel_args : list [ llvmir . Instruction ] ,
437583 ):
438584 """Sets flattened kernel args, kernel arg types and number of those
439585 arguments to the argument list."""
@@ -443,6 +589,9 @@ def set_arguments(
443589 "DPEX-DEBUG: Populating kernel args and arg type arrays.\n " ,
444590 )
445591
592+ self .cached_arguments .arg_ty_list = ty_kernel_args
593+ self .cached_arguments .arg_list = kernel_args
594+
446595 num_flattened_kernel_args = self ._get_num_flattened_kernel_args (
447596 kernel_argtys = ty_kernel_args ,
448597 )
@@ -475,6 +624,34 @@ def set_arguments(
475624 types .uintp , num_flattened_kernel_args
476625 )
477626
627+ def _extract_arguments_from_tuple (
628+ self ,
629+ ty_kernel_args_tuple : UniTuple ,
630+ ll_kernel_args_tuple : llvmir .Instruction ,
631+ ) -> list [llvmir .Instruction ]:
632+ """Extracts LLVM IR values from llvm tuple into python array."""
633+
634+ kernel_args = []
635+ for pos in range (len (ty_kernel_args_tuple )):
636+ kernel_args .append (
637+ self .builder .extract_value (ll_kernel_args_tuple , pos )
638+ )
639+
640+ return kernel_args
641+
642+ def set_arguments_form_tuple (
643+ self ,
644+ ty_kernel_args_tuple : UniTuple ,
645+ ll_kernel_args_tuple : llvmir .Instruction ,
646+ ):
647+ """Sets flattened kernel args, kernel arg types and number of those
648+ arguments to the argument list based on the arguments stored in tuple.
649+ """
650+ kernel_args = self ._extract_arguments_from_tuple (
651+ ty_kernel_args_tuple , ll_kernel_args_tuple
652+ )
653+ self .set_arguments (ty_kernel_args_tuple , kernel_args )
654+
478655 def set_dependant_event_list (self , dep_events : list [llvmir .Instruction ]):
479656 """Sets dependant events to the argument list."""
480657 if self .arguments .dep_events is not None :
@@ -499,11 +676,86 @@ def submit(self) -> llvmir.Instruction:
499676 args = self .arguments .to_list ()
500677
501678 if self .arguments .local_range is None :
502- eref = sycl .dpctl_queue_submit_range (self .builder , * args )
679+ event_ref = sycl .dpctl_queue_submit_range (self .builder , * args )
503680 else :
504- eref = sycl .dpctl_queue_submit_ndrange (self .builder , * args )
681+ event_ref = sycl .dpctl_queue_submit_ndrange (self .builder , * args )
682+
683+ self .cached_arguments .device_event_ref = event_ref
505684
506- return eref
685+ for cleanup in self ._cleanups :
686+ cleanup ()
687+
688+ return event_ref
689+
690+ def _allocate_meminfo_array (
691+ self ,
692+ ) -> tuple [int , list [llvmir .Instruction ]]:
693+ """Allocates an LLVM array value to store each memory info from all
694+ kernel arguments. The array is the populated with the LLVM value for
695+ every meminfo of the kernel arguments.
696+ """
697+ kernel_args = self .cached_arguments .arg_list
698+ kernel_argtys = self .cached_arguments .arg_ty_list
699+
700+ meminfos = []
701+ for arg_num , argtype in enumerate (kernel_argtys ):
702+ llvm_val = kernel_args [arg_num ]
703+
704+ meminfos += [
705+ meminfo
706+ for ty , meminfo in self .context .nrt .get_meminfos (
707+ self .builder , argtype , llvm_val
708+ )
709+ ]
710+
711+ meminfo_list = cgutils .alloca_once (
712+ self .builder ,
713+ utils .get_llvm_type (context = self .context , type = types .voidptr ),
714+ size = self .context .get_constant (types .uintp , len (meminfos )),
715+ )
716+
717+ for meminfo_num , meminfo in enumerate (meminfos ):
718+ meminfo_arg_dst = self .builder .gep (
719+ meminfo_list ,
720+ [self .context .get_constant (types .int32 , meminfo_num )],
721+ )
722+ meminfo_ptr = self .builder .bitcast (
723+ meminfo ,
724+ utils .get_llvm_type (context = self .context , type = types .voidptr ),
725+ )
726+ self .builder .store (meminfo_ptr , meminfo_arg_dst )
727+
728+ return len (meminfos ), meminfo_list
729+
730+ def acquire_meminfo_and_submit_release (
731+ self ,
732+ ) -> llvmir .Instruction :
733+ """Schedule sycl host task to release nrt meminfo of the arguments used
734+ to run job. Use it to keep arguments alive during kernel execution."""
735+ queue_ref = self .arguments .sycl_queue_ref
736+ event_ref = self .cached_arguments .device_event_ref
737+
738+ total_meminfos , meminfo_list = self ._allocate_meminfo_array ()
739+
740+ event_ref_ptr = self .builder .alloca (event_ref .type )
741+ self .builder .store (event_ref , event_ref_ptr )
742+
743+ status_ptr = cgutils .alloca_once (
744+ self .builder , self .context .get_value_type (types .uint64 )
745+ )
746+ host_eref = self .dpexrt .acquire_meminfo_and_schedule_release (
747+ self .builder ,
748+ [
749+ self .context .nrt .get_nrt_api (self .builder ),
750+ queue_ref ,
751+ meminfo_list ,
752+ self .context .get_constant (types .uintp , total_meminfos ),
753+ event_ref_ptr ,
754+ self .context .get_constant (types .uintp , 1 ),
755+ status_ptr ,
756+ ],
757+ )
758+ return host_eref
507759
508760 def _get_num_flattened_kernel_args (
509761 self ,
@@ -571,3 +823,26 @@ def _populate_kernel_args_and_args_ty_arrays(
571823 kernel_arg_num ,
572824 )
573825 kernel_arg_num += 1
826+
827+
828+ def get_queue_from_llvm_values (
829+ ctx : CPUContext ,
830+ builder : IRBuilder ,
831+ ty_kernel_args : list [types .Type ],
832+ ll_kernel_args : list [llvmir .Instruction ],
833+ ):
834+ """
835+ Get the sycl queue from the first DpnpNdArray argument. Prior passes
836+ before lowering make sure that compute-follows-data is enforceable
837+ for a specific call to a kernel. As such, at the stage of lowering
838+ the queue from the first DpnpNdArray argument can be extracted.
839+ """
840+ for arg_num , argty in enumerate (ty_kernel_args ):
841+ if isinstance (argty , DpnpNdArray ):
842+ llvm_val = ll_kernel_args [arg_num ]
843+ datamodel = ctx .data_model_manager .lookup (argty )
844+ sycl_queue_attr_pos = datamodel .get_field_position ("sycl_queue" )
845+ queue_ref = builder .extract_value (llvm_val , sycl_queue_attr_pos )
846+ break
847+
848+ return queue_ref
0 commit comments