diff --git a/pytensor/link/numba/dispatch/compile_ops.py b/pytensor/link/numba/dispatch/compile_ops.py index 8eb73d0111..f869e49f8a 100644 --- a/pytensor/link/numba/dispatch/compile_ops.py +++ b/pytensor/link/numba/dispatch/compile_ops.py @@ -10,6 +10,7 @@ from pytensor.compile.mode import NUMBA from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.ifelse import IfElse +from pytensor.link.numba.cache import compile_numba_function_src from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( numba_funcify_and_cache_key, @@ -103,30 +104,38 @@ def deepcopy(x): @register_funcify_default_op_cache_key(IfElse) def numba_funcify_IfElse(op, **kwargs): n_outs = op.n_outs + as_view = op.as_view - if n_outs > 1: - - @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - res = args[:n_outs] - else: - res = args[n_outs:] - - return res + true_names = [f"t{i}" for i in range(n_outs)] + false_names = [f"f{i}" for i in range(n_outs)] + arg_list = ", ".join(true_names + false_names) + # Build return expressions + if as_view: + true_returns = ", ".join(true_names) + false_returns = ", ".join(false_names) else: + true_returns = ", ".join(f"{name}.copy()" for name in true_names) + false_returns = ", ".join(f"{name}.copy()" for name in false_names) + + # Build the code for the function + func_src = f""" +def ifelse_codegen(cond, {arg_list}): + if cond: + return ({true_returns}) + else: + return ({false_returns}) +""" + + # Compile the generated source code into a Python function + ifelse_py = compile_numba_function_src(func_src, "ifelse_codegen", globals()) - @numba_basic.numba_njit - def ifelse(cond, *args): - if cond: - res = args[:n_outs] - else: - res = args[n_outs:] + # JIT-compile using numba + ifelse_numba = numba_basic.numba_njit(ifelse_py) - return res[0] + cache_version = 1 - return ifelse + return ifelse_numba, cache_version @register_funcify_and_cache_key(CheckAndRaise) diff --git a/tests/link/numba/test_compile_ops.py b/tests/link/numba/test_compile_ops.py index b51b359a08..7216cfb0bc 100644 --- a/tests/link/numba/test_compile_ops.py +++ b/tests/link/numba/test_compile_ops.py @@ -1,9 +1,10 @@ import numpy as np import pytest -from pytensor import OpFromGraph, config, function, ifelse +from pytensor import Mode, OpFromGraph, config, function, ifelse from pytensor import tensor as pt from pytensor.compile import ViewOp +from pytensor.ifelse import IfElse from pytensor.raise_op import assert_op from tests.link.numba.test_basic import compare_numba_and_py @@ -153,3 +154,74 @@ def test_check_and_raise(): out = assert_op(x.sum(), np.array(True)) compare_numba_and_py([x], out, [x_test_value]) + + +@pytest.mark.parametrize("as_view", [True, False]) +def test_ifelse_single_output(as_view): + x = pt.vector("x") + + op = IfElse(as_view=as_view, n_outs=1) + out = op(x.sum() > 0, [x], [x])[0] # returns tuple/list + + fn = function([x], out, mode=Mode("numba", optimizer=None), accept_inplace=True) + + # FALSE branch + a = np.zeros(3) + res_false = fn(a) + + assert np.array_equal(res_false, a) + if as_view: + assert res_false is a + else: + assert res_false is not a + + # TRUE branch + b = np.ones(3) + res_true = fn(b) + + assert np.array_equal(res_true, b) + if as_view: + assert res_true is b + else: + assert res_true is not b + + +@pytest.mark.parametrize("as_view", [True, False]) +def test_ifelse_multiple_outputs(as_view): + x = pt.vector("x") + y = pt.vector("y") + + op = IfElse(as_view=as_view, n_outs=2) + out1, out2 = op(x.sum() > 0, x, y, y, x) + + fn = function( + [x, y], [out1, out2], mode=Mode("numba", optimizer=None), accept_inplace=True + ) + + # TRUE branch + a = np.ones(3) + b = np.zeros(3) + r1_true, r2_true = fn(a, b) + + assert np.array_equal(r1_true, a) + assert np.array_equal(r2_true, b) + if as_view: + assert r1_true is a + assert r2_true is b + else: + assert r1_true is not a + assert r2_true is not b + + # FALSE branch + a2 = np.zeros(3) + b2 = np.arange(3) + r1_false, r2_false = fn(a2, b2) + + assert np.array_equal(r1_false, b2) + assert np.array_equal(r2_false, a2) + if as_view: + assert r1_false is b2 + assert r2_false is a2 + else: + assert r1_false is not b2 + assert r2_false is not a2