1010from typing import Tuple
1111
1212import numba .core .event as ev
13+ from llvmlite .binding .value import ValueRef
1314from numba .core import errors , sigutils , types
1415from numba .core .compiler import CompileResult , Flags
1516from numba .core .compiler_lock import global_compiler_lock
1617from numba .core .dispatcher import Dispatcher , _FunctionCompiler
18+ from numba .core .funcdesc import PythonFunctionDescriptor
1719from numba .core .target_extension import dispatcher_registry , target_registry
1820from numba .core .types import void
1921from numba .core .typing .typeof import Purpose , typeof
2022
2123from numba_dpex import config , spirv_generator
24+ from numba_dpex .core .codegen import SPIRVCodeLibrary
2225from numba_dpex .core .exceptions import (
2326 ExecutionQueueInferenceError ,
2427 KernelHasReturnValueError ,
2528 UnsupportedKernelArgumentError ,
2629)
2730from numba_dpex .core .pipelines import kernel_compiler
28- from numba_dpex .core .targets .kernel_target import CompilationMode
31+ from numba_dpex .core .targets .kernel_target import (
32+ CompilationMode ,
33+ DpexKernelTargetContext ,
34+ )
2935from numba_dpex .core .types import DpnpNdArray
36+ from numba_dpex .core .utils import kernel_launcher as kl
3037
3138from .target import DPEX_KERNEL_EXP_TARGET_NAME , dpex_exp_kernel_target
3239
33- _KernelModule = namedtuple ("_KernelModule" , ["kernel_name" , "kernel_bitcode" ])
34-
3540_KernelCompileResult = namedtuple (
3641 "_KernelCompileResult" , CompileResult ._fields + ("kernel_device_ir_module" ,)
3742)
@@ -76,9 +81,14 @@ def check_queue_equivalence_of_args(
7681 )
7782
7883 def _compile_to_spirv (
79- self , kernel_library , kernel_fndesc , kernel_targetctx
84+ self ,
85+ kernel_library : SPIRVCodeLibrary ,
86+ kernel_fndesc : PythonFunctionDescriptor ,
87+ kernel_targetctx : DpexKernelTargetContext ,
8088 ):
81- kernel_func = kernel_library .get_function (kernel_fndesc .llvm_func_name )
89+ kernel_func : ValueRef = kernel_library .get_function (
90+ kernel_fndesc .llvm_func_name
91+ )
8292
8393 # Create a spir_kernel wrapper function
8494 kernel_fn = kernel_targetctx .prepare_spir_kernel (
@@ -103,11 +113,11 @@ def _compile_to_spirv(
103113 kernel_library .final_module ,
104114 kernel_library .final_module .as_bitcode (),
105115 )
106- return _KernelModule (
116+ return kl . SPIRVKernelModule (
107117 kernel_name = kernel_fn .name , kernel_bitcode = kernel_spirv_module
108118 )
109119
110- def compile (self , args , return_type ):
120+ def compile (self , args , return_type ) -> _KernelCompileResult :
111121 status , kcres = self ._compile_cached (args , return_type )
112122 if status :
113123 return kcres
@@ -160,8 +170,10 @@ def _compile_cached(
160170 self .targetoptions ["_compilation_mode" ]
161171 == CompilationMode .KERNEL
162172 ):
163- kernel_device_ir_module : _KernelModule = self ._compile_to_spirv (
164- cres .library , cres .fndesc , cres .target_context
173+ kernel_device_ir_module : kl .SPIRVKernelModule = (
174+ self ._compile_to_spirv (
175+ cres .library , cres .fndesc , cres .target_context
176+ )
165177 )
166178 else :
167179 kernel_device_ir_module = None
@@ -329,14 +341,17 @@ def cb_llvm(dur):
329341 # Add code to enable on disk caching of a binary spirv kernel.
330342 # Refer: https://github.com/IntelPython/numba-dpex/issues/1197
331343 self ._cache_misses [sig ] += 1
332- ev_details = {
333- "dispatcher" : self ,
334- "args" : args ,
335- "return_type" : return_type ,
336- }
337- with ev .trigger_event ("numba_dpex:compile" , data = ev_details ):
344+ with ev .trigger_event (
345+ "numba_dpex:compile" ,
346+ data = {
347+ "dispatcher" : self ,
348+ "args" : args ,
349+ "return_type" : return_type ,
350+ },
351+ ):
338352 try :
339- kcres : _KernelCompileResult = self ._compiler .compile (
353+ compiler : _KernelCompiler = self ._compiler
354+ kcres : _KernelCompileResult = compiler .compile (
340355 args , return_type
341356 )
342357 except errors .ForceLiteralArg as e :
0 commit comments