diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 5bce6b8b92..175ca783bf 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -142,6 +142,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: specify_broadcastable, specify_shape, ) +from pytensor.tensor.shape_ops import * # We import as `_shared` instead of `shared` to avoid confusion between # `pytensor.shared` and `tensor._shared`. diff --git a/pytensor/tensor/shape_ops.py b/pytensor/tensor/shape_ops.py new file mode 100644 index 0000000000..9dbdabf248 --- /dev/null +++ b/pytensor/tensor/shape_ops.py @@ -0,0 +1,522 @@ +from collections.abc import Iterable, Sequence +from itertools import pairwise +from typing import cast as type_cast + +import numpy as np +from numpy.lib._array_utils_impl import normalize_axis_tuple + +from pytensor import Variable +from pytensor.graph import Apply +from pytensor.graph.op import Op +from pytensor.tensor import TensorLike, as_tensor_variable +from pytensor.tensor.basic import ( + expand_dims, + get_underlying_scalar_constant_value, + join, + split, +) +from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import prod +from pytensor.tensor.shape import ShapeValueType +from pytensor.tensor.type import tensor +from pytensor.tensor.variable import TensorVariable + + +class JoinDims(Op): + __props__ = ("axis",) + view_map = {0: [0]} + + def __init__(self, axis: Sequence[int]): + if any(i < 0 for i in axis): + raise ValueError("JoinDims axis must be non-negative") + + if len(axis) > 1 and np.diff(axis).max() > 1: + raise ValueError( + f"join_dims axis must be consecutive, got normalized axis: {axis}" + ) + + self.axis = axis + + def make_node(self, x: Variable) -> Apply: # type: ignore[override] + static_shapes = x.type.shape + if x.type.ndim < max(self.axis) + 1: + raise ValueError( + f"Input ndim {x.type.ndim} is less than the maximum axis {max(self.axis)} + 1" + ) + joined_shape = ( + int(np.prod([static_shapes[i] for i in self.axis])) + if all(static_shapes[i] is not None for i in self.axis) + else None + ) + + output_shapes = ( + *static_shapes[: min(self.axis)], + joined_shape, + *static_shapes[max(self.axis) + 1 :], + ) + + output_type = tensor(shape=output_shapes, dtype=x.type.dtype) + return Apply(self, [x], [output_type]) + + def infer_shape(self, fgraph, node, shapes): + [input_shape] = shapes + joined_shape = prod([input_shape[i] for i in self.axis]) + out_shape = ( + *input_shape[: min(self.axis)], + joined_shape, + *input_shape[max(self.axis) + 1 :], + ) + + return [out_shape] + + def perform(self, node, inputs, outputs): + (x,) = inputs + (out,) = outputs + + output_shape = [ + *x.shape[: min(self.axis)], + -1, + *x.shape[max(self.axis) + 1 :], + ] + + out[0] = x.reshape(tuple(output_shape)) + + +def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable: + """Join consecutive dimensions of a tensor into a single dimension. + + Parameters + ---------- + x : Variable + The input tensor. + axis : int or sequence of int, optional + The dimensions to join. If None, all dimensions are joined. + + Returns + ------- + joined_x : Variable + The reshaped tensor with joined dimensions. + + Examples + -------- + >>> import pytensor.tensor as pt + >>> x = pt.tensor("x", shape=(2, 3, 4, 5)) + >>> y = pt.join_dims(x, axis=(1, 2)) + >>> y.type.shape + (2, 12, 5) + """ + x = as_tensor_variable(x) + + if axis is None: + axis = list(range(x.ndim)) + elif isinstance(axis, int): + axis = [axis] + elif not isinstance(axis, list | tuple): + raise TypeError("axis must be an int, a list/tuple of ints, or None") + + if not axis: + # The user passed an empty list/tuple, so we return the input as is + return x + + axis = normalize_axis_tuple(axis, x.ndim) + return type_cast(TensorVariable, JoinDims(axis)(x)) + + +class SplitDims(Op): + __props__ = ("axis",) + view_map = {0: [0]} + + def __init__(self, axis: int | None = None): + if axis is not None and axis < 0: + raise ValueError("SplitDims axis must be non-negative") + self.axis = axis + + def _make_output_shape(self, input_shape, shape): + [axis] = normalize_axis_tuple(self.axis, len(input_shape)) + output_shapes = list(input_shape) + + def _get_constant_shape(x): + try: + # get_underling_scalar_constant_value returns a numpy scalar, we need a python int + return get_underlying_scalar_constant_value(x).item() + except NotScalarConstantError: + return x + + constant_shape = [_get_constant_shape(x) for x in shape] + + return *output_shapes[:axis], *constant_shape, *output_shapes[axis + 1 :] + + def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override] + if shape.type.dtype not in ("int8", "int16", "int32", "int64"): + raise TypeError("shape must be an integer tensor") + + output_shapes = self._make_output_shape(x.type.shape, shape) + + output = tensor( + shape=tuple([x if isinstance(x, int) else None for x in output_shapes]), + dtype=x.type.dtype, + ) + return Apply(self, [x, shape], [output]) + + def infer_shape(self, fgraph, node, shapes): + [input_shape, _] = shapes + _, shape = node.inputs + output_shape = self._make_output_shape(input_shape, shape) + + return [output_shape] + + def perform(self, node, inputs, outputs): + (x, shape) = inputs + (out,) = outputs + + out[0] = x.reshape(self._make_output_shape(x.shape, shape)) + + +def split_dims( + x: TensorLike, + shape: ShapeValueType | Sequence[ShapeValueType], + axis: int | None = None, +) -> TensorVariable: + """Split a dimension of a tensor into multiple dimensions. + + Parameters + ---------- + x : TensorLike + The input tensor. + shape : int or sequence of int + The new shape to split the specified dimension into. + axis : int, optional + The dimension to split. If None, the input is assumed to be 1D and axis 0 is used. + + Returns + ------- + split_x : Variable + The reshaped tensor with split dimensions. + + Examples + -------- + >>> import pytensor.tensor as pt + >>> x = pt.tensor("x", shape=(6, 4, 6)) + >>> y = pt.split_dims(x, shape=(2, 3), axis=0) + >>> y.type.shape + (2, 3, 4, 6) + """ + x = as_tensor_variable(x) + + if axis is None: + if x.ndim != 1: + raise ValueError( + "split_dims can only be called with axis=None for 1d inputs" + ) + axis = 0 + + if isinstance(shape, int): + shape = [shape] + else: + shape = list(shape) # type: ignore[arg-type] + + if not shape: + # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for + # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes + # (3, ) and (3, 3) to (3, 4) + return type_cast(TensorVariable, x.squeeze(axis=axis)) + + [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc] + shape = as_tensor_variable(shape) # type: ignore[arg-type] + return type_cast(TensorVariable, SplitDims(axis)(x, shape)) + + +def _analyze_axes_list(axes) -> tuple[int, int, int]: + """ + Analyze the provided axes list to determine how many axes are before and after the interval to be raveled, as + well as the minimum and maximum number of axes that the inputs can have. + + The rules are: + - Axes must be strictly increasing in both the positive and negative parts of the list. + - Negative axes must come after positive axes. + - There can be at most one "hole" in the axes list, which can be either an implicit hole on an endpoint + (e.g. [0, 1]) or an explicit hole in the middle (e.g. [0, 2] or [1, -1]). + + Returns + ------- + n_axes_before: int + The number of axes before the interval to be raveled. + n_axes_after: int + The number of axes after the interval to be raveled. + min_axes: int + The minimum number of axes that the inputs must have. + """ + if axes is None: + return 0, 0, 0 + + if isinstance(axes, int): + axes = [axes] + elif not isinstance(axes, Iterable): + raise TypeError("axes must be an int, an iterable of ints, or None") + + axes = list(axes) + + if len(axes) == 0: + raise ValueError("axes=[] is ambiguous; use None to ravel all") + + if len(set(axes)) != len(axes): + raise ValueError("axes must have no duplicates") + + first_negative_idx = next((i for i, a in enumerate(axes) if a < 0), len(axes)) + positive_axes = list(axes[:first_negative_idx]) + negative_axes = list(axes[first_negative_idx:]) + + if not all(a < 0 for a in negative_axes): + raise ValueError("Negative axes must come after positive") + + def strictly_increasing(s): + return all(b > a for a, b in pairwise(s)) + + if positive_axes and not strictly_increasing(positive_axes): + raise ValueError("Axes must be strictly increasing in the positive part") + if negative_axes and not strictly_increasing(negative_axes): + raise ValueError("Axes must be strictly increasing in the negative part") + + def find_gaps(s): + """Return positions where b - a > 1.""" + return [i for i, (a, b) in enumerate(pairwise(s)) if b - a > 1] + + pos_gaps = find_gaps(positive_axes) + neg_gaps = find_gaps(negative_axes) + + if pos_gaps: + raise ValueError("Positive axes must be contiguous") + if neg_gaps: + raise ValueError("Negative axes must be contiguous") + + if positive_axes and positive_axes[0] != 0: + raise ValueError( + "If positive axes are provided, the first positive axis must be 0 to avoid ambiguity. To ravel indices " + "starting from the front, use negative axes only." + ) + + if negative_axes and negative_axes[-1] != -1: + raise ValueError( + "If negative axes are provided, the last negative axis must be -1 to avoid ambiguity. To ravel indices " + "up to the end, use positive axes only." + ) + + positive_only = positive_axes and not negative_axes + negative_only = negative_axes and not positive_axes + + if positive_only: + n_before = len(positive_axes) + n_after = 0 + min_axes = n_before + + return n_before, n_after, min_axes + + elif negative_only: + n_before = 0 + n_after = len(negative_axes) + min_axes = n_after + + return n_before, n_after, min_axes + + else: + n_before = len(positive_axes) + n_after = len(negative_axes) + min_axes = n_before + n_after + + return n_before, n_after, min_axes + + +def pack( + *tensors: TensorLike, axes: Sequence[int] | int | None = None +) -> tuple[TensorVariable, list[ShapeValueType]]: + """ + Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. + + Parameters + ---------- + *tensors : TensorLike + Input tensors to be packed. + axes : int, sequence of int, or None, optional + Axes to preserve during packing. If None, all axes are raveled. See the Notes section for the rules. + + Returns + ------- + packed_tensor : TensorLike + The packed tensor with specified axes preserved and others raveled. + packed_shapes : list of ShapeValueType + A list containing the shapes of the raveled dimensions for each input tensor. + + Notes + ----- + The `axes` parameter determines which axes are preserved during packing. Axes can be specified using positive or + negative indices, but must follow these rules: + - If axes is None, all axes are raveled. + - If a single integer is provided, it can be positive or negative, and can take any value up to the smallest + number of dimensions among the input tensors. + - If a list is provided, it can be all positive, all negative, or a combination of positive and negative. + - Positive axes must be contiguous and start from 0. + - Negative axes must be contiguous and end at -1. + - If positive and negative axes are combined, positive axes must come before negative axes, and both 0 and -1 + must be included. + + Examples + -------- + The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent + to ``join(0, *[t.ravel() for t in tensors])``: + + .. code-block:: python + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3)) + y = pt.tensor("y", shape=(4, 5, 6)) + + packed_tensor, packed_shapes = pt.pack(x, y, axes=None) + # packed_tensor has shape (6 + 120,) == (126,) + # packed_shapes is [(2, 3), (4, 5, 6)] + + If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors + must have the same size along the preserved axis. For example, using axes=0: + + .. code-block:: python + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 3)) + y = pt.tensor("y", shape=(2, 5, 6)) + packed_tensor, packed_shapes = pt.pack(x, y, axes=0) + # packed_tensor has shape (2, 3 + 30) == (2, 33) + # packed_shapes is [(3,), (5, 6)] + + + Using negative indexing we can preserve the last two axes: + + .. code-block:: python + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(4, 2, 3)) + y = pt.tensor("y", shape=(5, 2, 3)) + packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1)) + # packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3) + # packed_shapes is [(4,), (5, + + Or using a mix of positive and negative axes, we can preserve the first and last axes: + + .. code-block:: python + import pytensor.tensor as pt + + x = pt.tensor("x", shape=(2, 4, 3)) + y = pt.tensor("y", shape=(2, 5, 3)) + packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1)) + # packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3) + # packed_shapes is [(4,), (5,)] + """ + tensor_list = [as_tensor_variable(t) for t in tensors] + + n_before, n_after, min_axes = _analyze_axes_list(axes) + + reshaped_tensors: list[TensorVariable] = [] + packed_shapes: list[ShapeValueType] = [] + + if all([n_before == 0, n_after == 0, min_axes == 0]): + # Special case -- we're raveling everything + packed_shapes = [t.shape for t in tensor_list] + reshaped_tensors = [t.ravel() for t in tensor_list] + + return join(0, *reshaped_tensors), packed_shapes + + for i, input_tensor in enumerate(tensor_list): + n_dim = input_tensor.ndim + + if n_dim < min_axes: + raise ValueError( + f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " + f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." + ) + n_after_packed = n_dim - n_after + packed_shapes.append(input_tensor.shape[n_before:n_after_packed]) + + if n_dim == min_axes: + # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern + # implied by the axes. If n_before == 0, the reshape would be (-1, ...), so we need to expand at axis 0. + # If n_after == 0, the reshape would be (..., -1), so we need to expand at axis -1. If both are equal, + # the reshape will occur in the center of the tensor. + if n_before == 0: + input_tensor = expand_dims(input_tensor, axis=0) + elif n_after == 0: + input_tensor = expand_dims(input_tensor, axis=-1) + elif n_before == n_after: + input_tensor = expand_dims(input_tensor, axis=n_before) + + reshaped_tensors.append(input_tensor) + continue + + # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1, + # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the + # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the + # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing). + join_axes = {n_before, n_after_packed - 1} + joined = join_dims(input_tensor, tuple(join_axes)) + reshaped_tensors.append(joined) + + return join(n_before, *reshaped_tensors), packed_shapes + + +def unpack( + packed_input: TensorLike, + axes: int | Sequence[int] | None, + packed_shapes: list[ShapeValueType], +) -> list[TensorVariable]: + """ + Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. + + The unpacking process reverses the packing operation, restoring the original shapes of the input tensors. `axes` + corresponds to the axes that were preserved during packing, and `packed_shapes` contains the shapes of the raveled + dimensions for each output tensor (that is, the shapes that were destroyed during packing). + + The signature of unpack is such that the same `axes` should be passed to both `pack` and `unpack` to create a + "round-trip" operation. For details on the rules for `axes`, see the documentation for `pack`. + + Parameters + ---------- + packed_input : TensorLike + The packed tensor to be unpacked. + axes : int, sequence of int, or None + Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used. + packed_shapes : list of ShapeValueType + A list containing the shapes of the raveled dimensions for each output tensor. + + Returns + ------- + unpacked_tensors : list of TensorLike + A list of unpacked tensors with their original shapes restored. + """ + packed_input = as_tensor_variable(packed_input) + + if axes is None: + if packed_input.ndim != 1: + raise ValueError( + "unpack can only be called with keep_axis=None for 1d inputs" + ) + split_axis = 0 + else: + axes = normalize_axis_tuple(axes, ndim=packed_input.ndim) + try: + [split_axis] = (i for i in range(packed_input.ndim) if i not in axes) + except ValueError as err: + raise ValueError( + "Unpack must have exactly one more dimension that implied by axes" + ) from err + + split_inputs = split( + packed_input, + splits_size=[prod(shape).astype(int) for shape in packed_shapes], + n_splits=len(packed_shapes), + axis=split_axis, + ) + + return [ + split_dims(inp, shape, split_axis) + for inp, shape in zip(split_inputs, packed_shapes, strict=True) + ] + + +__all__ = ["join_dims", "pack", "split_dims", "unpack"] diff --git a/tests/tensor/test_reshape_ops.py b/tests/tensor/test_reshape_ops.py new file mode 100644 index 0000000000..4a22042afa --- /dev/null +++ b/tests/tensor/test_reshape_ops.py @@ -0,0 +1,262 @@ +import numpy as np +import pytest + +import pytensor +from pytensor import config, function +from pytensor import tensor as pt +from pytensor.tensor.shape_ops import ( + _analyze_axes_list, + join_dims, + pack, + split_dims, + unpack, +) + + +def test_join_dims(): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=(2, 3, 4, 5)) + assert join_dims(x, axis=(0, 1)).type.shape == (6, 4, 5) + assert join_dims(x, axis=(1, 2)).type.shape == (2, 12, 5) + assert join_dims(x, axis=(-1, -2)).type.shape == (2, 3, 20) + + assert join_dims(x, axis=()).type.shape == (2, 3, 4, 5) + assert join_dims(x, axis=(2,)).type.shape == (2, 3, 4, 5) + + with pytest.raises( + ValueError, + match=r"join_dims axis must be consecutive, got normalized axis: \(0, 2\)", + ): + _ = join_dims(x, axis=(0, 2)).type.shape == (8, 3, 5) + + x_joined = join_dims(x, axis=(1, 2)) + x_value = rng.normal(size=(2, 3, 4, 5)).astype(config.floatX) + + fn = function([x], x_joined, mode="FAST_COMPILE") + + x_joined_value = fn(x_value) + np.testing.assert_allclose(x_joined_value, x_value.reshape(2, 12, 5)) + + assert join_dims(x, axis=(1,)).eval({x: x_value}).shape == (2, 3, 4, 5) + assert join_dims(x, axis=()).eval({x: x_value}).shape == (2, 3, 4, 5) + + +@pytest.mark.parametrize( + "axis, shape, expected_shape", + [ + (0, pt.as_tensor([2, 3]), (2, 3, 4, 6)), + (2, [2, 3], (6, 4, 2, 3)), + (-1, 6, (6, 4, 6)), + ], + ids=["tensor", "list", "integer"], +) +def test_split_dims(axis, shape, expected_shape): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=(6, 4, 6)) + x_split = split_dims(x, axis=axis, shape=shape) + assert x_split.type.shape == expected_shape + + x_split = split_dims(x, axis=axis, shape=shape) + x_value = rng.normal(size=(6, 4, 6)).astype(config.floatX) + + fn = function([x], x_split, mode="FAST_COMPILE") + + x_split_value = fn(x_value) + np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape)) + + +def test_split_size_zero_shape(): + x = pt.tensor("x", shape=(1, 4, 6)) + x_split = split_dims(x, axis=0, shape=pt.as_tensor(np.zeros((0,)))) + assert x_split.type.shape == (4, 6) + + x_value = np.empty((1, 4, 6), dtype=config.floatX) + + fn = function([x], x_split, mode="FAST_COMPILE") + + x_split_value = fn(x_value) + np.testing.assert_allclose(x_split_value, x_value.squeeze(0)) + + +def test_make_replacements_with_pack_unpack(): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + loss = (x + y.sum() + z.sum()) ** 2 + + flat_packed, packed_shapes = pack(x, y, z, axes=None) + new_input = flat_packed.type() + new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes) + + loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) + fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") + + input_vals = [ + rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z] + ] + flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0) + output_val = fn(flat_inputs, *input_vals) + + assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2) + + +class TestPack: + @pytest.mark.parametrize( + "axes, expected", + [ + (None, [0, 0, 0]), # '*' + ([0, 1], [2, 0, 2]), # 'i j *' + ([-1], [0, 1, 1]), # '* k' + ([-2, -1], [0, 2, 2]), # '* i j' + ([0, -1], [1, 1, 2]), # 'i * k' + ([0, 1, 2, -1], [3, 1, 4]), # 'i j k * l' + ], + ids=[ + "ravel_all", + "keep_first_two", + "keep_last", + "ravel_start", + "first_and_last", + "complex_case", + ], + ) + def test_analyze_axes_list_valid(self, axes, expected): + outputs = _analyze_axes_list(axes) + names = ["n_before", "n_after", "min_axes"] + for out, exp, name in zip(outputs, expected, names, strict=True): + assert out == exp, f"Expected {exp}, got {out} for {name}" + + def test_analyze_axes_list_invalid(self): + # Positive only but not contiguous + with pytest.raises(ValueError, match="Positive axes must be contiguous"): + _analyze_axes_list([1, 3]) + + # Negative only but not contiguous + with pytest.raises(ValueError, match="Negative axes must be contiguous"): + _analyze_axes_list([-3, -1]) + + # Mixed up positive and negative + with pytest.raises(ValueError, match="Negative axes must come after positive"): + _analyze_axes_list([0, 1, -2, 4]) + + # Duplicate axes + with pytest.raises(ValueError, match="axes must have no duplicates"): + _analyze_axes_list([0, 0]) + + # Not monotonic + with pytest.raises(ValueError, match="Axes must be strictly increasing"): + _analyze_axes_list([0, 2, 1]) + + # Negative before positive + with pytest.raises(ValueError, match="Negative axes must come after positive"): + _analyze_axes_list([-1, 0]) + + def test_pack_basic(self): + # rng = np.random.default_rng() + x = pt.tensor("x", shape=()) + y = pt.tensor("y", shape=(5,)) + z = pt.tensor("z", shape=(3, 3)) + + input_dict = { + variable.name: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + + # Simple case, reduce all axes, equivalent to einops '*' + packed_tensor, packed_shapes = pack(x, y, z, axes=None) + assert packed_tensor.type.shape == (15,) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim,) + np.testing.assert_allclose( + packed_shape.eval(input_dict, on_unused_input="ignore"), + tensor.type.shape, + ) + + # To preserve an axis, all inputs need at least one dimension, and the preserved axis has to agree. + # x is scalar, so pack will raise: + with pytest.raises( + ValueError, + match=r"Input 0 \(zero indexed\) to pack has 0 dimensions, but axes=0 assumes at least 1 dimension\.", + ): + pack(x, y, z, axes=0) + + # With valid x, pack should still raise, because the axis of concatenation doesn't agree across all inputs + x = pt.tensor("x", shape=(3,)) + input_dict["x"] = np.zeros((3,), dtype=config.floatX) + + with pytest.raises( + ValueError, + match=r"all input array dimensions other than the specified `axis` \(1\) must match exactly, or be unknown " + r"\(None\), but along dimension 0, the inputs shapes are incompatible: \[3 5 3\]", + ): + packed_tensor, packed_shapes = pack(x, y, z, axes=0) + packed_tensor.eval(input_dict) + + # Valid case, preserve first axis, equivalent to einops 'i *' + y = pt.tensor("y", shape=(3, 5)) + z = pt.tensor("z", shape=(3, 3, 3)) + packed_tensor, packed_shapes = pack(x, y, z, axes=0) + input_dict = { + variable.name: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + assert packed_tensor.type.shape == (3, 15) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 1,) + np.testing.assert_allclose( + packed_shape.eval(input_dict, on_unused_input="ignore"), + tensor.type.shape[1:], + ) + + # More complex case, preserve last axis implicitly, equivalent to einops 'i * k'. This introduces a max + # dimension condition on the input shapes + x = pt.tensor("x", shape=(3, 2)) + y = pt.tensor("y", shape=(3, 5, 2)) + z = pt.tensor("z", shape=(3, 1, 7, 5, 2)) + + with pytest.raises( + ValueError, + match=r"Positive axes must be contiguous", + ): + pack(x, y, z, axes=[0, 3]) + + z = pt.tensor("z", shape=(3, 1, 7, 2)) + packed_tensor, packed_shapes = pack(x, y, z, axes=[0, -1]) + input_dict = { + variable.name: np.zeros(variable.type.shape, dtype=config.floatX) + for variable in [x, y, z] + } + assert packed_tensor.type.shape == (3, 13, 2) + for tensor, packed_shape in zip([x, y, z], packed_shapes): + assert packed_shape.type.shape == (tensor.ndim - 2,) + np.testing.assert_allclose( + packed_shape.eval(input_dict, on_unused_input="ignore"), + tensor.type.shape[1:-1], + ) + + @pytest.mark.parametrize("axes", [-1]) + def test_pack_unpack_round_trip(self, axes): + rng = np.random.default_rng() + + x = pt.tensor("x", shape=(3, 5)) + y = pt.tensor("y", shape=(3, 3, 5)) + z = pt.tensor("z", shape=(1, 3, 5)) + + flat_packed, packed_shapes = pack(x, y, z, axes=axes) + new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes) + + fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE") + + input_dict = { + var.name: rng.normal(size=var.type.shape).astype(config.floatX) + for var in [x, y, z] + } + output_vals = fn(**input_dict) + + for input_val, output_val in zip(input_dict.values(), output_vals, strict=True): + np.testing.assert_allclose(input_val, output_val)