1212from llvmlite import ir as llvmir
1313from numba .core import cgutils , types
1414from numba .core .cpu import CPUContext
15- from numba .core .types .containers import UniTuple
15+ from numba .core .types .containers import Tuple , UniTuple
1616from numba .core .types .functions import Dispatcher
1717from numba .extending import intrinsic
1818
@@ -51,13 +51,15 @@ def _submit_kernel_async(
5151 typingctx ,
5252 ty_kernel_fn : Dispatcher ,
5353 ty_index_space : Union [RangeType , NdRangeType ],
54+ ty_dependent_events : UniTuple ,
5455 ty_kernel_args_tuple : UniTuple ,
5556):
5657 """Generates IR code for call_kernel_async dpjit function."""
5758 return _submit_kernel (
5859 typingctx ,
5960 ty_kernel_fn ,
6061 ty_index_space ,
62+ ty_dependent_events ,
6163 ty_kernel_args_tuple ,
6264 sync = False ,
6365 )
@@ -75,15 +77,17 @@ def _submit_kernel_sync(
7577 typingctx ,
7678 ty_kernel_fn ,
7779 ty_index_space ,
80+ None ,
7881 ty_kernel_args_tuple ,
7982 sync = True ,
8083 )
8184
8285
83- def _submit_kernel (
84- typingctx , # pylint: disable=W0613
86+ def _submit_kernel ( # pylint: disable=too-many-arguments
87+ typingctx , # pylint: disable=unused-argument
8588 ty_kernel_fn : Dispatcher ,
8689 ty_index_space : Union [RangeType , NdRangeType ],
90+ ty_dependent_events : UniTuple ,
8791 ty_kernel_args_tuple : UniTuple ,
8892 sync : bool ,
8993):
@@ -106,7 +110,21 @@ def _submit_kernel(
106110 ty_event = DpctlSyclEvent ()
107111 ty_return = types .Tuple ([ty_event , ty_event ])
108112
109- sig = ty_return (ty_kernel_fn , ty_index_space , ty_kernel_args_tuple )
113+ if ty_dependent_events is not None :
114+ if not isinstance (ty_dependent_events , UniTuple ) and not isinstance (
115+ ty_dependent_events , Tuple
116+ ):
117+ raise ValueError ("dependent events must be passed as a tuple" )
118+
119+ sig = ty_return (
120+ ty_kernel_fn ,
121+ ty_index_space ,
122+ ty_dependent_events ,
123+ ty_kernel_args_tuple ,
124+ )
125+ else :
126+ sig = ty_return (ty_kernel_fn , ty_index_space , ty_kernel_args_tuple )
127+
110128 kernel_sig = types .void (* ty_kernel_args_tuple )
111129 # ty_kernel_fn is type specific to exact function, so we can get function
112130 # directly from type and compile it. Thats why we don't need to get it in
@@ -123,8 +141,14 @@ def codegen(
123141 ):
124142 ty_index_space : Union [RangeType , NdRangeType ] = sig .args [1 ]
125143 ll_index_space : llvmir .Instruction = llargs [1 ]
126- ty_kernel_args_tuple : UniTuple = sig .args [2 ]
127- ll_kernel_args_tuple : llvmir .Instruction = llargs [2 ]
144+ ty_kernel_args_tuple : UniTuple = sig .args [- 1 ]
145+ ll_kernel_args_tuple : llvmir .Instruction = llargs [- 1 ]
146+
147+ if len (llargs ) == 4 :
148+ ty_dependent_events : UniTuple = sig .args [2 ]
149+ ll_dependent_events : llvmir .Instruction = llargs [2 ]
150+ else :
151+ ty_dependent_events = None
128152
129153 kl_builder = kl .KernelLaunchIRBuilder (
130154 cgctx ,
@@ -140,7 +164,13 @@ def codegen(
140164 )
141165 kl_builder .set_queue_from_arguments ()
142166 kl_builder .set_kernel_from_spirv (kernel_module )
143- kl_builder .set_dependant_event_list ([])
167+ if ty_dependent_events is None :
168+ kl_builder .set_dependent_events ([])
169+ else :
170+ kl_builder .set_dependent_events_from_tuple (
171+ ty_dependent_events ,
172+ ll_dependent_events ,
173+ )
144174 device_event_ref = kl_builder .submit ()
145175
146176 if not sync :
@@ -185,7 +215,10 @@ def call_kernel(kernel_fn, index_space, *kernel_args) -> None:
185215
186216@dpjit
187217def call_kernel_async (
188- kernel_fn , index_space , * kernel_args
218+ kernel_fn ,
219+ index_space ,
220+ dependent_events : list [dpctl .SyclEvent ],
221+ * kernel_args
189222) -> tuple [dpctl .SyclEvent , dpctl .SyclEvent ]:
190223 """Calls a numba_dpex.kernel decorated function from CPython or from another
191224 dpjit function. Kernel execution happens in asyncronous way, so the thread
@@ -210,5 +243,6 @@ def call_kernel_async(
210243 return _submit_kernel_async ( # pylint: disable=E1120
211244 kernel_fn ,
212245 index_space ,
246+ dependent_events ,
213247 kernel_args ,
214248 )
0 commit comments