Skip to content

Commit 934ae9c

Browse files
Add docstrings
1 parent 25bce5e commit 934ae9c

File tree

1 file changed

+113
-3
lines changed

1 file changed

+113
-3
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 113 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.tensor.math import prod
1313
from pytensor.tensor.shape import ShapeValueType
1414
from pytensor.tensor.type import tensor
15-
from pytensor.tensor.variable import TensorConstant
15+
from pytensor.tensor.variable import TensorConstant, TensorVariable
1616

1717

1818
class JoinDims(Op):
@@ -269,7 +269,89 @@ def find_gaps(s):
269269
return n_before, n_after, min_axes
270270

271271

272-
def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None):
272+
def pack(
273+
*tensors: TensorLike, axes: Sequence[int] | int | None = None
274+
) -> tuple[TensorVariable, list[ShapeValueType]]:
275+
"""
276+
Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis.
277+
278+
Parameters
279+
----------
280+
*tensors : TensorLike
281+
Input tensors to be packed.
282+
axes : int, sequence of int, or None, optional
283+
Axes to preserve during packing. If None, all axes are raveled. See the Notes section for the rules.
284+
285+
Returns
286+
-------
287+
packed_tensor : TensorLike
288+
The packed tensor with specified axes preserved and others raveled.
289+
packed_shapes : list of ShapeValueType
290+
A list containing the shapes of the raveled dimensions for each input tensor.
291+
292+
Notes
293+
-----
294+
The `axes` parameter determines which axes are preserved during packing. Axes can be specified using positive or
295+
negative indices, but must follow these rules:
296+
- If axes is None, all axes are raveled.
297+
- If a single integer is provided, it can be positive or negative, and can take any value up to the smallest
298+
number of dimensions among the input tensors.
299+
- If a list is provided, it can be all positive, all negative, or a combination of positive and negative.
300+
- Positive axes must be contiguous and start from 0.
301+
- Negative axes must be contiguous and end at -1.
302+
- If positive and negative axes are combined, positive axes must come before negative axes, and both 0 and -1
303+
must be included.
304+
305+
Examples
306+
--------
307+
The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent
308+
to ``join(0, *[t.ravel() for t in tensors])``:
309+
310+
.. code-block:: python
311+
import pytensor.tensor as pt
312+
313+
x = pt.tensor("x", shape=(2, 3))
314+
y = pt.tensor("y", shape=(4, 5, 6))
315+
316+
packed_tensor, packed_shapes = pt.pack(x, y, axes=None)
317+
# packed_tensor has shape (6 + 120,) == (126,)
318+
# packed_shapes is [(2, 3), (4, 5, 6)]
319+
320+
If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors
321+
must have the same size along the preserved axis. For example, using axes=0:
322+
323+
.. code-block:: python
324+
import pytensor.tensor as pt
325+
326+
x = pt.tensor("x", shape=(2, 3))
327+
y = pt.tensor("y", shape=(2, 5, 6))
328+
packed_tensor, packed_shapes = pt.pack(x, y, axes=0)
329+
# packed_tensor has shape (2, 3 + 30) == (2, 33)
330+
# packed_shapes is [(3,), (5, 6)]
331+
332+
333+
Using negative indexing we can preserve the last two axes:
334+
335+
.. code-block:: python
336+
import pytensor.tensor as pt
337+
338+
x = pt.tensor("x", shape=(4, 2, 3))
339+
y = pt.tensor("y", shape=(5, 2, 3))
340+
packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1))
341+
# packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3)
342+
# packed_shapes is [(4,), (5,
343+
344+
Or using a mix of positive and negative axes, we can preserve the first and last axes:
345+
346+
.. code-block:: python
347+
import pytensor.tensor as pt
348+
349+
x = pt.tensor("x", shape=(2, 4, 3))
350+
y = pt.tensor("y", shape=(2, 5, 3))
351+
packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1))
352+
# packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
353+
# packed_shapes is [(4,), (5,)]
354+
"""
273355
n_before, n_after, min_axes = _analyze_axes_list(axes)
274356

275357
if all([n_before == 0, n_after == 0, min_axes == 0]):
@@ -311,7 +393,35 @@ def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None):
311393
return join(n_before, *reshaped_tensors), packed_shapes
312394

313395

314-
def unpack(packed_input, axes, packed_shapes):
396+
def unpack(
397+
packed_input: TensorLike,
398+
axes: int | Sequence[int] | None,
399+
packed_shapes: list[ShapeValueType],
400+
) -> list[TensorVariable]:
401+
"""
402+
Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping.
403+
404+
The unpacking process reverses the packing operation, restoring the original shapes of the input tensors. `axes`
405+
corresponds to the axes that were preserved during packing, and `packed_shapes` contains the shapes of the raveled
406+
dimensions for each output tensor (that is, the shapes that were destroyed during packing).
407+
408+
The signature of unpack is such that the same `axes` should be passed to both `pack` and `unpack` to create a
409+
"round-trip" operation. For details on the rules for `axes`, see the documentation for `pack`.
410+
411+
Parameters
412+
----------
413+
packed_input : TensorLike
414+
The packed tensor to be unpacked.
415+
axes : int, sequence of int, or None
416+
Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used.
417+
packed_shapes : list of ShapeValueType
418+
A list containing the shapes of the raveled dimensions for each output tensor.
419+
420+
Returns
421+
-------
422+
unpacked_tensors : list of TensorLike
423+
A list of unpacked tensors with their original shapes restored.
424+
"""
315425
if axes is None:
316426
if packed_input.ndim != 1:
317427
raise ValueError(

0 commit comments

Comments
 (0)