@@ -2259,7 +2259,9 @@ class Pack(OpFromGraph):
22592259
22602260
22612261def pack (
2262- * tensors : TensorVariable , axes : int | Sequence [int ] | None = None
2262+ * tensors : TensorVariable ,
2263+ axes : int | Sequence [int ] | None = None ,
2264+ preserve_axes : int | Sequence [int ] | None = None ,
22632265) -> tuple [TensorVariable , list [tuple [TensorVariable ]]]:
22642266 """
22652267 Given a list of tensors of varying shapes and dimensions, ravels and concatenates them into a single 1d vector.
@@ -2270,10 +2272,11 @@ def pack(
22702272 Tensors to be packed. Tensors can have varying shapes and dimensions, but must have the same size along each
22712273 of the dimensions specified in the `axes` parameter.
22722274 axes: int or sequence of int, optional
2273- Axes to be preserved . All other axes will be raveled (packed) , and the output is the result of concatenating
2274- on the new raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
2275+ Axes to be raveled . All other axes will preserved , and the output is the result of concatenating
2276+ on the raveled dimension. If None, all axes will be raveled and joined. Axes can be either positive or
22752277 negative, but must be striclty increasing in both the positive and negative parts of the list. Negative axes
2276- must come after positive axes.
2278+ must come after positive axes. Only one of `axes` or `preserve_axes` can be specified.
2279+ preserve_axes: int or sequence of int, optional
22772280
22782281 Returns
22792282 -------
0 commit comments