Skip to content

Commit 76bcd86

Browse files
committed
Numba eigh: Cast to promised dtype
1 parent c815849 commit 76bcd86

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def numba_funcify_Eig(op, node, **kwargs):
102102
def eig(x):
103103
if discrete_input:
104104
x = x.astype(w_dtype)
105-
return np.linalg.eig(x)
105+
w, v = np.linalg.eig(x)
106+
return w.astype(w_dtype), v.astype(w_dtype)
106107

107108
cache_version = 1
108109
return eig, cache_version

0 commit comments

Comments
 (0)