@@ -2267,17 +2267,113 @@ def pack(
22672267 Parameters
22682268 ----------
22692269 tensors: TensorVariable
2270- Tensors to be packed into a single vector.
2270+ Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each
2271+ of the dimensions specified in the `axes` parameter.
22712272 axes: int or sequence of int, optional
2272- Axes to be concatenated. All other axes will be raveled (packed) and joined. If None, all axes will be raveled
2273- and joined.
2273+ Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating
2274+ on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
2275+ negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes
2276+ must come after positive axes.
22742277
22752278 Returns
22762279 -------
22772280 flat_tensor: TensorVariable
22782281 A new symbolic variable representing the concatenated 1d vector of all tensor inputs
22792282 packed_shapes: list of tuples of TensorVariable
22802283 A list of tuples, where each tuple contains the symbolic shape of the original tensors.
2284+
2285+ Notes
2286+ -----
2287+ This function is a helper for joining tensors of varying shapes into a single tenor. This is done by choosing a
2288+ list of axes to concatenate, and raveling all other axes. The resulting tensor are then concatenated along the
2289+ raveled axis. The original shapes of the tensors are also returned, so that they can be unpacked later.
2290+
2291+ The `axes` parameter determines which dimensions are *not* raveled. The requested axes must exist in all input
2292+ tensors, but there are otherwwise no restrictions on the shapes or dimensions of the input tensors. For example, if
2293+ `axes=[0]`, then the first dimension of each tensor is preserved, and all other dimensions are raveled:
2294+
2295+ .. code-block:: python
2296+
2297+ import pytensor.tensor as pt
2298+
2299+ x = pt.tensor("x", shape=(2, 3, 4))
2300+ y = pt.tensor("y", shape=(2, 5))
2301+ packed_output, shapes = pack(x, y, axes=0)
2302+ # packed_output will have shape (2, 3 * 4 + 5) = (2, 17)
2303+
2304+ Since axes = 0, the first dimension of both `x` and `y` is preserved. This first example is equivalent to a simple
2305+ reshape and concat operation:
2306+
2307+ .. code-block:: python
2308+
2309+ x_reshaped = x.reshape(2, -1) # shape (2, 12)
2310+ y_reshaped = y.reshape(2, -1) # shape (2, 5)
2311+ packed_output = pt.concatenate(
2312+ [x_reshaped, y_reshaped], axis=1
2313+ ) # shape (2, 17)
2314+
2315+ `axes` can also be negative, in which case the axes are counted from the end of the tensor shape. For example,
2316+ if `axes=[-1]`, then the last dimension of each tensor is preserved, and all other dimensions are raveled:
2317+
2318+ .. code-block:: python
2319+
2320+ import pytensor.tensor as pt
2321+
2322+ x = pt.tensor("x", shape=(3, 4, 7))
2323+ y = pt.tensor("y", shape=(6, 2, 1, 7))
2324+ packed_output, shapes = pack(x, y, axes=-1)
2325+ # packed_output will have shape (3 * 4 + 6 * 2 * 1, 7) = (24, 7)
2326+
2327+ The most important restriction of `axes` is that there can be at most one "hole" in the axes list. A hole is
2328+ defined as a missing axis in the sequence of axes. The easiest way to define a hole is by using both positive
2329+ and negative axes together. For example, `axes=[0, -1]` has a hole between the first and last axes. In this case,
2330+ the first and last dimensions of each tensor are preserved, and all other dimensions are raveled:
2331+
2332+ .. code-block:: python
2333+
2334+ import pytensor.tensor as pt
2335+
2336+ x = pt.tensor("x", shape=(2, 3, 2, 3, 7))
2337+ y = pt.tensor("y", shape=(2, 6, 7))
2338+ packed_output, shapes = pack(x, y, axes=[0, -1])
2339+ # packed_output will have shape (2, 3 * 2 * 3 + 6, 7) = (2, 24, 7)
2340+
2341+ Multiple explicit holes are not allowed. For example, `axes = [0, 2, -1]` is illegal because there are two holes,
2342+ one between axes 0 and 2, and another between axes 2 and -1.
2343+
2344+ Implicit holes are also possible when using only positive or only negative axes. `axes = [0]` already has an
2345+ implicit hole to the right of axis 0. `axes = [2, 3]` has two implicit holes, one to the left of axis 2, and another
2346+ to the right. This is illegal, since there are two holes. However, `axes = [2, 3]` can be made legal if we interpret
2347+ axis 3 as the last axis (-1), which closes the right implicit hole. The interpretation requires that at least one
2348+ input tensor has exactly 4 dimensions:
2349+
2350+ .. code-block:: python
2351+
2352+ import pytensor.tensor as pt
2353+
2354+ x = pt.tensor("x", shape=(5, 2, 3, 4))
2355+ y = pt.tensor("y", shape=(2, 3, 4))
2356+ packed_output, shapes = pack(x, y, axes=[2, 3])
2357+ # packed_output will have shape (5 * 2 + 2, 3, 4) = (12, 3, 4)
2358+
2359+ Note here that `y` has only 3 dimensions, so axis 3 is interpreted as -1, the last axis. If no input has 4
2360+ dimensions, or if any input has more than 4 dimensions, an error is raised in this case.
2361+
2362+ Negative axes have similar rules regarding implicit holes. `axes = [-1]` has an implicit hole to the left of
2363+ axis -1. `axes = [-3, -2]` has two implicit holes. To arrive at a valid interpretation, we take -3 to be axis 0,
2364+ which closes the left implicit hole. This requires that at least one input tensor has exactly 3 dimensions:
2365+
2366+ .. code-block:: python
2367+
2368+ import pytensor.tensor as pt
2369+
2370+ x = pt.tensor("x", shape=(2, 3, 4))
2371+ y = pt.tensor("y", shape=(6, 4))
2372+ packed_output, shapes = pack(x, y, axes=[-3, -2])
2373+ # packed_output will have shape (2 + 6, 3, 4) = (8, 3, 4)
2374+
2375+ Similarly to the previous example, if no input has 3 dimensions, or if any input has more than 3 dimensions, an
2376+ error would be raised in this example.
22812377 """
22822378 if not tensors :
22832379 raise ValueError ("Cannot pack an empty list of tensors." )
@@ -2316,6 +2412,7 @@ def pack(
23162412 inputs = tensors ,
23172413 outputs = [packed_output_tensor , * packed_output_shapes ],
23182414 name = "Pack{axes=" + str (axes ) + "}" ,
2415+ inline = True ,
23192416 )
23202417
23212418 outputs = pack_op (* tensors )
0 commit comments