Skip to content

Commit 39265a2

Browse files
committed
Numba linalg: Handle empty inputs
1 parent d24c3a3 commit 39265a2

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6464
on_error = op.on_error
6565

6666
dtype = node.inputs[0].dtype
67+
out_dtype = node.outputs[0].dtype
6768
if dtype in complex_dtypes:
6869
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
6970

7071
@numba_basic.numba_njit
7172
def cholesky(a):
73+
if a.size == 0:
74+
return np.zeros(a.shape, dtype=out_dtype)
75+
7276
if check_finite:
7377
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
7478
raise np.linalg.LinAlgError(
@@ -163,6 +167,7 @@ def lu(a):
163167
@numba_funcify.register(LUFactor)
164168
def numba_funcify_LUFactor(op, node, **kwargs):
165169
dtype = node.inputs[0].dtype
170+
out_dtype_np = node.outputs[0].type.numpy_dtype
166171
check_finite = op.check_finite
167172
overwrite_a = op.overwrite_a
168173

@@ -171,6 +176,12 @@ def numba_funcify_LUFactor(op, node, **kwargs):
171176

172177
@numba_basic.numba_njit
173178
def lu_factor(a):
179+
if a.size == 0:
180+
return (
181+
np.zeros(a.shape, dtype=out_dtype_np),
182+
np.zeros(a.shape[0], dtype="int32"),
183+
)
184+
174185
if check_finite:
175186
if np.any(np.bitwise_or(np.isinf(a), np.isnan(a))):
176187
raise np.linalg.LinAlgError(

0 commit comments

Comments
 (0)