We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
make_node
1 parent e032a21 commit a166acaCopy full SHA for a166aca
pytensor/tensor/shape_ops.py
@@ -39,6 +39,10 @@ def __init__(self, axis: Sequence[int]):
39
40
def make_node(self, x: Variable) -> Apply: # type: ignore[override]
41
static_shapes = x.type.shape
42
+ if x.type.ndim < max(self.axis) + 1:
43
+ raise ValueError(
44
+ f"Input ndim {x.type.ndim} is less than the maximum axis {max(self.axis)} + 1"
45
+ )
46
joined_shape = (
47
int(np.prod([static_shapes[i] for i in self.axis]))
48
if all(static_shapes[i] is not None for i in self.axis)
0 commit comments