Skip to content

Commit 1b74d21

Browse files
.wip
1 parent 5be83d1 commit 1b74d21

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2259,7 +2259,9 @@ class Pack(OpFromGraph):
22592259

22602260

22612261
def pack(
2262-
*tensors: TensorVariable, axes: int | Sequence[int] | None = None
2262+
*tensors: TensorVariable,
2263+
axes: int | Sequence[int] | None = None,
2264+
preserve_axes: int | Sequence[int] | None = None,
22632265
) -> tuple[TensorVariable, list[tuple[TensorVariable]]]:
22642266
"""
22652267
Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2270,10 +2272,11 @@ def pack(
22702272
Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each
22712273
of the dimensions specified in the `axes` parameter.
22722274
axes: int or sequence of int, optional
2273-
Axes to be preserved. All other axes will be raveled (packed), and the output is the result of concatenating
2274-
on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
2275+
Axes to be raveled. All other axes will preserved, and the output is the result of concatenating
2276+
on the raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
22752277
negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes
2276-
must come after positive axes.
2278+
must come after positive axes. Only one of `axes` or `preserve_axes` can be specified.
2279+
preserve_axes: int or sequence of int, optional
22772280
22782281
Returns
22792282
-------

0 commit comments

Comments
 (0)