Skip to content

Commit 670132b

Browse files
Use join_dims in pack
1 parent 0e17a31 commit 670132b

File tree

2 files changed

+30
-22
lines changed

2 files changed

+30
-22
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pytensor.graph import Apply
99
from pytensor.graph.op import Op
1010
from pytensor.tensor import TensorLike, as_tensor_variable
11-
from pytensor.tensor.basic import join, split
11+
from pytensor.tensor.basic import expand_dims, join, split
1212
from pytensor.tensor.math import prod
1313
from pytensor.tensor.shape import ShapeValueType
1414
from pytensor.tensor.type import tensor
@@ -119,7 +119,7 @@ def perform(self, node, inputs, outputs):
119119

120120
def split_dims(
121121
x: TensorLike, shape: ShapeValueType, axis: int | None = None
122-
) -> Variable:
122+
) -> TensorVariable:
123123
"""Split a dimension of a tensor into multiple dimensions.
124124
125125
Parameters
@@ -372,23 +372,31 @@ def pack(
372372
f"Input {i} (zero indexed) to pack has {n_dim} dimensions, "
373373
f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}."
374374
)
375-
376-
axis_after_packed_axes = n_dim - n_after
377-
packed_shapes.append(input_tensor.shape[n_before:axis_after_packed_axes])
378-
379-
new_shape = (
380-
*input_tensor.shape[:n_before],
381-
-1,
382-
*input_tensor.shape[axis_after_packed_axes:],
383-
)
384-
reshaped_tensors.append(input_tensor.reshape(new_shape))
385-
386-
# Using join_dims could look like this, but it does not insert extra shapes when needed. For example, it fails
387-
# on pack(pt.tensor("x", shape=(3, )), pt.tensor("y", shape=(3, 3)), axes=0), because the first tensor needs to
388-
# have its single dimension expanded before the join.
389-
390-
# join_axes = {n_before, axis_after_packed_axes - 1}
391-
# reshaped_tensors.append(join_dims(input_tensor, tuple(join_axes)))
375+
n_after_packed = n_dim - n_after
376+
packed_shapes.append(input_tensor.shape[n_before:n_after_packed])
377+
378+
if n_dim == min_axes:
379+
# If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern
380+
# implied by the axes. If n_before == 0, the reshape would be (-1, ...), so we need to expand at axis 0.
381+
# If n_after == 0, the reshape would be (..., -1), so we need to expand at axis -1. If both are equal,
382+
# the reshape will occur in the center of the tensor.
383+
if n_before == 0:
384+
input_tensor = expand_dims(input_tensor, axis=0)
385+
elif n_after == 0:
386+
input_tensor = expand_dims(input_tensor, axis=-1)
387+
elif n_before == n_after:
388+
input_tensor = expand_dims(input_tensor, axis=n_before)
389+
390+
reshaped_tensors.append(input_tensor)
391+
continue
392+
393+
# The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1,
394+
# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
395+
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
396+
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
397+
join_axes = {n_before, n_after_packed - 1}
398+
joined = join_dims(input_tensor, tuple(join_axes))
399+
reshaped_tensors.append(joined)
392400

393401
return join(n_before, *reshaped_tensors), packed_shapes
394402

tests/tensor/test_reshape_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def test_pack_basic(self):
191191

192192
with pytest.raises(
193193
ValueError,
194-
match=r"all the input array dimensions except for the concatenation axis must match exactly, but along "
195-
r"dimension 0, the array at index 0 has size 3",
194+
match=r"all input array dimensions other than the specified `axis` \(1\) must match exactly, or be unknown "
195+
r"\(None\), but along dimension 0, the inputs shapes are incompatible: \[3 5 3\]",
196196
):
197197
packed_tensor, packed_shapes = pack(x, y, z, axes=0)
198198
packed_tensor.eval(input_dict)
@@ -239,7 +239,7 @@ def test_pack_basic(self):
239239
tensor.type.shape[1:-1],
240240
)
241241

242-
@pytest.mark.parametrize("axes", [None, -1, (-2, -1)])
242+
@pytest.mark.parametrize("axes", [-1])
243243
def test_pack_unpack_round_trip(self, axes):
244244
rng = np.random.default_rng()
245245

0 commit comments

Comments
 (0)