-
Notifications
You must be signed in to change notification settings - Fork 149
Fix non-inplace IfElse on numba mode #1765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |||||
| from pytensor.compile.mode import NUMBA | ||||||
| from pytensor.compile.ops import DeepCopyOp, TypeCastingOp | ||||||
| from pytensor.ifelse import IfElse | ||||||
| from pytensor.link.numba.cache import compile_numba_function_src | ||||||
| from pytensor.link.numba.dispatch import basic as numba_basic | ||||||
| from pytensor.link.numba.dispatch.basic import ( | ||||||
| numba_funcify_and_cache_key, | ||||||
|
|
@@ -103,30 +104,48 @@ def deepcopy(x): | |||||
| @register_funcify_default_op_cache_key(IfElse) | ||||||
| def numba_funcify_IfElse(op, **kwargs): | ||||||
| n_outs = op.n_outs | ||||||
| as_view = op.as_view | ||||||
|
|
||||||
| if n_outs > 1: | ||||||
| if n_outs == 1: | ||||||
|
|
||||||
| @numba_basic.numba_njit | ||||||
| def ifelse(cond, *args): | ||||||
| if cond: | ||||||
| res = args[:n_outs] | ||||||
| else: | ||||||
| res = args[n_outs:] | ||||||
| def ifelse(cond, x_true, x_false): | ||||||
| arr = x_true if cond else x_false | ||||||
| return arr if as_view else arr.copy() | ||||||
|
|
||||||
| return res | ||||||
| cache_version = 3 | ||||||
| return ifelse, cache_version | ||||||
|
|
||||||
| true_names = [f"t{i}" for i in range(n_outs)] | ||||||
| false_names = [f"f{i}" for i in range(n_outs)] | ||||||
| arg_list = ", ".join(true_names + false_names) | ||||||
|
|
||||||
| # Build return expressions | ||||||
| if as_view: | ||||||
| true_returns = ", ".join(true_names) | ||||||
| false_returns = ", ".join(false_names) | ||||||
| else: | ||||||
| true_returns = ", ".join(f"{name}.copy()" for name in true_names) | ||||||
| false_returns = ", ".join(f"{name}.copy()" for name in false_names) | ||||||
|
|
||||||
| # Build the code for the function | ||||||
| func_src = f""" | ||||||
| def ifelse_codegen(cond, {arg_list}): | ||||||
| if cond: | ||||||
| return ({true_returns}) | ||||||
| else: | ||||||
| return ({false_returns}) | ||||||
| """ | ||||||
|
|
||||||
| @numba_basic.numba_njit | ||||||
| def ifelse(cond, *args): | ||||||
| if cond: | ||||||
| res = args[:n_outs] | ||||||
| else: | ||||||
| res = args[n_outs:] | ||||||
| # Compile the generated source code into a Python function | ||||||
| ifelse_py = compile_numba_function_src(func_src, "ifelse_codegen", globals()) | ||||||
|
|
||||||
| # JIT-compile using numba | ||||||
| ifelse_numba = numba_basic.numba_njit(ifelse_py) | ||||||
|
|
||||||
| return res[0] | ||||||
| cache_version = 3 | ||||||
|
||||||
| cache_version = 3 | |
| cache_version = 1 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from pytensor import OpFromGraph, config, function, ifelse | ||
| from pytensor import Mode, OpFromGraph, config, function, ifelse | ||
| from pytensor import tensor as pt | ||
| from pytensor.compile import ViewOp | ||
| from pytensor.raise_op import assert_op | ||
|
|
@@ -153,3 +153,52 @@ def test_check_and_raise(): | |
| out = assert_op(x.sum(), np.array(True)) | ||
|
|
||
| compare_numba_and_py([x], out, [x_test_value]) | ||
|
|
||
|
|
||
| def test_ifelse_single_output(): | ||
| x = pt.vector("x") | ||
| out = ifelse(x.sum() > 0, x, x) | ||
|
|
||
| fn = function([x], out, mode=Mode("numba", optimizer=None)) | ||
|
|
||
| x_test = np.zeros((5,)) | ||
| res = fn(x_test) | ||
|
|
||
| # Returned array should not be the input (must be a copy) | ||
| assert res is not x_test | ||
| assert np.array_equal(res, x_test) | ||
|
|
||
|
|
||
| def test_ifelse_multiple_outputs(): | ||
| x = pt.vector("x") | ||
| y = pt.vector("y") | ||
| out1, out2 = ifelse(x.sum() > 0, (x, y), (y, x)) | ||
|
|
||
| fn = function([x, y], [out1, out2], mode=Mode("numba", optimizer=None)) | ||
|
||
|
|
||
| a = np.ones(3) | ||
| b = np.zeros(3) | ||
|
|
||
| r1, r2 = fn(a, b) | ||
|
|
||
| assert np.array_equal(r1, a) | ||
| assert np.array_equal(r2, b) | ||
| assert r1 is not a | ||
| assert r2 is not b | ||
|
|
||
|
|
||
| def test_ifelse_false_branch(): | ||
|
||
| x = pt.vector("x") | ||
| y = pt.vector("y") | ||
|
|
||
| out = ifelse(x.sum() > 0, x, y) | ||
|
|
||
| fn = function([x, y], out, mode=Mode("numba", optimizer=None)) | ||
|
|
||
| a = np.zeros(3) | ||
| b = np.arange(3) | ||
|
|
||
| res = fn(a, b) | ||
|
|
||
| assert np.array_equal(res, b) | ||
| assert res is not b | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On second thought I guess we can get rid of the special case and stay with the codegen for every case now