Skip to content

Commit 5be83d1

Browse files
Docs
1 parent 4d20c4f commit 5be83d1

File tree

1 file changed

+100
-3
lines changed

1 file changed

+100
-3
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,17 +2267,113 @@ def pack(
22672267
Parameters
22682268
----------
22692269
tensors: TensorVariable
2270-
Tensors to be packed into a single vector.
2270+
Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each
2271+
of the dimensions specified in the `axes` parameter.
22712272
axes: int or sequence of int, optional
2272-
Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2273-
and joined.
2273+
Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating
2274+
on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
2275+
negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes
2276+
must come after positive axes.
22742277
22752278
Returns
22762279
-------
22772280
flat_tensor: TensorVariable
22782281
A new symbolic variable representing the concatenated 1d vector of all tensor inputs
22792282
packed_shapes: list of tuples of TensorVariable
22802283
A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2284+
2285+
Notes
2286+
-----
2287+
This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a
2288+
list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the
2289+
raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later.
2290+
2291+
The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input
2292+
tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if
2293+
`axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled:
2294+
2295+
.. code-block:: python
2296+
2297+
import pytensor.tensor as pt
2298+
2299+
x = pt.tensor("x", shape=(2, 3, 4))
2300+
y = pt.tensor("y", shape=(2, 5))
2301+
packed_output, shapes = pack(x, y, axes=0)
2302+
# packed_output will have shape (2, 3 * 4 + 5) = (2, 17)
2303+
2304+
Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple
2305+
reshape and concat operation:
2306+
2307+
.. code-block:: python
2308+
2309+
x_reshaped = x.reshape(2, -1) # shape (2, 12)
2310+
y_reshaped = y.reshape(2, -1) # shape (2, 5)
2311+
packed_output = pt.concatenate(
2312+
[x_reshaped, y_reshaped], axis=1
2313+
) # shape (2, 17)
2314+
2315+
`axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example,
2316+
if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled:
2317+
2318+
.. code-block:: python
2319+
2320+
import pytensor.tensor as pt
2321+
2322+
x = pt.tensor("x", shape=(3, 4, 7))
2323+
y = pt.tensor("y", shape=(6, 2, 1, 7))
2324+
packed_output, shapes = pack(x, y, axes=-1)
2325+
# packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7)
2326+
2327+
The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is
2328+
defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive
2329+
and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case,
2330+
the first and last dimensions of each tensor are preserved, and all other dimensions are raveled:
2331+
2332+
.. code-block:: python
2333+
2334+
import pytensor.tensor as pt
2335+
2336+
x = pt.tensor("x", shape=(2, 3, 2, 3, 7))
2337+
y = pt.tensor("y", shape=(2, 6, 7))
2338+
packed_output, shapes = pack(x, y, axes=[0, -1])
2339+
# packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7)
2340+
2341+
Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes,
2342+
one between axes 0 and 2, and another between axes 2 and -1.
2343+
2344+
Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an
2345+
implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another
2346+
to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret
2347+
axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one
2348+
input tensor has exactly 4 dimensions:
2349+
2350+
.. code-block:: python
2351+
2352+
import pytensor.tensor as pt
2353+
2354+
x = pt.tensor("x", shape=(5, 2, 3, 4))
2355+
y = pt.tensor("y", shape=(2, 3, 4))
2356+
packed_output, shapes = pack(x, y, axes=[2, 3])
2357+
# packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4)
2358+
2359+
Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4
2360+
dimensions, or if any input has more than 4 dimensions, an error is raised in this case.
2361+
2362+
Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of
2363+
axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0,
2364+
which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions:
2365+
2366+
.. code-block:: python
2367+
2368+
import pytensor.tensor as pt
2369+
2370+
x = pt.tensor("x", shape=(2, 3, 4))
2371+
y = pt.tensor("y", shape=(6, 4))
2372+
packed_output, shapes = pack(x, y, axes=[-3, -2])
2373+
# packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4)
2374+
2375+
Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an
2376+
error would be raised in this example.
22812377
"""
22822378
if not tensors:
22832379
raise ValueError("Cannot pack an empty list of tensors.")
@@ -2316,6 +2412,7 @@ def pack(
23162412
inputs=tensors,
23172413
outputs=[packed_output_tensor, *packed_output_shapes],
23182414
name="Pack{axes=" + str(axes) + "}",
2415+
inline=True,
23192416
)
23202417

23212418
outputs = pack_op(*tensors)

0 commit comments

Comments
 (0)