We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
SplitDims._make_output_shape
1 parent d6d7cac commit 25bce5eCopy full SHA for 25bce5e
pytensor/tensor/shape_ops.py
@@ -96,16 +96,10 @@ def __init__(self, axis: int | None = None):
96
self.axis = axis
97
98
def _make_output_shape(self, input_shape, shape):
99
- axis = self.axis
100
-
+ [axis] = normalize_axis_tuple(self.axis, len(input_shape))
101
output_shapes = list(input_shape)
102
- shape = list(shape)
103
104
- output_shapes[axis] = shape.pop(-1)
105
- for s in shape[::-1]:
106
- output_shapes.insert(axis, s)
107
108
- return tuple(output_shapes)
+ return *output_shapes[:axis], *shape, *output_shapes[axis + 1 :]
109
110
def make_node(self, x: Variable, shape: Variable) -> Apply:
111
output_shapes = self._make_output_shape(x.type.shape, shape)
0 commit comments