|
8 | 8 | from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple |
9 | 9 | from numpy.lib.stride_tricks import as_strided |
10 | 10 |
|
| 11 | +from pytensor import config |
11 | 12 | from pytensor.graph.op import Op |
12 | 13 | from pytensor.link.numba.cache import ( |
13 | 14 | compile_numba_function_src, |
@@ -608,45 +609,62 @@ def numba_funcify_Dot(op, node, **kwargs): |
608 | 609 | x, y = node.inputs |
609 | 610 | [out] = node.outputs |
610 | 611 |
|
611 | | - x_dtype = x.type.dtype |
612 | | - y_dtype = y.type.dtype |
613 | | - dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" |
614 | | - out_dtype = out.type.dtype |
| 612 | + x_dtype = x.type.numpy_dtype |
| 613 | + y_dtype = y.type.numpy_dtype |
615 | 614 |
|
616 | | - if x_dtype == dot_dtype and y_dtype == dot_dtype: |
| 615 | + numba_dot_dtype = out_dtype = out.type.numpy_dtype |
| 616 | + if out_dtype.kind not in "fc": |
| 617 | + # Numba alawys returns non-integral outputs, we need to cast to float |
| 618 | + numba_dot_dtype = np.dtype( |
| 619 | + f"float{max((32, out.type.numpy_dtype.itemsize * 8))}" |
| 620 | + ) |
| 621 | + |
| 622 | + if config.compiler_verbose and not ( |
| 623 | + x_dtype == y_dtype == out_dtype == numba_dot_dtype |
| 624 | + ): |
| 625 | + print( # noqa: T201 |
| 626 | + "Numba Dot requires a type casting of inputs and/or output: " |
| 627 | + f"{x_dtype=}, {y_dtype=}, {out_dtype=}, {numba_dot_dtype=}" |
| 628 | + ) |
| 629 | + |
| 630 | + if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype: |
617 | 631 |
|
618 | 632 | @numba_basic.numba_njit |
619 | 633 | def dot(x, y): |
620 | 634 | return np.asarray(np.dot(x, y)) |
621 | 635 |
|
622 | | - elif x_dtype == dot_dtype and y_dtype != dot_dtype: |
| 636 | + elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype: |
623 | 637 |
|
624 | 638 | @numba_basic.numba_njit |
625 | 639 | def dot(x, y): |
626 | | - return np.asarray(np.dot(x, y.astype(dot_dtype))) |
| 640 | + return np.asarray(np.dot(x, y.astype(numba_dot_dtype))) |
627 | 641 |
|
628 | | - elif x_dtype != dot_dtype and y_dtype == dot_dtype: |
| 642 | + elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype: |
629 | 643 |
|
630 | 644 | @numba_basic.numba_njit |
631 | 645 | def dot(x, y): |
632 | | - return np.asarray(np.dot(x.astype(dot_dtype), y)) |
| 646 | + return np.asarray(np.dot(x.astype(numba_dot_dtype), y)) |
633 | 647 |
|
634 | 648 | else: |
635 | 649 |
|
636 | 650 | @numba_basic.numba_njit |
637 | 651 | def dot(x, y): |
638 | | - return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) |
| 652 | + return np.asarray( |
| 653 | + np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype)) |
| 654 | + ) |
| 655 | + |
| 656 | + cache_version = 1 |
639 | 657 |
|
640 | | - if out_dtype == dot_dtype: |
641 | | - return dot |
| 658 | + if out_dtype == numba_dot_dtype: |
| 659 | + return dot, cache_version |
642 | 660 |
|
643 | 661 | else: |
644 | 662 |
|
645 | 663 | @numba_basic.numba_njit |
646 | 664 | def dot_with_cast(x, y): |
647 | 665 | return dot(x, y).astype(out_dtype) |
648 | 666 |
|
649 | | - return dot_with_cast |
| 667 | + return dot_with_cast, cache_version |
650 | 668 |
|
651 | 669 |
|
652 | 670 | @register_funcify_default_op_cache_key(BatchedDot) |
|
0 commit comments