55from pytensor import config
66from pytensor .link .numba .dispatch import basic as numba_basic
77from pytensor .link .numba .dispatch .basic import (
8+ generate_fallback_impl ,
89 numba_funcify ,
910 register_funcify_default_op_cache_key ,
1011)
4445from pytensor .tensor .type import complex_dtypes , integer_dtypes
4546
4647
47- _COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
48- "Complex dtype for {op} not supported in numba mode. "
49- "If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
50- )
51-
52-
5348@numba_funcify .register (Cholesky )
5449def numba_funcify_Cholesky (op , node , ** kwargs ):
5550 """
@@ -66,7 +61,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6661 dtype = node .inputs [0 ].dtype
6762 out_dtype = node .outputs [0 ].dtype
6863 if dtype in complex_dtypes :
69- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
64+ return generate_fallback_impl ( op , node = node , ** kwargs )
7065
7166 @numba_basic .numba_njit
7267 def cholesky (a ):
@@ -124,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
124119
125120 dtype = node .inputs [0 ].dtype
126121 if dtype in complex_dtypes :
127- NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
122+ return generate_fallback_impl ( op , node = node , ** kwargs )
128123
129124 @numba_basic .numba_njit
130125 def lu (a ):
@@ -172,7 +167,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
172167 overwrite_a = op .overwrite_a
173168
174169 if dtype in complex_dtypes :
175- NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
170+ return generate_fallback_impl ( op , node = node )
176171
177172 @numba_basic .numba_njit
178173 def lu_factor (a ):
@@ -228,7 +223,7 @@ def numba_funcify_Solve(op, node, **kwargs):
228223
229224 dtype = node .inputs [0 ].dtype
230225 if dtype in complex_dtypes :
231- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
226+ return generate_fallback_impl ( op , node = node )
232227
233228 if assume_a == "gen" :
234229 solve_fn = _solve_gen
@@ -277,9 +272,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
277272
278273 dtype = node .inputs [0 ].dtype
279274 if dtype in complex_dtypes :
280- raise NotImplementedError (
281- _COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = "Solve Triangular" )
282- )
275+ return generate_fallback_impl (op , node = node )
283276
284277 @numba_basic .numba_njit
285278 def solve_triangular (a , b ):
@@ -317,7 +310,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
317310 out_dtype = node .outputs [0 ].type .numpy_dtype
318311 c , b = node .inputs
319312 if c .dtype in complex_dtypes or b .dtype in complex_dtypes :
320- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
313+ return generate_fallback_impl ( op , node = node , ** kwargs )
321314
322315 must_cast_c = c .type .numpy_dtype .kind in "ibu" or (
323316 c .type .numpy_dtype .itemsize < out_dtype .itemsize
@@ -368,7 +361,7 @@ def numba_funcify_QR(op, node, **kwargs):
368361
369362 dtype = node .inputs [0 ].dtype
370363 if dtype in complex_dtypes :
371- raise NotImplementedError ( _COMPLEX_DTYPE_NOT_SUPPORTED_MSG . format ( op = op ) )
364+ return generate_fallback_impl ( op , node = node , ** kwargs )
372365
373366 integer_input = dtype in integer_dtypes
374367 in_dtype = config .floatX if integer_input else dtype
0 commit comments