Skip to content

Commit a166aca

Browse files
Validate input ndim in JoinDims make_node
1 parent e032a21 commit a166aca

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ def __init__(self, axis: Sequence[int]):
3939

4040
def make_node(self, x: Variable) -> Apply: # type: ignore[override]
4141
static_shapes = x.type.shape
42+
if x.type.ndim < max(self.axis) + 1:
43+
raise ValueError(
44+
f"Input ndim {x.type.ndim} is less than the maximum axis {max(self.axis)} + 1"
45+
)
4246
joined_shape = (
4347
int(np.prod([static_shapes[i] for i in self.axis]))
4448
if all(static_shapes[i] is not None for i in self.axis)

0 commit comments

Comments
 (0)