|
8 | 8 | from pytensor.graph import Apply |
9 | 9 | from pytensor.graph.op import Op |
10 | 10 | from pytensor.tensor import TensorLike, as_tensor_variable |
11 | | -from pytensor.tensor.basic import join, split |
| 11 | +from pytensor.tensor.basic import expand_dims, join, split |
12 | 12 | from pytensor.tensor.math import prod |
13 | 13 | from pytensor.tensor.shape import ShapeValueType |
14 | 14 | from pytensor.tensor.type import tensor |
@@ -119,7 +119,7 @@ def perform(self, node, inputs, outputs): |
119 | 119 |
|
120 | 120 | def split_dims( |
121 | 121 | x: TensorLike, shape: ShapeValueType, axis: int | None = None |
122 | | -) -> Variable: |
| 122 | +) -> TensorVariable: |
123 | 123 | """Split a dimension of a tensor into multiple dimensions. |
124 | 124 |
|
125 | 125 | Parameters |
@@ -372,23 +372,31 @@ def pack( |
372 | 372 | f"Input {i} (zero indexed) to pack has {n_dim} dimensions, " |
373 | 373 | f"but axes={axes} assumes at least {min_axes} dimension{'s' if min_axes != 1 else ''}." |
374 | 374 | ) |
375 | | - |
376 | | - axis_after_packed_axes = n_dim - n_after |
377 | | - packed_shapes.append(input_tensor.shape[n_before:axis_after_packed_axes]) |
378 | | - |
379 | | - new_shape = ( |
380 | | - *input_tensor.shape[:n_before], |
381 | | - -1, |
382 | | - *input_tensor.shape[axis_after_packed_axes:], |
383 | | - ) |
384 | | - reshaped_tensors.append(input_tensor.reshape(new_shape)) |
385 | | - |
386 | | - # Using join_dims could look like this, but it does not insert extra shapes when needed. For example, it fails |
387 | | - # on pack(pt.tensor("x", shape=(3, )), pt.tensor("y", shape=(3, 3)), axes=0), because the first tensor needs to |
388 | | - # have its single dimension expanded before the join. |
389 | | - |
390 | | - # join_axes = {n_before, axis_after_packed_axes - 1} |
391 | | - # reshaped_tensors.append(join_dims(input_tensor, tuple(join_axes))) |
| 375 | + n_after_packed = n_dim - n_after |
| 376 | + packed_shapes.append(input_tensor.shape[n_before:n_after_packed]) |
| 377 | + |
| 378 | + if n_dim == min_axes: |
| 379 | + # If an input has the minimum number of axes, pack implicitly inserts a new axis based on the pattern |
| 380 | + # implied by the axes. If n_before == 0, the reshape would be (-1, ...), so we need to expand at axis 0. |
| 381 | + # If n_after == 0, the reshape would be (..., -1), so we need to expand at axis -1. If both are equal, |
| 382 | + # the reshape will occur in the center of the tensor. |
| 383 | + if n_before == 0: |
| 384 | + input_tensor = expand_dims(input_tensor, axis=0) |
| 385 | + elif n_after == 0: |
| 386 | + input_tensor = expand_dims(input_tensor, axis=-1) |
| 387 | + elif n_before == n_after: |
| 388 | + input_tensor = expand_dims(input_tensor, axis=n_before) |
| 389 | + |
| 390 | + reshaped_tensors.append(input_tensor) |
| 391 | + continue |
| 392 | + |
| 393 | + # The reshape we want is (shape[:before], -1, shape[n_after_packed:]). join_dims does (shape[:min(axes)], -1, |
| 394 | + # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the |
| 395 | + # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the |
| 396 | + # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing). |
| 397 | + join_axes = {n_before, n_after_packed - 1} |
| 398 | + joined = join_dims(input_tensor, tuple(join_axes)) |
| 399 | + reshaped_tensors.append(joined) |
392 | 400 |
|
393 | 401 | return join(n_before, *reshaped_tensors), packed_shapes |
394 | 402 |
|
|
0 commit comments