Skip to content

Commit 6ad3a19

Browse files
committed
.Numba linalg: handle empty inputs and dtypes carefully
1 parent c4070fe commit 6ad3a19

File tree

14 files changed

+234
-175
lines changed

14 files changed

+234
-175
lines changed

pytensor/link/numba/dispatch/linalg/decomposition/cholesky.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from pytensor.link.numba.dispatch.linalg._LAPACK import (
77
_LAPACK,
8-
_get_underlying_float,
98
int_ptr_to_val,
109
val_to_int_ptr,
1110
)
@@ -26,7 +25,7 @@ def cholesky_impl(A, lower=0, overwrite_a=False, check_finite=True):
2625
ensure_lapack()
2726
_check_scipy_linalg_matrix(A, "cholesky")
2827
dtype = A.dtype
29-
w_type = _get_underlying_float(dtype)
28+
3029
numba_potrf = _LAPACK().numba_xpotrf(dtype)
3130

3231
def impl(A, lower=False, overwrite_a=False, check_finite=True):
@@ -53,7 +52,7 @@ def impl(A, lower=False, overwrite_a=False, check_finite=True):
5352
numba_potrf(
5453
UPLO,
5554
N,
56-
A_copy.view(w_type).ctypes,
55+
A_copy.ctypes,
5756
LDA,
5857
INFO,
5958
)

pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from pytensor.link.numba.dispatch.linalg._LAPACK import (
1010
_LAPACK,
11-
_get_underlying_float,
1211
int_ptr_to_val,
1312
val_to_int_ptr,
1413
)
@@ -40,7 +39,6 @@ def getrf_impl(
4039
ensure_lapack()
4140
_check_scipy_linalg_matrix(A, "getrf")
4241
dtype = A.dtype
43-
w_type = _get_underlying_float(dtype)
4442
numba_getrf = _LAPACK().numba_xgetrf(dtype)
4543

4644
def impl(
@@ -59,7 +57,7 @@ def impl(
5957
IPIV = np.empty(_N, dtype=np.int32) # type: ignore
6058
INFO = val_to_int_ptr(0)
6159

62-
numba_getrf(M, N, A_copy.view(w_type).ctypes, LDA, IPIV.ctypes, INFO)
60+
numba_getrf(M, N, A_copy.ctypes, LDA, IPIV.ctypes, INFO)
6361

6462
return A_copy, IPIV, int_ptr_to_val(INFO)
6563

0 commit comments

Comments
 (0)