Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions pytensor/link/numba/dispatch/compile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key,
Expand Down Expand Up @@ -103,30 +104,48 @@ def deepcopy(x):
@register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
as_view = op.as_view

if n_outs > 1:
if n_outs == 1:

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
def ifelse(cond, x_true, x_false):
arr = x_true if cond else x_false
return arr if as_view else arr.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I guess we can get rid of the special case and stay with the codegen for every case now


return res
cache_version = 3
return ifelse, cache_version

true_names = [f"t{i}" for i in range(n_outs)]
false_names = [f"f{i}" for i in range(n_outs)]
arg_list = ", ".join(true_names + false_names)

# Build return expressions
if as_view:
true_returns = ", ".join(true_names)
false_returns = ", ".join(false_names)
else:
true_returns = ", ".join(f"{name}.copy()" for name in true_names)
false_returns = ", ".join(f"{name}.copy()" for name in false_names)

# Build the code for the function
func_src = f"""
def ifelse_codegen(cond, {arg_list}):
if cond:
return ({true_returns})
else:
return ({false_returns})
"""

@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
# Compile the generated source code into a Python function
ifelse_py = compile_numba_function_src(func_src, "ifelse_codegen", globals())

# JIT-compile using numba
ifelse_numba = numba_basic.numba_njit(ifelse_py)

return res[0]
cache_version = 3
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably needed to bump a few times, but for the PR we should only bump once. You can erase your previous cache with pytensor-cache purge for local testing

Suggested change
cache_version = 3
cache_version = 1


return ifelse
return ifelse_numba, cache_version


@register_funcify_and_cache_key(CheckAndRaise)
Expand Down
51 changes: 50 additions & 1 deletion tests/link/numba/test_compile_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest

from pytensor import OpFromGraph, config, function, ifelse
from pytensor import Mode, OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.raise_op import assert_op
Expand Down Expand Up @@ -153,3 +153,52 @@ def test_check_and_raise():
out = assert_op(x.sum(), np.array(True))

compare_numba_and_py([x], out, [x_test_value])


def test_ifelse_single_output():
x = pt.vector("x")
out = ifelse(x.sum() > 0, x, x)

fn = function([x], out, mode=Mode("numba", optimizer=None))

x_test = np.zeros((5,))
res = fn(x_test)

# Returned array should not be the input (must be a copy)
assert res is not x_test
assert np.array_equal(res, x_test)


def test_ifelse_multiple_outputs():
x = pt.vector("x")
y = pt.vector("y")
out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x))

fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrize this and the single output test to test inplace and not inplace. You can create IfElse inplace manually like IfElse(as_view=True|False, n_outs=2), and pass accept_inplace=Truetofunction`.

We want to make sure that r1 is a, r2 is b in that case. Right now we are never testing the inplace mode


a = np.ones(3)
b = np.zeros(3)

r1, r2 = fn(a, b)

assert np.array_equal(r1, a)
assert np.array_equal(r2, b)
assert r1 is not a
assert r2 is not b


def test_ifelse_false_branch():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can merge this test with the previous ones. Just eval the function twice, in a way that triggers the different branches.

x = pt.vector("x")
y = pt.vector("y")

out = ifelse(x.sum() > 0, x, y)

fn = function([x, y], out, mode=Mode("numba", optimizer=None))

a = np.zeros(3)
b = np.arange(3)

res = fn(a, b)

assert np.array_equal(res, b)
assert res is not b
Loading