-
Notifications
You must be signed in to change notification settings - Fork 149
Implement pack/unpack helpers #1578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
36422cd to
cf633e7
Compare
|
The pack -> type -> unpack -> replace pattern might be common enough to merit it's own helper. PyMC has tools for doing this, for example, in One other thing I forgot to mention is that this will all fail on inputs with shape 0, since that will ruin the |
da89b9d to
9ead211
Compare
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (94.02%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1578 +/- ##
==========================================
+ Coverage 81.70% 81.74% +0.03%
==========================================
Files 246 246
Lines 53632 53763 +131
Branches 9438 9462 +24
==========================================
+ Hits 43820 43946 +126
- Misses 7330 7333 +3
- Partials 2482 2484 +2
🚀 New features to boost your workflow:
|
2 and 3. I would really like to have these, it's what I needed for the batched_dot_to_core rewrites.This isn't a simple case of vectorize because the dims I want to pack are both on the left and right of other dims |
|
I am inclined to making this a core op and not just a helper. It obliviates most uses of reshape and it's much easier to reason about, not having to worry about pesky -1 or whether the reshape shape comes from the original input shapes or not. That would pretty much address #883 We could use OFG and/or specialize to reshape/split later. It need also not be done in this PR. It's an implementation detail as far as the user is concerned. |
860a7ab to
fcbd0af
Compare
fcbd0af to
5788333
Compare
I pushed a commit that adds a "feature complete" Basically, you pack by selecting which axes you don't want to ravel. Axes should be None, int, or tuple[int]. If None, it's the same as x = pt.tensor("x", shape=())
y = pt.tensor("y", shape=(5,))
z = pt.tensor("z", shape=(3, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=None)
packed_tensor.type.shape # (15,)Once you pass in an integers, your inputs need to have the same shape on the dimensions of concatenation. All dimensions that are in a "hole" of the provided axes is raveled and joined. What's a hole? You can have explicit or implicit holes. An "explicit" hole is a gap in the integers you provide. For example, x = pt.tensor("x", shape=(5, 3))
y = pt.tensor("y", shape=(5, 2, 4, 3))
z = pt.tensor("z", shape=(5, 6, 3))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,-1])
packed_tensor.type.shape # (5, 13, 3)
x = pt.tensor('x', shape=(2, 6))
y = pt.tensor('y', shape=(2, 3, 7, 6))
z = pt.tensor('z', shape=(2, 10, 6))
packed_tensor, packed_shapes = pt.pack(x, y, z, axes=[0,3])
packed_tensor.type.shape # (2, 32, 6)Minimum size is also enforced -- we couldn't have passed in I could imagine this being more strict --
|
|
What if we make 1) axis be the ones to ravel and 2) you have to specify axis per input? You can have a single number / list of ints but then only valid if all inputs have the same ndim? |
|
Ok 1) is not intuitive, but 2 may still have some merit? It's strictly more powerful and maybe less magical to specify the axes for each input. It also means we can normalize them to be positive and I think simplify the code analysis? In your use cases this would be terrible UX? |
|
BTW I'm not bashing on the idea. On the contrary I quite like it. I just lost some of the context on the PR and I'm being lazy about getting it back. One thing that would be nice to prove the API is to refactor the batch dot rewrite (dot to batched matmul or whatever is called) to use pack. This is my motivating case where I wanted this functionality. |
|
The main use-case I have in mind for this is in optimize/pymc where we get parameters in arbitrary shapes, but we want to pack them into a single vector and do a replacement of the original variables with elements of that vector (see the Having to specify the number of dims ahead of time in that case doesn't work, because we don't know what the user will give us. |
|
Why do you need to specify dims before you get to see the inputs? |
tests/tensor/test_extra_ops.py
Outdated
| new_outputs = unpack(new_input, packed_shapes) | ||
|
|
||
| loss = pytensor.graph.graph_replace(loss, dict(zip([x, y, z], new_outputs))) | ||
| fn = pytensor.function([new_input, x, y, z], loss, mode="FAST_COMPILE") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that for this example to make sense you want to rewrite the graph after the replace to get rid of the dependency on x, y, z since shapes are static
|
Here is the snippet from our discussion: API# Axis flavor
## Single variable
assert A.shape == (4, 3, 12)
B = pt.unpack(A, axis=-1, shapes=[(2, 6)])
assert B.shape == (4, 3, 2, 6)
[A], [(2, 6)] = pt.pack([B], axis_list=[(-2, -1)])
## Multiple variables
assert A.shape == (4, 3, 12)
b, c = pt.unpack(A, axis=-1, shapes=[(2, 3), (6,)])
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])
## Vectorize_graph (partial axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])
assert A.shape == (4, 3, 12)
bA = vectorize(A, {b: bb, c:bc})
bA, [(2, 3), (6,)] = pt.pack([b, c], axis_list=[(-2, -1), (-1,)])
assert bA.shape == (?, 4, 3, 12)
## Vectorize graph (all axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(4, 3, 2, 3), (4, 3, 6,)] = pt.pack([b, c], axis_list=[(0, 1, 2, 3), (0, 1, 2)])
assert A.shape == (144,)
bA = vectorize(A, {b: bb, c:bc})
bA, [(4, 3, 2, 3), (4, 3, 6)] = pt.pack([b, c], axis_list=[(1, 2, 3, 4), (1, 2, 3)])
assert bA.shape == (?, 144)
# Keep axis flavor
## Single variable
assert A.shape == (4, 3, 12)
B = pt.unpack(A, keep_axis=(0, 1), shapes=[(2, 6)])
assert B.shape == (4, 3, 2, 6)
[A_again], [(2, 6)] = pt.pack([B], keep_axis=(0, 1))
## Multiple variables
assert A.shape == (4, 3, 12)
b, c = pt.unpack(A, keep_axis=(0, 1), shapes=[(2, 3), (6,)])
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1))
## Vectorize_graph (partial axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1))
assert A.shape == (4, 3, 12)
bA = vectorize(A, {b: bb, c:bc})
bA, [(2, 3), (6,)] = pt.pack([b, c], keep_axis=(0, 1, 2))
assert bA.shape == (?, 4, 3, 12)
## Vectorize graph (all axis)
assert b.shape == (4, 3, 2, 3)
assert c.shape == (4, 3, 6)
A, [(4, 3, 2, 3), (4, 3, 6,)] = pt.pack([b, c], keep_axis=None | ())
assert A.shape == (144,)
bA = vectorize(A, {b: bb, c:bc})
bA, [(4, 3, 2, 3), (4, 3, 6)] = pt.pack([b, c], keep_axis=(0,))
assert bA.shape == (?, 144)Implementationdef join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
axis = normalize_axis_tuple(x.ndim, axis)
if np.diff(axis).max() > 1:
raise ValueError(f"join_dims axis must be consecutive, got normalized axis: {axis}")
x_shape = tuple(x.shape)
return x.reshape((*x_shape[:min(axis)], -1, *x_shape[max(axis)+1:]))
def split_dims(x, axis: int | None = None, shape: ShapeVariable) -> Variable:
if axis is None:
if x.ndim != 1:
raise ValueError("split_dims can only be called with axis=None for 1d inputs")
axis = 0
new_shape = list(x.shape)
new_shape[axis] = shape
return x.reshape(tuple(new_shape))
## Axis flavor
def pack(inputs: Sequence[Variable], axis_list: Sequence[Sequence[int], | None] | None) -> Variable, Sequence[ShapeVariable]:
if axis_list is None:
axis_list = [None] * len(inputs)
packed_inputs, packed_shapes = zip(*[join_dims(inp, axis) for inp, axis in zip(inputs, axis_list, strict=True)])
join_axis = min(normalize_axis_tuple(inputs[0].ndim, axis_list[0]))
packed_input = pt.join(packed_inputs, axis=axis)
return packed_input, packed_shapes
def unpack(packed_input: Variable, axis: int | None = None, unpacked_shapes: Sequence[ShapeVariable]) -> list[Variable]:
if axis is None:
if packed_input.ndim != 1:
raise ValueError("unpack can only be called with axis=None for 1d inputs")
axis = 0
split_inputs = split(packed_inputs, n_splits=len(shapes), axis=axis)
return [split_dims(inp, axis, shape) for inp, shape in zip(split_inputs, shapes, strict=True)]
## Keep axis flavor
def pack(inputs, keep_axis):
packed_inputs, packed_shapes = zip(*[
join_dims(
inp,
[
i for i in range(inp.ndim)
if i not in normalize_axis_tuple(inp.ndim, keep_axis)
]
)
for inp in inputs
])
join_axis = min(normalize_axis_tuple(inputs[0].ndim, keep_axis))
packed_input = pt.join(packed_inputs, axis=axis)
return packed_input, packed_shapes
def unpack(packed_input, keep_axis, unpacked_shapes):
if keep_axis is None:
if packed_input.ndim != 1:
raise ValueError("unpack can only be called with keep_axis=None for 1d inputs")
split_axis = 0
else:
keep_axis = normalize_axis_tuple(unpacked_input, keep_axis)
try:
[split_axis] = (i for i in range(packed_input.ndim) if i not in keep_axis)
except ValueError as err
raise ValueError("Unpack must have exactly one more dimension that implied by keep_axis") from err
split_inputs = split(packed_inputs, n_splits=len(shapes), axis=split_axis[0])
return [split_dims(inp, split_axis, shape) for inp, shape in zip(split_inputs, shapes, strict=True)]I think Did I miss something with my implementation of |
f6a0e12 to
80020e4
Compare
|
@ricardoV94 I soft rebooted this PR, starting from The only thing you missed with |
bc2cde3 to
510fc48
Compare
|
I created a new file,
x = pt.tensor('x', shape=(3,))
y = pt.tensor('y', shape=(3, 5))
z, shapes = pt.pack([x, y], axes=0)To pack this, I'll keep trying, but I'm struggling to see a way to cover all cases. |
ricardoV94
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, ops still need some work. right now you're breaking inplace API by not stating the outputs are views
| (out,) = outputs | ||
|
|
||
| output_shape = [ | ||
| *x.shape[: min(self.axis)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parametrize by first axis + number of axes instead? Seems sufficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer the axis list. It strikes me as more natural.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has no purpose? The user facing function can use axis but why the op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has no purpose? The user facing function can use axis but why the op?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could also be from axis - to axis. Does it make pack code more simple for instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note you can still str repr and add an Op property axis to build you the axis, but internally it seems to only complicate the logic? For the vectorize, having start + length or start - end would make things much easier, just shift by the start (and if there's an end) by number of batch_ndim
670132b to
f7644b6
Compare
| self.axis = axis | ||
|
|
||
| def _make_output_shape(self, input_shape, shape): | ||
| [axis] = normalize_axis_tuple(self.axis, len(input_shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it already normalized by not allowing negative?
| def _get_constant_shape(x): | ||
| try: | ||
| # get_underling_scalar_constant_value returns a numpy scalar, we need a python int | ||
| return get_underlying_scalar_constant_value(x).item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need underlying because it has to be scalar, so use get_scalar_constant_value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't make sense for perform either, it's a graph extraction operation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't make sense for perform either, it's a graph extraction operation
| def infer_shape(self, fgraph, node, shapes): | ||
| [input_shape, _] = shapes | ||
| _, shape = node.inputs | ||
| output_shape = self._make_output_shape(input_shape, shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you shouldn't use graph analysis in the symbolic infer shape just use the provided inputs
| if all([n_before == 0, n_after == 0, min_axes == 0]): | ||
| # Special case -- we're raveling everything | ||
| packed_shapes = [t.shape for t in tensor_list] | ||
| reshaped_tensors = [t.ravel() for t in tensor_list] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why revel instead of join dims? It's still a reshape under the hood. May seem like a simple one, but once we vectorize becomes already ugly
| # shape[max(axes)+1:]). So this will work if we choose axes=(n_before, n_after_packed - 1). Because of the | ||
| # rules on the axes input, we will always have n_before <= n_after_packed - 1. A set is used here to cover the | ||
| # corner case when n_before == n_after_packed - 1 (i.e., when there is only one axis to ravel --> do nothing). | ||
| join_axes = {n_before, n_after_packed - 1} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't this be a range?
| from pytensor.tensor.variable import TensorVariable | ||
|
|
||
|
|
||
| class JoinDims(Op): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add vectorize node already for the new ops?
Description
Adds pt.pack and pt.unpack helpers, roughly conforming to the
einopsfunctions of the same name.These helps are for situations where we have a ragged list of inputs that need to be raveled into a single flat list for some intermediate step. This occurs in places like optimization.
Example usage:
Unpack simply undoes the computation, although there's norewrite to ensure
pt.unpack(*pt.pack(*inputs))is the identity function:The use-case I forsee is creating replacement for a function of the inputs we're packing, for example:
Note that the final compiled function depends only on
new_input, only because the shapes of the 3 packed variables were statically known. This leads to my design choices section:packwill eagerly return a list of integer shapes aspacked_shapesif possible. If not possible, they will be symbolic shapes. This is maybe an anti-pattern -- we might prefer a rewrite to handle this later, but it seemed easy enough to do eagerly.pt.vectorize.einopsAPI has arguments to support packing/unpacking on arbitrary subsets of dimensions. I didn't do this, because I couldn't think of a use-case that a user couldn't get himself usingDimShuffleandvectorize.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1578.org.readthedocs.build/en/1578/