Skip to content

Commit 6549887

Browse files
committed
Numba int_to_float: Remove buggy helper
* It did not handle complex values correctly * It increased compile time with the nested function
1 parent dfe7b44 commit 6549887

File tree

3 files changed

+75
-61
lines changed

3 files changed

+75
-61
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -224,36 +224,6 @@ def codegen(context, builder, signature, args):
224224
return sig, codegen
225225

226226

227-
def int_to_float_fn(inputs, out_dtype):
228-
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
229-
230-
if (
231-
all(inp.type.dtype == out_dtype for inp in inputs)
232-
and np.dtype(out_dtype).kind == "f"
233-
):
234-
235-
@numba_njit(inline="always")
236-
def inputs_cast(x):
237-
return x
238-
239-
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
240-
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
241-
242-
@numba_njit(inline="always")
243-
def inputs_cast(x):
244-
return x.astype(args_dtype)
245-
246-
else:
247-
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
248-
args_dtype = np.dtype(f"f{args_dtype_sz}")
249-
250-
@numba_njit(inline="always")
251-
def inputs_cast(x):
252-
return x.astype(args_dtype)
253-
254-
return inputs_cast
255-
256-
257227
@singledispatch
258228
def numba_typify(data, dtype=None, **kwargs):
259229
return data

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import numpy as np
55

66
import pytensor.link.numba.dispatch.basic as numba_basic
7+
from pytensor import config
78
from pytensor.link.numba.dispatch.basic import (
89
get_numba_type,
9-
int_to_float_fn,
1010
register_funcify_default_op_cache_key,
1111
)
1212
from pytensor.tensor.nlinalg import (
@@ -26,65 +26,88 @@ def numba_funcify_SVD(op, node, **kwargs):
2626
compute_uv = op.compute_uv
2727
out_dtype = np.dtype(node.outputs[0].dtype)
2828

29-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
29+
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
30+
if discrete_input and config.compiler_verbose:
31+
print("SVD requires casting discrete input to float") # noqa: T201
3032

3133
if not compute_uv:
3234

3335
@numba_basic.numba_njit
3436
def svd(x):
35-
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
37+
if discrete_input:
38+
x = x.astype(out_dtype)
39+
_, ret, _ = np.linalg.svd(x, full_matrices)
3640
return ret
3741

3842
else:
3943

4044
@numba_basic.numba_njit
4145
def svd(x):
42-
return np.linalg.svd(inputs_cast(x), full_matrices)
46+
if discrete_input:
47+
x = x.astype(out_dtype)
48+
return np.linalg.svd(x, full_matrices)
4349

44-
return svd
50+
cache_version = 1
51+
return svd, cache_version
4552

4653

4754
@register_funcify_default_op_cache_key(Det)
4855
def numba_funcify_Det(op, node, **kwargs):
4956
out_dtype = node.outputs[0].type.numpy_dtype
50-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
57+
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
58+
if discrete_input and config.compiler_verbose:
59+
print("Det requires casting discrete input to float") # noqa: T201
5160

5261
@numba_basic.numba_njit
5362
def det(x):
54-
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
63+
if discrete_input:
64+
x = x.astype(out_dtype)
65+
return np.array(np.linalg.det(x), dtype=out_dtype)
5566

56-
return det
67+
cache_version = 1
68+
return det, cache_version
5769

5870

5971
@register_funcify_default_op_cache_key(SLogDet)
6072
def numba_funcify_SLogDet(op, node, **kwargs):
61-
out_dtype_1 = node.outputs[0].type.numpy_dtype
62-
out_dtype_2 = node.outputs[1].type.numpy_dtype
73+
out_dtype_sign = node.outputs[0].type.numpy_dtype
74+
out_dtype_det = node.outputs[1].type.numpy_dtype
6375

64-
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
76+
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
77+
if discrete_input and config.compiler_verbose:
78+
print("SLogDet requires casting discrete input to float") # noqa: T201
6579

6680
@numba_basic.numba_njit
6781
def slogdet(x):
68-
sign, det = np.linalg.slogdet(inputs_cast(x))
82+
if discrete_input:
83+
x = x.astype(out_dtype_det)
84+
sign, det = np.linalg.slogdet(x)
6985
return (
70-
np.array(sign).astype(out_dtype_1),
71-
np.array(det).astype(out_dtype_2),
86+
np.array(sign, dtype=out_dtype_sign),
87+
np.array(det, dtype=out_dtype_det),
7288
)
7389

74-
return slogdet
90+
cache_version = 1
91+
return slogdet, cache_version
7592

7693

7794
@register_funcify_default_op_cache_key(Eig)
7895
def numba_funcify_Eig(op, node, **kwargs):
7996
w_dtype = node.outputs[0].type.numpy_dtype
80-
inputs_cast = int_to_float_fn(node.inputs, w_dtype)
97+
non_complex_input = node.inputs[0].type.numpy_dtype.kind != "c"
98+
if non_complex_input and config.compiler_verbose:
99+
print("Eig requires casting input to complex") # noqa: T201
81100

82101
@numba_basic.numba_njit
83102
def eig(x):
84-
w, v = np.linalg.eig(inputs_cast(x))
103+
if non_complex_input:
104+
# Even floats are better cast to complex, otherwise numba may raise
105+
# ValueError: eig() argument must not cause a domain change.
106+
x = x.astype(w_dtype)
107+
w, v = np.linalg.eig(x)
85108
return w.astype(w_dtype), v.astype(w_dtype)
86109

87-
cache_version = 1
110+
cache_version = 2
88111
return eig, cache_version
89112

90113

@@ -125,22 +148,32 @@ def eigh(x):
125148
@register_funcify_default_op_cache_key(MatrixInverse)
126149
def numba_funcify_MatrixInverse(op, node, **kwargs):
127150
out_dtype = node.outputs[0].type.numpy_dtype
128-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
151+
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
152+
if discrete_input and config.compiler_verbose:
153+
print("MatrixInverse requires casting discrete input to float") # noqa: T201
129154

130155
@numba_basic.numba_njit
131156
def matrix_inverse(x):
132-
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
157+
if discrete_input:
158+
x = x.astype(out_dtype)
159+
return np.linalg.inv(x)
133160

134-
return matrix_inverse
161+
cache_version = 1
162+
return matrix_inverse, cache_version
135163

136164

137165
@register_funcify_default_op_cache_key(MatrixPinv)
138166
def numba_funcify_MatrixPinv(op, node, **kwargs):
139167
out_dtype = node.outputs[0].type.numpy_dtype
140-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
168+
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
169+
if discrete_input and config.compiler_verbose:
170+
print("MatrixPinv requires casting discrete input to float") # noqa: T201
141171

142172
@numba_basic.numba_njit
143-
def matrixpinv(x):
144-
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
173+
def matrix_pinv(x):
174+
if discrete_input:
175+
x = x.astype(out_dtype)
176+
return np.linalg.pinv(x)
145177

146-
return matrixpinv
178+
cache_version = 1
179+
return matrix_pinv, cache_version

tests/link/numba/test_nlinalg.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55

66
import pytensor.tensor as pt
7-
from pytensor import config
87
from pytensor.tensor import nlinalg
98
from tests.link.numba.test_basic import compare_numba_and_py
109

@@ -52,23 +51,35 @@ def test_Det_SLogDet(op, dtype):
5251
)
5352

5453

55-
@pytest.mark.parametrize("input_dtype", ["float", "int"])
54+
@pytest.mark.parametrize("input_dtype", ["int64", "float64", "complex128"])
5655
@pytest.mark.parametrize("symmetric", [True, False], ids=["symmetric", "general"])
5756
def test_Eig(input_dtype, symmetric):
58-
x = pt.dmatrix("x")
59-
if input_dtype == "float":
60-
x_val = rng.normal(size=(3, 3)).astype(config.floatX)
57+
x = pt.matrix("x", dtype=input_dtype)
58+
if x.type.numpy_dtype.kind in "fc":
59+
x_val = rng.normal(size=(3, 3)).astype(input_dtype)
6160
else:
6261
x_val = rng.integers(1, 10, size=(3, 3)).astype("int64")
6362

6463
if symmetric:
6564
x_val = x_val + x_val.T
6665

66+
def assert_fn(x, y):
67+
# eig can return equivalent values with some sign flips depending on impl, allow for that
68+
np.testing.assert_allclose(np.abs(x), np.abs(y), strict=True)
69+
6770
g = nlinalg.eig(x)
68-
compare_numba_and_py(
71+
_, [eigen_values, eigen_vectors] = compare_numba_and_py(
6972
graph_inputs=[x],
7073
graph_outputs=g,
7174
test_inputs=[x_val],
75+
assert_fn=assert_fn,
76+
)
77+
# Check eig is correct
78+
np.testing.assert_allclose(
79+
x_val @ eigen_vectors,
80+
eigen_vectors @ np.diag(eigen_values),
81+
atol=1e-7,
82+
rtol=1e-5,
7283
)
7384

7485

0 commit comments

Comments
 (0)