@@ -64,11 +64,15 @@ def numba_funcify_Cholesky(op, node, **kwargs):
6464 on_error = op .on_error
6565
6666 dtype = node .inputs [0 ].dtype
67+ out_dtype = node .outputs [0 ].dtype
6768 if dtype in complex_dtypes :
6869 raise NotImplementedError (_COMPLEX_DTYPE_NOT_SUPPORTED_MSG .format (op = op ))
6970
7071 @numba_basic .numba_njit
7172 def cholesky (a ):
73+ if a .size == 0 :
74+ return np .zeros (a .shape , dtype = out_dtype )
75+
7276 if check_finite :
7377 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
7478 raise np .linalg .LinAlgError (
@@ -163,6 +167,7 @@ def lu(a):
163167@numba_funcify .register (LUFactor )
164168def numba_funcify_LUFactor (op , node , ** kwargs ):
165169 dtype = node .inputs [0 ].dtype
170+ out_dtype_np = node .outputs [0 ].type .numpy_dtype
166171 check_finite = op .check_finite
167172 overwrite_a = op .overwrite_a
168173
@@ -171,6 +176,12 @@ def numba_funcify_LUFactor(op, node, **kwargs):
171176
172177 @numba_basic .numba_njit
173178 def lu_factor (a ):
179+ if a .size == 0 :
180+ return (
181+ np .zeros (a .shape , dtype = out_dtype_np ),
182+ np .zeros (a .shape [0 ], dtype = "int32" ),
183+ )
184+
174185 if check_finite :
175186 if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
176187 raise np .linalg .LinAlgError (
0 commit comments