2020from numba .core .types import void
2121from numba .core .typing .typeof import Purpose , typeof
2222
23- from numba_dpex import config , spirv_generator
23+ from numba_dpex import config , numba_sem_version , spirv_generator
2424from numba_dpex .core .codegen import SPIRVCodeLibrary
2525from numba_dpex .core .exceptions import (
2626 ExecutionQueueInferenceError ,
@@ -220,8 +220,6 @@ class KernelDispatcher(Dispatcher):
220220 targetdescr = dpex_exp_kernel_target
221221 _fold_args = False
222222
223- Dispatcher ._impl_kinds ["kernel" ] = _KernelCompiler
224-
225223 def __init__ (
226224 self ,
227225 pyfunc ,
@@ -240,12 +238,27 @@ def __init__(
240238
241239 self ._kernel_name = pyfunc .__name__
242240
243- super ().__init__ (
244- py_func = pyfunc ,
245- locals = local_vars_to_numba_types ,
246- impl_kind = "kernel" ,
247- targetoptions = targetoptions ,
248- pipeline_class = pipeline_class ,
241+ if numba_sem_version < (0 , 59 , 0 ):
242+ super ().__init__ (
243+ py_func = pyfunc ,
244+ locals = local_vars_to_numba_types ,
245+ impl_kind = "direct" ,
246+ targetoptions = targetoptions ,
247+ pipeline_class = pipeline_class ,
248+ )
249+ else :
250+ super ().__init__ (
251+ py_func = pyfunc ,
252+ locals = local_vars_to_numba_types ,
253+ targetoptions = targetoptions ,
254+ pipeline_class = pipeline_class ,
255+ )
256+ self ._compiler = _KernelCompiler (
257+ pyfunc ,
258+ self .targetdescr ,
259+ targetoptions ,
260+ local_vars_to_numba_types ,
261+ pipeline_class ,
249262 )
250263
251264 def typeof_pyval (self , val ):
0 commit comments