Skip to content

Conversation

@jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Aug 10, 2025

Description

Adds pt.pack and pt.unpack helpers, roughly conforming to the einops functions of the same name.

These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.

Example usage:

x = pt.tensor("x", shape=shapes[0])
y = pt.tensor("y", shape=shapes[1])
z = pt.tensor("z", shape=shapes[2])

flat_params, packed_shapes = pt.pack(x, y, z)

Unpack simply undoes the computation, although there's norewrite to ensure pt.unpack(*pt.pack(*inputs)) is the identity function:

x, y, z = pt.unpack(flat_params, packed_shapes)

The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:

loss = (x + y.sum() + z.sum()) ** 2

flat_packed, packed_shapes = pack(x, y, z)
new_input = flat_packed.type()
new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input], loss)

Note that the final compiled function depends only on new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:

  1. I decided to work with the static shapes directly if they are available. This means that pack will eagerly return a list of integer shapes as packed_shapes if possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.
  2. I didn't add support for batch dims. This is left to the user to do himself using pt.vectorize.
  3. The einops API has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself using DimShuffle and vectorize.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/

@jessegrabowski jessegrabowski added the enhancement New feature or request label Aug 10, 2025
@jessegrabowski
Copy link
Member Author

The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in RaveledArray and DictToArrayBijector, that could be replaced with appropriate symbolic operations.

One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the prod(shape) used to get the shape of the flat output. Not sure what to do in that case.

@jessegrabowski jessegrabowski force-pushed the pack-unpack branch 2 times, most recently from da89b9d to 9ead211 Compare August 10, 2025 08:31
@codecov
Copy link

codecov bot commented Aug 10, 2025

Codecov Report

❌ Patch coverage is 94.02985% with 8 lines in your changes missing coverage. Please review.
✅ Project coverage is 81.74%. Comparing base (1f9a67b) to head (0b86851).

Files with missing lines Patch % Lines
pytensor/tensor/extra_ops.py 94.02% 4 Missing and 4 partials ⚠️

❌ Your patch check has failed because the patch coverage (94.02%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1578      +/-   ##
==========================================
+ Coverage   81.70%   81.74%   +0.03%     
==========================================
  Files         246      246              
  Lines       53632    53763     +131     
  Branches     9438     9462      +24     
==========================================
+ Hits        43820    43946     +126     
- Misses       7330     7333       +3     
- Partials     2482     2484       +2     
Files with missing lines Coverage Δ
pytensor/tensor/extra_ops.py 88.88% <94.02%> (+0.95%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94
Copy link
Member

ricardoV94 commented Aug 10, 2025

  1. Better to have the same types as return, static shape to constant is introduced during rewrites already

2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims

@ricardoV94
Copy link
Member

I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not.

That would pretty much address #883

We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned.

@jessegrabowski
Copy link
Member Author

jessegrabowski commented Nov 2, 2025

I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not.

I pushed a commit that adds a "feature complete" Pack Op. It can do everything that einops.pack can do, and more. I'd say it's a bit on the overly complex side, but that's on brand for me. The API I cooked up could maybe be simplified.

Basically, you pack by selecting which axes you don't want to ravel. Axes should be None, int, or tuple[int].

If None, it's the same as packed_tensor = pt.join(0, *[var.ravel() for var in input_vars]):

x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=None)
packed_tensor.type.shape # (15,)

Once you pass in an integers, your inputs need to have the same shape on the dimensions of concatenation. All dimensions that are in a "hole" of the provided axes is raveled and joined. What's a hole? You can have explicit or implicit holes. An "explicit" hole is a gap in the integers you provide. For example, [0, -1] has an explicit hole: all dimensions except the first and last will raveled and joined.

x = pt.tensor("x", shape=(5, 3))
y = pt.tensor("y", shape=(5, 2, 4, 3))
z = pt.tensor("z", shape=(5, 6, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,-1])
packed_tensor.type.shape # (5, 13, 3)

axes=[0,3] also looks like an explicit hole -- dimensions 1 and 2 will be raveled and joined. But there's also an implicit hole, because there could be dimensions beyond 3. That makes 2 holes, which is an invalid pack. To resolve this, we assume that 3 is the maximum dimension size -- you need to pass at least one tensor with ndims==4 (since we want to concate on axis=3), and no input can have more. Under those conditions, axes=[0, 3] is treated the same as axes=[0, -1]:

x = pt.tensor('x', shape=(2, 6))
y = pt.tensor('y', shape=(2, 3, 7, 6))
z = pt.tensor('z', shape=(2, 10, 6))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,3])
packed_tensor.type.shape # (2, 32, 6)

Minimum size is also enforced -- we couldn't have passed in x = pt.tensor('x', shape=(2,)) for example, because we need at least 2 dimensions to concatenate, because we asked for 2 pack axes.

I could imagine this being more strict -- axes=[0, 3] could also imply that all inputs are ndim==4. The reason I didn't do it this way is because I wanted to match the feel of einops. axis=[0, 3] feels like it should correspond to something like i * j, with at most 2 dimensions inside the ellipsis.

Pack could probably also be an OpFromGraph, there's nothing special going on in the perform method. That would be better because then we get the gradients for free.

@ricardoV94
Copy link
Member

What if we make 1) axis be the ones to ravel and 2) you have to specify axis per input? You can have a single number / list of ints but then only valid if all inputs have the same ndim?

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 2, 2025

Ok 1) is not intuitive, but 2 may still have some merit? It's strictly more powerful and maybe less magical to specify the axes for each input. It also means we can normalize them to be positive and I think simplify the code analysis?

In your use cases this would be terrible UX?

@ricardoV94
Copy link
Member

BTW I'm not bashing on the idea. On the contrary I quite like it. I just lost some of the context on the PR and I'm being lazy about getting it back.

One thing that would be nice to prove the API is to refactor the batch dot rewrite (dot to batched matmul or whatever is called) to use pack. This is my motivating case where I wanted this functionality.

@jessegrabowski
Copy link
Member Author

The main use-case I have in mind for this is in optimize/pymc where we get parameters in arbitrary shapes, but we want to pack them into a single vector and do a replacement of the original variables with elements of that vector (see the test_make_replacements_with_pack_unpack test for what I have in mind).

Having to specify the number of dims ahead of time in that case doesn't work, because we don't know what the user will give us.

@ricardoV94
Copy link
Member

Why do you need to specify dims before you get to see the inputs?

new_outputs = unpack(new_input, packed_shapes)

loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs)))
fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that for this example to make sense you want to rewrite the graph after the replace to get rid of the dependency on x, y, z since shapes are static

@ricardoV94
Copy link
Member

ricardoV94 commented Nov 8, 2025

Here is the snippet from our discussion:

API

# Axis flavor
## Single variable
assert A.shape == (4, 3, 12)
B = pt.unpack(A, axis=-1, shapes=[(2, 6)])
assert B.shape == (4, 3, 2, 6)
[A], [(2, 6)] = pt.pack([B], axis_list=[(-2, -1)])

## Multiple variables
assert A.shape == (4, 3, 12)
b, c = pt.unpack(A, axis=-1, shapes=[(2, 3), (6,)])
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])

## Vectorize_graph (partial axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])
assert A.shape == (4, 3, 12)
bA = vectorize(A, {b: bb, c:bc})
bA, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])
assert bA.shape == (?, 4, 3, 12)

## Vectorize graph (all axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(4, 3, 2, 3), (4, 3, 6,)] = pt.pack([b, c], axis_list=[(0, 1, 2, 3), (0, 1, 2)])
assert A.shape == (144,)
bA = vectorize(A, {b: bb, c:bc})
bA, [(4, 3, 2, 3), (4, 3, 6)] = pt.pack([b, c], axis_list=[(1, 2, 3, 4), (1, 2, 3)])
assert bA.shape == (?, 144)


# Keep axis flavor
## Single variable
assert A.shape == (4, 3, 12)
B = pt.unpack(A, keep_axis=(0, 1), shapes=[(2, 6)])
assert B.shape == (4, 3, 2, 6)
[A_again], [(2, 6)] = pt.pack([B], keep_axis=(0, 1))

##  Multiple variables
assert A.shape == (4, 3, 12)
b, c = pt.unpack(A, keep_axis=(0, 1), shapes=[(2, 3), (6,)])
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1))

## Vectorize_graph (partial axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1))
assert A.shape == (4, 3, 12)
bA = vectorize(A, {b: bb, c:bc})
bA, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1, 2))
assert bA.shape == (?, 4, 3, 12)

## Vectorize graph (all axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(4, 3, 2, 3), (4, 3, 6,)] = pt.pack([b, c], keep_axis=None | ())
assert A.shape == (144,)
bA = vectorize(A, {b: bb, c:bc})
bA, [(4, 3, 2, 3), (4, 3, 6)] = pt.pack([b, c], keep_axis=(0,))
assert bA.shape == (?, 144)

Implementation

def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
    axis = normalize_axis_tuple(x.ndim, axis)
    if np.diff(axis).max() > 1:
        raise ValueError(f"join_dims axis must be consecutive, got normalized axis: {axis}")
    x_shape = tuple(x.shape)
    return x.reshape((*x_shape[:min(axis)], -1, *x_shape[max(axis)+1:]))
    
def split_dims(x, axis: int | None = None, shape: ShapeVariable) -> Variable:
    if axis is None:
        if x.ndim != 1:
            raise ValueError("split_dims can only be called with axis=None for 1d inputs")
        axis = 0
    new_shape = list(x.shape)
    new_shape[axis] = shape
    return x.reshape(tuple(new_shape))

## Axis flavor
def pack(inputs: Sequence[Variable], axis_list: Sequence[Sequence[int], | None] | None) -> Variable, Sequence[ShapeVariable]:
    if axis_list is None:
        axis_list = [None] * len(inputs)
    packed_inputs, packed_shapes = zip(*[join_dims(inp, axis) for inp, axis in zip(inputs, axis_list, strict=True)])
    join_axis = min(normalize_axis_tuple(inputs[0].ndim, axis_list[0]))
    packed_input = pt.join(packed_inputs, axis=axis)
    return packed_input, packed_shapes
  
def unpack(packed_input: Variable, axis: int | None = None, unpacked_shapes: Sequence[ShapeVariable]) -> list[Variable]:
    if axis is None:
      if packed_input.ndim != 1:
          raise ValueError("unpack can only be called with axis=None for 1d inputs")
      axis = 0
    split_inputs = split(packed_inputs, n_splits=len(shapes), axis=axis)
    return [split_dims(inp, axis, shape) for inp, shape in zip(split_inputs, shapes, strict=True)]

## Keep axis flavor
def pack(inputs, keep_axis):
    packed_inputs, packed_shapes = zip(*[
        join_dims(
            inp, 
            [
                i for i in range(inp.ndim) 
                if i not in normalize_axis_tuple(inp.ndim, keep_axis)
            ]
        )
        for inp in inputs
    ])
    join_axis = min(normalize_axis_tuple(inputs[0].ndim, keep_axis))
    packed_input = pt.join(packed_inputs, axis=axis)
    return packed_input, packed_shapes
  
def unpack(packed_input, keep_axis, unpacked_shapes):
    if keep_axis is None:
        if packed_input.ndim != 1:
            raise ValueError("unpack can only be called with keep_axis=None for 1d inputs")
        split_axis = 0
    else:
        keep_axis = normalize_axis_tuple(unpacked_input, keep_axis)
        try:
            [split_axis] = (i for i in range(packed_input.ndim) if i not in keep_axis)
        except ValueError as err
            raise ValueError("Unpack must have exactly one more dimension that implied by keep_axis") from err
        
    split_inputs = split(packed_inputs, n_splits=len(shapes), axis=split_axis[0])
        
    return [split_dims(inp, split_axis, shape) for inp, shape in zip(split_inputs, shapes, strict=True)]

I think keep_axis API is nicer for users as you can reuse the same axis in the pack/unpack definitions. It's quite more complex internally, though :(

Did I miss something with my implementation of pack with keep_axis? Doesn't seem to need any extra error handling besides whatever the underlying Ops already do. Unpack needed something to make sure you have exactly one dimension not specified in keep_axis

@jessegrabowski jessegrabowski force-pushed the pack-unpack branch 2 times, most recently from f6a0e12 to 80020e4 Compare November 27, 2025 00:16
@jessegrabowski
Copy link
Member Author

jessegrabowski commented Nov 27, 2025

@ricardoV94 I soft rebooted this PR, starting from join_dims and split_dims. Let me know how those look to you. Next commit will bring back pack/unpack.

The only thing you missed with unpack_dims was that the incoming shapes can (should?) also be a list, so we need to insert them into the old shape list.

@jessegrabowski jessegrabowski force-pushed the pack-unpack branch 2 times, most recently from bc2cde3 to 510fc48 Compare November 27, 2025 02:45
@jessegrabowski
Copy link
Member Author

I created a new file, reshape_ops, and moved join_dims, split_dims, pack, and unpack there. I also created new Ops: JoinDims and SplitDims. I haven't done the rewrites yet, because I'm not confident that these are what you had in mind.

pack still doesn't use join_dims, because of an important corner case where pack needs to create new dimensions. I left a note in the code, but the case is something like:

x = pt.tensor('x', shape=(3,))
y = pt.tensor('y', shape=(3, 5))
z, shapes = pt.pack([x, y], axes=0)

To pack this, x needs to have a 1 inserted on the right-hand side. The current code does x.reshape(3, -1), which does the insertion. Using join_dims(x, 0) will return (3, ), which then errors out during join.

I'll keep trying, but I'm struggling to see a way to cover all cases.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, ops still need some work. right now you're breaking inplace API by not stating the outputs are views

(out,) = outputs

output_shape = [
*x.shape[: min(self.axis)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parametrize by first axis + number of axes instead? Seems sufficient

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer the axis list. It strikes me as more natural.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has no purpose? The user facing function can use axis but why the op?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has no purpose? The user facing function can use axis but why the op?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could also be from axis - to axis. Does it make pack code more simple for instance?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note you can still str repr and add an Op property axis to build you the axis, but internally it seems to only complicate the logic? For the vectorize, having start + length or start - end would make things much easier, just shift by the start (and if there's an end) by number of batch_ndim

self.axis = axis

def _make_output_shape(self, input_shape, shape):
[axis] = normalize_axis_tuple(self.axis, len(input_shape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it already normalized by not allowing negative?

def _get_constant_shape(x):
try:
# get_underling_scalar_constant_value returns a numpy scalar, we need a python int
return get_underlying_scalar_constant_value(x).item()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need underlying because it has to be scalar, so use get_scalar_constant_value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't make sense for perform either, it's a graph extraction operation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't make sense for perform either, it's a graph extraction operation

def infer_shape(self, fgraph, node, shapes):
[input_shape, _] = shapes
_, shape = node.inputs
output_shape = self._make_output_shape(input_shape, shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't use graph analysis in the symbolic infer shape just use the provided inputs

if all([n_before == 0, n_after == 0, min_axes == 0]):
# Special case -- we're raveling everything
packed_shapes = [t.shape for t in tensor_list]
reshaped_tensors = [t.ravel() for t in tensor_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why revel instead of join dims? It's still a reshape under the hood. May seem like a simple one, but once we vectorize becomes already ugly

# shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the
# rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the
# corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing).
join_axes = {n_before, n_after_packed - 1}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be a range?

from pytensor.tensor.variable import TensorVariable


class JoinDims(Op):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add vectorize node already for the new ops?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement pack/unpack Ops

2 participants