|
25 | 25 | from pytensor.graph.type import Type |
26 | 26 | from pytensor.link.numba.dispatch import basic as numba_basic |
27 | 27 | from pytensor.link.numba.dispatch.basic import ( |
| 28 | + _filter_numba_warnings, |
28 | 29 | cache_key_for_constant, |
29 | 30 | numba_funcify_and_cache_key, |
30 | 31 | ) |
@@ -455,14 +456,46 @@ def test_scalar_return_value_conversion(): |
455 | 456 | assert isinstance(x_fn(1.0), np.ndarray) |
456 | 457 |
|
457 | 458 |
|
458 | | -@pytest.mark.filterwarnings("error") |
459 | | -def test_cache_warning_suppressed(): |
460 | | - x = pt.vector("x", shape=(5,), dtype="float64") |
461 | | - out = pt.psi(x) * 2 |
462 | | - fn = function([x], out, mode="NUMBA") |
463 | | - |
464 | | - x_test = np.random.uniform(size=5) |
465 | | - np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) |
| 459 | +class TestNumbaWarnings: |
| 460 | + def setup_method(self, method): |
| 461 | + # Pytest messes up with the package filters, reenable here for testing |
| 462 | + _filter_numba_warnings() |
| 463 | + |
| 464 | + @pytest.mark.filterwarnings("error") |
| 465 | + def test_cache_pointer_func_warning_suppressed(self): |
| 466 | + x = pt.vector("x", shape=(5,), dtype="float64") |
| 467 | + out = pt.psi(x) * 2 |
| 468 | + fn = function([x], out, mode="NUMBA") |
| 469 | + |
| 470 | + x_test = np.random.uniform(size=5) |
| 471 | + np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2) |
| 472 | + |
| 473 | + @pytest.mark.filterwarnings("error") |
| 474 | + def test_cache_large_global_array_warning_suppressed(self): |
| 475 | + rng = np.random.default_rng(458) |
| 476 | + large_constant = rng.normal(size=(100000, 5)) |
| 477 | + |
| 478 | + x = pt.vector("x", shape=(5,), dtype="float64") |
| 479 | + out = x * large_constant |
| 480 | + fn = function([x], out, mode="NUMBA") |
| 481 | + |
| 482 | + x_test = rng.uniform(size=5) |
| 483 | + np.testing.assert_allclose(fn(x_test), x_test * large_constant) |
| 484 | + |
| 485 | + @pytest.mark.filterwarnings("error") |
| 486 | + def test_contiguous_array_dot_warning_suppressed(self): |
| 487 | + A = pt.matrix("A") |
| 488 | + b = pt.vector("b") |
| 489 | + out = pt.dot(A, b[:, None]) |
| 490 | + # Cached functions won't reemit the warning, so we have to disable it |
| 491 | + with config.change_flags(numba__cache=False): |
| 492 | + fn = function([A, b], out, mode="NUMBA") |
| 493 | + |
| 494 | + A_test = np.ones((5, 5)) |
| 495 | + # Numba actually warns even on contiguous arrays: https://github.com/numba/numba/issues/10086 |
| 496 | + # But either way we don't want this warning for users as they have little control over strides |
| 497 | + b_test = np.ones((10,))[::2] |
| 498 | + np.testing.assert_allclose(fn(A_test, b_test), np.dot(A_test, b_test[:, None])) |
466 | 499 |
|
467 | 500 |
|
468 | 501 | @pytest.mark.parametrize("mode", ("default", "trust_input", "direct")) |
|
0 commit comments