Skip to content

Commit 25bce5e

Browse files
Clean up SplitDims._make_output_shape
1 parent d6d7cac commit 25bce5e

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,10 @@ def __init__(self, axis: int | None = None):
9696
self.axis = axis
9797

9898
def _make_output_shape(self, input_shape, shape):
99-
axis = self.axis
100-
99+
[axis] = normalize_axis_tuple(self.axis, len(input_shape))
101100
output_shapes = list(input_shape)
102-
shape = list(shape)
103-
104-
output_shapes[axis] = shape.pop(-1)
105-
for s in shape[::-1]:
106-
output_shapes.insert(axis, s)
107101

108-
return tuple(output_shapes)
102+
return *output_shapes[:axis], *shape, *output_shapes[axis + 1 :]
109103

110104
def make_node(self, x: Variable, shape: Variable) -> Apply:
111105
output_shapes = self._make_output_shape(x.type.shape, shape)

0 commit comments

Comments
 (0)