Skip to content

Commit 451f313

Browse files
Address feedback
1 parent ef43a80 commit 451f313

File tree

2 files changed

+121
-48
lines changed

2 files changed

+121
-48
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 100 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from pytensor import Variable
99
from pytensor.graph import Apply
1010
from pytensor.graph.op import Op
11+
from pytensor.graph.replace import _vectorize_node
1112
from pytensor.tensor import TensorLike, as_tensor_variable
1213
from pytensor.tensor.basic import (
14+
atleast_1d,
1315
expand_dims,
14-
get_underlying_scalar_constant_value,
16+
get_scalar_constant_value,
1517
join,
1618
split,
1719
)
@@ -23,60 +25,73 @@
2325

2426

2527
class JoinDims(Op):
26-
__props__ = ("axis",)
28+
__props__ = (
29+
"start_axis",
30+
"n_axes",
31+
)
2732
view_map = {0: [0]}
2833

29-
def __init__(self, axis: Sequence[int]):
30-
if any(i < 0 for i in axis):
31-
raise ValueError("JoinDims axis must be non-negative")
34+
def __init__(self, input_ndims: int, start_axis: int | None, n_axes: int | None):
35+
if start_axis < 0:
36+
raise ValueError("JoinDims start_axis must be non-negative")
3237

33-
if len(axis) > 1 and np.diff(axis).max() > 1:
34-
raise ValueError(
35-
f"join_dims axis must be consecutive, got normalized axis: {axis}"
36-
)
38+
self.start_axis = start_axis
39+
self.n_axes = n_axes
40+
self.input_ndims = input_ndims
3741

38-
self.axis = axis
42+
output_ndims = 1 if not start_axis else min(1, input_ndims - n_axes)
43+
44+
input_signature = ",".join(f"i{i}" for i in range(input_ndims))
45+
output_signature = ",".join(f"o{i}" for i in range(output_ndims))
46+
47+
self.gufunc_signature = f"({input_signature})->({output_signature})"
48+
49+
@property
50+
def axis_range(self):
51+
return range(self.start_axis, self.start_axis + self.n_axes)
52+
53+
def output_shapes(self, input_shapes, joined_shape):
54+
return (
55+
*input_shapes[: self.start_axis],
56+
joined_shape,
57+
*input_shapes[self.start_axis + self.n_axes :],
58+
)
3959

4060
def make_node(self, x: Variable) -> Apply: # type: ignore[override]
4161
static_shapes = x.type.shape
42-
if x.type.ndim < max(self.axis) + 1:
62+
if x.type.ndim != self.input_ndims:
4363
raise ValueError(
44-
f"Input ndim {x.type.ndim} is less than the maximum axis {max(self.axis)} + 1"
64+
f"Input ndim {x.type.ndim} is not equal to expected ndim {self.input_ndims}"
4565
)
66+
67+
axis_range = self.axis_range
68+
4669
joined_shape = (
47-
int(np.prod([static_shapes[i] for i in self.axis]))
48-
if all(static_shapes[i] is not None for i in self.axis)
70+
int(np.prod([static_shapes[i] for i in axis_range]))
71+
if all(static_shapes[i] is not None for i in axis_range)
4972
else None
5073
)
5174

52-
output_shapes = (
53-
*static_shapes[: min(self.axis)],
54-
joined_shape,
55-
*static_shapes[max(self.axis) + 1 :],
56-
)
57-
75+
output_shapes = self.output_shapes(static_shapes, joined_shape)
5876
output_type = tensor(shape=output_shapes, dtype=x.type.dtype)
77+
5978
return Apply(self, [x], [output_type])
6079

6180
def infer_shape(self, fgraph, node, shapes):
6281
[input_shape] = shapes
63-
joined_shape = prod([input_shape[i] for i in self.axis])
64-
out_shape = (
65-
*input_shape[: min(self.axis)],
66-
joined_shape,
67-
*input_shape[max(self.axis) + 1 :],
68-
)
82+
axis_range = self.axis_range
6983

70-
return [out_shape]
84+
joined_shape = prod([input_shape[i] for i in axis_range])
85+
return [self.output_shapes(input_shape, joined_shape)]
7186

7287
def perform(self, node, inputs, outputs):
7388
(x,) = inputs
7489
(out,) = outputs
7590

7691
output_shape = [
77-
*x.shape[: min(self.axis)],
92+
*x.shape[: self.start_axis],
7893
-1,
79-
*x.shape[max(self.axis) + 1 :],
94+
*x.shape[self.start_axis + self.n_axes :],
8095
]
8196

8297
out[0] = x.reshape(tuple(output_shape))
@@ -119,7 +134,22 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
119134
return x
120135

121136
axis = normalize_axis_tuple(axis, x.ndim)
122-
return type_cast(TensorVariable, JoinDims(axis)(x))
137+
138+
if any(i < 0 for i in axis):
139+
raise ValueError("join_dims axis must be non-negative")
140+
141+
if len(axis) > 1 and np.diff(axis).max() > 1:
142+
raise ValueError(
143+
f"join_dims axis must be consecutive, got normalized axis: {axis}"
144+
)
145+
146+
start_axis = min(axis)
147+
n_axes = len(axis)
148+
149+
return type_cast(
150+
TensorVariable,
151+
JoinDims(input_ndims=x.ndim, start_axis=start_axis, n_axes=n_axes)(x),
152+
)
123153

124154

125155
class SplitDims(Op):
@@ -131,26 +161,24 @@ def __init__(self, axis: int | None = None):
131161
raise ValueError("SplitDims axis must be non-negative")
132162
self.axis = axis
133163

134-
def _make_output_shape(self, input_shape, shape):
135-
[axis] = normalize_axis_tuple(self.axis, len(input_shape))
136-
output_shapes = list(input_shape)
164+
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
165+
if shape.type.dtype not in ("int8", "int16", "int32", "int64"):
166+
raise TypeError("shape must be an integer tensor")
137167

138168
def _get_constant_shape(x):
139169
try:
140-
# get_underling_scalar_constant_value returns a numpy scalar, we need a python int
141-
return get_underlying_scalar_constant_value(x).item()
170+
return get_scalar_constant_value(x).item()
142171
except NotScalarConstantError:
143172
return x
144173

145-
constant_shape = [_get_constant_shape(x) for x in shape]
146-
147-
return *output_shapes[:axis], *constant_shape, *output_shapes[axis + 1 :]
148-
149-
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
150-
if shape.type.dtype not in ("int8", "int16", "int32", "int64"):
151-
raise TypeError("shape must be an integer tensor")
174+
axis = self.axis
175+
constant_shape = [_get_constant_shape(s) for s in shape]
152176

153-
output_shapes = self._make_output_shape(x.type.shape, shape)
177+
output_shapes = [
178+
*x.type.shape[:axis],
179+
*constant_shape,
180+
*x.type.shape[axis + 1 :],
181+
]
154182

155183
output = tensor(
156184
shape=tuple([x if isinstance(x, int) else None for x in output_shapes]),
@@ -161,15 +189,37 @@ def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[over
161189
def infer_shape(self, fgraph, node, shapes):
162190
[input_shape, _] = shapes
163191
_, shape = node.inputs
164-
output_shape = self._make_output_shape(input_shape, shape)
192+
output_shapes = list(input_shape)
193+
axis = self.axis
165194

166-
return [output_shape]
195+
inferred_shape = [*output_shapes[:axis], *shape, *output_shapes[axis + 1 :]]
196+
return [inferred_shape]
167197

168198
def perform(self, node, inputs, outputs):
169199
(x, shape) = inputs
170200
(out,) = outputs
171201

172-
out[0] = x.reshape(self._make_output_shape(x.shape, shape))
202+
output_shape = [
203+
*x.shape[: self.axis],
204+
*shape,
205+
*x.shape[self.axis + 1 :],
206+
]
207+
208+
out[0] = x.reshape(output_shape)
209+
210+
211+
@_vectorize_node.register(SplitDims)
212+
def _vectorize_splitdims(op, node, x, shape):
213+
from pytensor.tensor.blockwise import vectorize_node_fallback
214+
215+
old_x, _ = node.inputs
216+
batched_ndims = x.type.ndim - old_x.type.ndim
217+
218+
if as_tensor_variable(shape).type.ndim != 1:
219+
return vectorize_node_fallback(op, node, x, shape)
220+
221+
axis = op.axis
222+
return split_dims(x, shape, axis=axis + batched_ndims).owner
173223

174224

175225
def split_dims(
@@ -223,7 +273,9 @@ def split_dims(
223273

224274
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
225275
shape = as_tensor_variable(shape) # type: ignore[arg-type]
226-
return type_cast(TensorVariable, SplitDims(axis)(x, shape))
276+
277+
split_op = SplitDims(axis=axis)
278+
return type_cast(TensorVariable, split_op(x, shape))
227279

228280

229281
def _analyze_axes_list(axes) -> tuple[int, int, int]:
@@ -419,7 +471,7 @@ def pack(
419471
if all([n_before == 0, n_after == 0, min_axes == 0]):
420472
# Special case -- we're raveling everything
421473
packed_shapes = [t.shape for t in tensor_list]
422-
reshaped_tensors = [t.ravel() for t in tensor_list]
474+
reshaped_tensors = [atleast_1d(join_dims(t, None)) for t in tensor_list]
423475

424476
return join(0, *reshaped_tensors), packed_shapes
425477

tests/tensor/test_reshape_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytensor
55
from pytensor import config, function
66
from pytensor import tensor as pt
7+
from pytensor.graph import vectorize_graph
78
from pytensor.tensor.shape_ops import (
89
_analyze_axes_list,
910
join_dims,
@@ -41,6 +42,16 @@ def test_join_dims():
4142
assert join_dims(x, axis=(1,)).eval({x: x_value}).shape == (2, 3, 4, 5)
4243
assert join_dims(x, axis=()).eval({x: x_value}).shape == (2, 3, 4, 5)
4344

45+
x = pt.tensor("x", shape=(3, 5))
46+
x_joined = join_dims(x, axis=(0, 1))
47+
x_batched = pt.tensor("x_batched", shape=(10, 3, 5))
48+
x_joined_batched = vectorize_graph(x_joined, {x: x_batched})
49+
50+
assert x_joined_batched.type.shape == (10, 15)
51+
52+
x_batched_val = rng.normal(size=(10, 3, 5)).astype(config.floatX)
53+
assert x_joined_batched.eval({x_batched: x_batched_val}).shape == (10, 15)
54+
4455

4556
@pytest.mark.parametrize(
4657
"axis, shape, expected_shape",
@@ -66,6 +77,16 @@ def test_split_dims(axis, shape, expected_shape):
6677
x_split_value = fn(x_value)
6778
np.testing.assert_allclose(x_split_value, x_value.reshape(expected_shape))
6879

80+
x = pt.tensor("x", shape=(10,))
81+
x_split = split_dims(x, shape=(5, 2), axis=0)
82+
x_batched = pt.tensor("x_batched", shape=(3, 10))
83+
x_split_batched = vectorize_graph(x_split, {x: x_batched})
84+
85+
assert x_split_batched.type.shape == (3, 5, 2)
86+
87+
x_batched_val = rng.normal(size=(3, 10)).astype(config.floatX)
88+
assert x_split_batched.eval({x_batched: x_batched_val}).shape == (3, 5, 2)
89+
6990

7091
def test_split_size_zero_shape():
7192
x = pt.tensor("x", shape=(1, 4, 6))

0 commit comments

Comments
 (0)