Skip to content

Commit 04f09ac

Browse files
Respond to feedback
1 parent 1247bf3 commit 04f09ac

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,20 @@
1616

1717

1818
class JoinDims(Op):
19+
__props__ = ("axis",)
20+
view_map = {0: [0]}
21+
1922
def __init__(self, axis: Sequence[int] | int | None = None):
23+
if (isinstance(axis, int) and axis < 0) or (
24+
isinstance(axis, Iterable) and any(i < 0 for i in axis)
25+
):
26+
raise ValueError("JoinDims axis must be non-negative")
27+
28+
if len(axis) > 1 and np.diff(axis).max() > 1:
29+
raise ValueError(
30+
f"join_dims axis must be consecutive, got normalized axis: {axis}"
31+
)
32+
2033
self.axis = axis
2134

2235
def make_node(self, x: Variable) -> Apply:
@@ -36,6 +49,17 @@ def make_node(self, x: Variable) -> Apply:
3649
output_type = tensor(shape=output_shapes, dtype=x.type.dtype)
3750
return Apply(self, [x], [output_type])
3851

52+
def infer_shape(self, fgraph, node, shapes):
53+
[input_shape] = shapes
54+
joined_shape = prod([input_shape[i] for i in self.axis])
55+
out_shape = (
56+
*input_shape[: min(self.axis)],
57+
joined_shape,
58+
*input_shape[max(self.axis) + 1 :],
59+
)
60+
61+
return [out_shape]
62+
3963
def perform(self, node, inputs, outputs):
4064
(x,) = inputs
4165
(out,) = outputs
@@ -82,16 +106,13 @@ def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
82106
return x
83107

84108
axis = normalize_axis_tuple(axis, x.ndim)
85-
86-
if len(axis) > 1 and np.diff(axis).max() > 1:
87-
raise ValueError(
88-
f"join_dims axis must be consecutive, got normalized axis: {axis}"
89-
)
90-
91109
return JoinDims(axis)(x)
92110

93111

94112
class SplitDims(Op):
113+
__props__ = ("axis",)
114+
view_map = {0: [0]}
115+
95116
def __init__(self, axis: int | None = None):
96117
self.axis = axis
97118

@@ -110,6 +131,13 @@ def make_node(self, x: Variable, shape: Variable) -> Apply:
110131
)
111132
return Apply(self, [x, as_tensor_variable(shape)], [output])
112133

134+
def infer_shape(self, fgraph, node, shapes):
135+
[input_shape, _] = shapes
136+
_, shape = node.inputs
137+
output_shape = self._make_output_shape(input_shape, shape)
138+
139+
return [output_shape]
140+
113141
def perform(self, node, inputs, outputs):
114142
(x, shape) = inputs
115143
(out,) = outputs

0 commit comments

Comments
 (0)