Skip to content

Commit 3b722ce

Browse files
committed
Numba Dot: Handle complex inputs
1 parent 3ff7603 commit 3b722ce

File tree

2 files changed

+50
-13
lines changed

2 files changed

+50
-13
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
99
from numpy.lib.stride_tricks import as_strided
1010

11+
from pytensor import config
1112
from pytensor.graph.op import Op
1213
from pytensor.link.numba.cache import (
1314
compile_numba_function_src,
@@ -608,45 +609,62 @@ def numba_funcify_Dot(op, node, **kwargs):
608609
x, y = node.inputs
609610
[out] = node.outputs
610611

611-
x_dtype = x.type.dtype
612-
y_dtype = y.type.dtype
613-
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
614-
out_dtype = out.type.dtype
612+
x_dtype = x.type.numpy_dtype
613+
y_dtype = y.type.numpy_dtype
615614

616-
if x_dtype == dot_dtype and y_dtype == dot_dtype:
615+
numba_dot_dtype = out_dtype = out.type.numpy_dtype
616+
if out_dtype.kind not in "fc":
617+
# Numba alawys returns non-integral outputs, we need to cast to float
618+
numba_dot_dtype = np.dtype(
619+
f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
620+
)
621+
622+
if config.compiler_verbose and not (
623+
x_dtype == y_dtype == out_dtype == numba_dot_dtype
624+
):
625+
print( # noqa: T201
626+
"Numba Dot requires a type casting of inputs and/or output: "
627+
f"{x_dtype=}, {y_dtype=}, {out_dtype=}, {numba_dot_dtype=}"
628+
)
629+
630+
if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype:
617631

618632
@numba_basic.numba_njit
619633
def dot(x, y):
620634
return np.asarray(np.dot(x, y))
621635

622-
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
636+
elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype:
623637

624638
@numba_basic.numba_njit
625639
def dot(x, y):
626-
return np.asarray(np.dot(x, y.astype(dot_dtype)))
640+
return np.asarray(np.dot(x, y.astype(numba_dot_dtype)))
627641

628-
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
642+
elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype:
629643

630644
@numba_basic.numba_njit
631645
def dot(x, y):
632-
return np.asarray(np.dot(x.astype(dot_dtype), y))
646+
return np.asarray(np.dot(x.astype(numba_dot_dtype), y))
633647

634648
else:
635649

636650
@numba_basic.numba_njit
637651
def dot(x, y):
638-
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
652+
return np.asarray(
653+
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype))
654+
)
655+
656+
cache_version = 1
639657

640-
if out_dtype == dot_dtype:
641-
return dot
658+
if out_dtype == numba_dot_dtype:
659+
return dot, cache_version
642660

643661
else:
644662

645663
@numba_basic.numba_njit
646664
def dot_with_cast(x, y):
647665
return dot(x, y).astype(out_dtype)
648666

649-
return dot_with_cast
667+
return dot_with_cast, cache_version
650668

651669

652670
@register_funcify_default_op_cache_key(BatchedDot)

tests/link/numba/test_elemwise.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,25 @@ def test_dimshuffle(self, c_contiguous, benchmark):
718718
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
719719
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
720720
),
721+
# Viewing the array with 2 last dimensions as complex128 means
722+
# the first entry will be real part and the second entry the imaginary part
723+
(
724+
(
725+
pt.matrix(dtype="complex128"),
726+
rng.random(size=(5, 4, 2)).view("complex128").squeeze(-1),
727+
),
728+
(
729+
pt.matrix(dtype="complex128"),
730+
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
731+
),
732+
),
733+
(
734+
(pt.matrix(dtype="int64"), rng.random(size=(5, 4)).astype("int64")),
735+
(
736+
pt.matrix(dtype="complex128"),
737+
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
738+
),
739+
),
721740
],
722741
)
723742
def test_Dot(x, y):

0 commit comments

Comments
 (0)