Skip to content

Commit 370b172

Browse files
authored
Numba Pow: Fix failure with discrete integer exponents (#1758)
Workaround for: numba/numba#9554
1 parent ae499a4 commit 370b172

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

pytensor/link/numba/dispatch/scalar.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Composite,
2323
Identity,
2424
Mul,
25+
Pow,
2526
Reciprocal,
2627
ScalarOp,
2728
Second,
@@ -165,6 +166,23 @@ def {binary_op_name}({input_signature}):
165166
return nary_fn
166167

167168

169+
@register_funcify_and_cache_key(Pow)
170+
def numba_funcify_Pow(op, node, **kwargs):
171+
pow_dtype = node.inputs[1].type.dtype
172+
if pow_dtype.startswith("int"):
173+
# Numba power fails when exponents are non 64-bit discrete integers and fasthmath=True
174+
# https://github.com/numba/numba/issues/9554
175+
176+
def pow(x, y):
177+
return x ** np.asarray(y, dtype=np.int64).item()
178+
else:
179+
180+
def pow(x, y):
181+
return x**y
182+
183+
return numba_basic.numba_njit(pow), scalar_op_cache_key(op)
184+
185+
168186
@register_funcify_and_cache_key(Add)
169187
def numba_funcify_Add(op, node, **kwargs):
170188
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

tests/link/numba/test_scalar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,16 @@ def test_Softplus(dtype):
189189
)
190190

191191

192+
def test_discrete_power():
193+
# Test we don't fail to compile power with discrete exponents due to https://github.com/numba/numba/issues/9554
194+
x = pt.scalar("x", dtype="float64")
195+
exponent = pt.scalar("exponent", dtype="int8")
196+
out = pt.power(x, exponent)
197+
compare_numba_and_py(
198+
[x, exponent], [out], [np.array(0.5), np.array(2, dtype="int8")]
199+
)
200+
201+
192202
def test_cython_obj_mode_fallback():
193203
"""Test that unsupported cython signatures fallback to obj-mode"""
194204

0 commit comments

Comments
 (0)