Skip to content

Commit c815849

Browse files
committed
Numba linalg: Fallback to objmode with complex inputs
1 parent 39265a2 commit c815849

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor import config
66
from pytensor.link.numba.dispatch import basic as numba_basic
77
from pytensor.link.numba.dispatch.basic import (
8+
generate_fallback_impl,
89
numba_funcify,
910
register_funcify_default_op_cache_key,
1011
)
@@ -44,12 +45,6 @@
4445
from pytensor.tensor.type import complex_dtypes, integer_dtypes
4546

4647

47-
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
48-
"Complex dtype for {op} not supported in numba mode. "
49-
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
50-
)
51-
52-
5348
@numba_funcify.register(Cholesky)
5449
def numba_funcify_Cholesky(op, node, **kwargs):
5550
"""
@@ -66,7 +61,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6661
dtype = node.inputs[0].dtype
6762
out_dtype = node.outputs[0].dtype
6863
if dtype in complex_dtypes:
69-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
64+
return generate_fallback_impl(op, node=node, **kwargs)
7065

7166
@numba_basic.numba_njit
7267
def cholesky(a):
@@ -124,7 +119,7 @@ def numba_funcify_LU(op, node, **kwargs):
124119

125120
dtype = node.inputs[0].dtype
126121
if dtype in complex_dtypes:
127-
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
122+
return generate_fallback_impl(op, node=node, **kwargs)
128123

129124
@numba_basic.numba_njit
130125
def lu(a):
@@ -172,7 +167,7 @@ def numba_funcify_LUFactor(op, node, **kwargs):
172167
overwrite_a = op.overwrite_a
173168

174169
if dtype in complex_dtypes:
175-
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
170+
return generate_fallback_impl(op, node=node)
176171

177172
@numba_basic.numba_njit
178173
def lu_factor(a):
@@ -228,7 +223,7 @@ def numba_funcify_Solve(op, node, **kwargs):
228223

229224
dtype = node.inputs[0].dtype
230225
if dtype in complex_dtypes:
231-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
226+
return generate_fallback_impl(op, node=node)
232227

233228
if assume_a == "gen":
234229
solve_fn = _solve_gen
@@ -277,9 +272,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
277272

278273
dtype = node.inputs[0].dtype
279274
if dtype in complex_dtypes:
280-
raise NotImplementedError(
281-
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
282-
)
275+
return generate_fallback_impl(op, node=node)
283276

284277
@numba_basic.numba_njit
285278
def solve_triangular(a, b):
@@ -317,7 +310,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
317310
out_dtype = node.outputs[0].type.numpy_dtype
318311
c, b = node.inputs
319312
if c.dtype in complex_dtypes or b.dtype in complex_dtypes:
320-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
313+
return generate_fallback_impl(op, node=node, **kwargs)
321314

322315
must_cast_c = c.type.numpy_dtype.kind in "ibu" or (
323316
c.type.numpy_dtype.itemsize < out_dtype.itemsize
@@ -368,7 +361,7 @@ def numba_funcify_QR(op, node, **kwargs):
368361

369362
dtype = node.inputs[0].dtype
370363
if dtype in complex_dtypes:
371-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
364+
return generate_fallback_impl(op, node=node, **kwargs)
372365

373366
integer_input = dtype in integer_dtypes
374367
in_dtype = config.floatX if integer_input else dtype

0 commit comments

Comments
 (0)