|
12 | 12 | from pytensor.tensor.math import prod |
13 | 13 | from pytensor.tensor.shape import ShapeValueType |
14 | 14 | from pytensor.tensor.type import tensor |
15 | | -from pytensor.tensor.variable import TensorConstant |
| 15 | +from pytensor.tensor.variable import TensorConstant, TensorVariable |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class JoinDims(Op): |
@@ -269,7 +269,89 @@ def find_gaps(s): |
269 | 269 | return n_before, n_after, min_axes |
270 | 270 |
|
271 | 271 |
|
272 | | -def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None): |
| 272 | +def pack( |
| 273 | + *tensors: TensorLike, axes: Sequence[int] | int | None = None |
| 274 | +) -> tuple[TensorVariable, list[ShapeValueType]]: |
| 275 | + """ |
| 276 | + Combine multiple tensors by preserving the specified axes and raveling the rest into a single axis. |
| 277 | +
|
| 278 | + Parameters |
| 279 | + ---------- |
| 280 | + *tensors : TensorLike |
| 281 | + Input tensors to be packed. |
| 282 | + axes : int, sequence of int, or None, optional |
| 283 | + Axes to preserve during packing. If None, all axes are raveled. See the Notes section for the rules. |
| 284 | +
|
| 285 | + Returns |
| 286 | + ------- |
| 287 | + packed_tensor : TensorLike |
| 288 | + The packed tensor with specified axes preserved and others raveled. |
| 289 | + packed_shapes : list of ShapeValueType |
| 290 | + A list containing the shapes of the raveled dimensions for each input tensor. |
| 291 | +
|
| 292 | + Notes |
| 293 | + ----- |
| 294 | + The `axes` parameter determines which axes are preserved during packing. Axes can be specified using positive or |
| 295 | + negative indices, but must follow these rules: |
| 296 | + - If axes is None, all axes are raveled. |
| 297 | + - If a single integer is provided, it can be positive or negative, and can take any value up to the smallest |
| 298 | + number of dimensions among the input tensors. |
| 299 | + - If a list is provided, it can be all positive, all negative, or a combination of positive and negative. |
| 300 | + - Positive axes must be contiguous and start from 0. |
| 301 | + - Negative axes must be contiguous and end at -1. |
| 302 | + - If positive and negative axes are combined, positive axes must come before negative axes, and both 0 and -1 |
| 303 | + must be included. |
| 304 | +
|
| 305 | + Examples |
| 306 | + -------- |
| 307 | + The easiest way to understand pack is through examples. The simplest case is using axes=None, which is equivalent |
| 308 | + to ``join(0, *[t.ravel() for t in tensors])``: |
| 309 | +
|
| 310 | + .. code-block:: python |
| 311 | + import pytensor.tensor as pt |
| 312 | +
|
| 313 | + x = pt.tensor("x", shape=(2, 3)) |
| 314 | + y = pt.tensor("y", shape=(4, 5, 6)) |
| 315 | +
|
| 316 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=None) |
| 317 | + # packed_tensor has shape (6 + 120,) == (126,) |
| 318 | + # packed_shapes is [(2, 3), (4, 5, 6)] |
| 319 | +
|
| 320 | + If we want to preserve a single axis, we can use either positive or negative indexing. Notice that all tensors |
| 321 | + must have the same size along the preserved axis. For example, using axes=0: |
| 322 | +
|
| 323 | + .. code-block:: python |
| 324 | + import pytensor.tensor as pt |
| 325 | +
|
| 326 | + x = pt.tensor("x", shape=(2, 3)) |
| 327 | + y = pt.tensor("y", shape=(2, 5, 6)) |
| 328 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=0) |
| 329 | + # packed_tensor has shape (2, 3 + 30) == (2, 33) |
| 330 | + # packed_shapes is [(3,), (5, 6)] |
| 331 | +
|
| 332 | +
|
| 333 | + Using negative indexing we can preserve the last two axes: |
| 334 | +
|
| 335 | + .. code-block:: python |
| 336 | + import pytensor.tensor as pt |
| 337 | +
|
| 338 | + x = pt.tensor("x", shape=(4, 2, 3)) |
| 339 | + y = pt.tensor("y", shape=(5, 2, 3)) |
| 340 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=(-2, -1)) |
| 341 | + # packed_tensor has shape (4 + 5, 2, 3) == (9, 2, 3) |
| 342 | + # packed_shapes is [(4,), (5, |
| 343 | +
|
| 344 | + Or using a mix of positive and negative axes, we can preserve the first and last axes: |
| 345 | +
|
| 346 | + .. code-block:: python |
| 347 | + import pytensor.tensor as pt |
| 348 | +
|
| 349 | + x = pt.tensor("x", shape=(2, 4, 3)) |
| 350 | + y = pt.tensor("y", shape=(2, 5, 3)) |
| 351 | + packed_tensor, packed_shapes = pt.pack(x, y, axes=(0, -1)) |
| 352 | + # packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3) |
| 353 | + # packed_shapes is [(4,), (5,)] |
| 354 | + """ |
273 | 355 | n_before, n_after, min_axes = _analyze_axes_list(axes) |
274 | 356 |
|
275 | 357 | if all([n_before == 0, n_after == 0, min_axes == 0]): |
@@ -311,7 +393,35 @@ def pack(*tensors: TensorLike, axes: Sequence[int] | int | None = None): |
311 | 393 | return join(n_before, *reshaped_tensors), packed_shapes |
312 | 394 |
|
313 | 395 |
|
314 | | -def unpack(packed_input, axes, packed_shapes): |
| 396 | +def unpack( |
| 397 | + packed_input: TensorLike, |
| 398 | + axes: int | Sequence[int] | None, |
| 399 | + packed_shapes: list[ShapeValueType], |
| 400 | +) -> list[TensorVariable]: |
| 401 | + """ |
| 402 | + Unpack a packed tensor into multiple tensors by splitting along the specified axes and reshaping. |
| 403 | +
|
| 404 | + The unpacking process reverses the packing operation, restoring the original shapes of the input tensors. `axes` |
| 405 | + corresponds to the axes that were preserved during packing, and `packed_shapes` contains the shapes of the raveled |
| 406 | + dimensions for each output tensor (that is, the shapes that were destroyed during packing). |
| 407 | +
|
| 408 | + The signature of unpack is such that the same `axes` should be passed to both `pack` and `unpack` to create a |
| 409 | + "round-trip" operation. For details on the rules for `axes`, see the documentation for `pack`. |
| 410 | +
|
| 411 | + Parameters |
| 412 | + ---------- |
| 413 | + packed_input : TensorLike |
| 414 | + The packed tensor to be unpacked. |
| 415 | + axes : int, sequence of int, or None |
| 416 | + Axes that were preserved during packing. If None, the input is assumed to be 1D and axis 0 is used. |
| 417 | + packed_shapes : list of ShapeValueType |
| 418 | + A list containing the shapes of the raveled dimensions for each output tensor. |
| 419 | +
|
| 420 | + Returns |
| 421 | + ------- |
| 422 | + unpacked_tensors : list of TensorLike |
| 423 | + A list of unpacked tensors with their original shapes restored. |
| 424 | + """ |
315 | 425 | if axes is None: |
316 | 426 | if packed_input.ndim != 1: |
317 | 427 | raise ValueError( |
|
0 commit comments