Skip to content

Commit 510fc48

Browse files
Add unpack helper
1 parent 1077c17 commit 510fc48

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from pytensor.scalar import upcast
2727
from pytensor.tensor import TensorLike, as_tensor_variable
2828
from pytensor.tensor import basic as ptb
29-
from pytensor.tensor.basic import alloc, as_tensor, join, second
29+
from pytensor.tensor.basic import alloc, as_tensor, join, second, split
3030
from pytensor.tensor.exceptions import NotScalarConstantError
3131
from pytensor.tensor.math import abs as pt_abs
3232
from pytensor.tensor.math import all as pt_all
@@ -2198,6 +2198,35 @@ def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None):
21982198
return join(n_before, *reshaped_tensors), packed_shapes
21992199

22002200

2201+
def unpack(packed_input, axes, packed_shapes):
2202+
if axes is None:
2203+
if packed_input.ndim != 1:
2204+
raise ValueError(
2205+
"unpack can only be called with keep_axis=None for 1d inputs"
2206+
)
2207+
split_axis = 0
2208+
else:
2209+
axes = normalize_axis_tuple(axes, ndim=packed_input.ndim)
2210+
try:
2211+
[split_axis] = (i for i in range(packed_input.ndim) if i not in axes)
2212+
except ValueError as err:
2213+
raise ValueError(
2214+
"Unpack must have exactly one more dimension that implied by axes"
2215+
) from err
2216+
2217+
split_inputs = split(
2218+
packed_input,
2219+
splits_size=[prod(shape).astype(int) for shape in packed_shapes],
2220+
n_splits=len(packed_shapes),
2221+
axis=split_axis,
2222+
)
2223+
2224+
return [
2225+
split_dims(inp, shape, split_axis)
2226+
for inp, shape in zip(split_inputs, packed_shapes, strict=True)
2227+
]
2228+
2229+
22012230
__all__ = [
22022231
"bartlett",
22032232
"bincount",

tests/tensor/test_extra_ops.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
split_dims,
4848
squeeze,
4949
to_one_hot,
50+
unpack,
5051
unravel_index,
5152
)
5253
from pytensor.tensor.type import (
@@ -1593,3 +1594,50 @@ def test_pack_basic(self):
15931594
packed_shape.eval(input_dict, on_unused_input="ignore"),
15941595
tensor.type.shape[1:-1],
15951596
)
1597+
1598+
@pytest.mark.parametrize("axes", [None, -1, (-2, -1)])
1599+
def test_pack_unpack_round_trip(self, axes):
1600+
rng = np.random.default_rng()
1601+
1602+
x = pt.tensor("x", shape=(3, 5))
1603+
y = pt.tensor("y", shape=(3, 3, 5))
1604+
z = pt.tensor("z", shape=(1, 3, 5))
1605+
1606+
flat_packed, packed_shapes = pack(x, y, z, axes=axes)
1607+
new_outputs = unpack(flat_packed, axes=axes, packed_shapes=packed_shapes)
1608+
1609+
fn = pytensor.function([x, y, z], new_outputs, mode="FAST_COMPILE")
1610+
1611+
input_dict = {
1612+
var.name: rng.normal(size=var.type.shape).astype(config.floatX)
1613+
for var in [x, y, z]
1614+
}
1615+
output_vals = fn(**input_dict)
1616+
1617+
for input_val, output_val in zip(input_dict.values(), output_vals, strict=True):
1618+
np.testing.assert_allclose(input_val, output_val)
1619+
1620+
1621+
def test_make_replacements_with_pack_unpack():
1622+
rng = np.random.default_rng()
1623+
1624+
x = pt.tensor("x", shape=())
1625+
y = pt.tensor("y", shape=(5,))
1626+
z = pt.tensor("z", shape=(3, 3))
1627+
1628+
loss = (x + y.sum() + z.sum()) ** 2
1629+
1630+
flat_packed, packed_shapes = pack(x, y, z, axes=None)
1631+
new_input = flat_packed.type()
1632+
new_outputs = unpack(new_input, axes=None, packed_shapes=packed_shapes)
1633+
1634+
loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
1635+
fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE")
1636+
1637+
input_vals = [
1638+
rng.normal(size=(var.type.shape)).astype(config.floatX) for var in [x, y, z]
1639+
]
1640+
flat_inputs = np.concatenate([input.ravel() for input in input_vals], axis=0)
1641+
output_val = fn(flat_inputs, *input_vals)
1642+
1643+
assert np.allclose(output_val, sum([input.sum() for input in input_vals]) ** 2)

0 commit comments

Comments
 (0)