We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a166aca commit ef43a80Copy full SHA for ef43a80
pytensor/tensor/shape_ops.py
@@ -127,7 +127,7 @@ class SplitDims(Op):
127
view_map = {0: [0]}
128
129
def __init__(self, axis: int | None = None):
130
- if axis < 0:
+ if axis is not None and axis < 0:
131
raise ValueError("SplitDims axis must be non-negative")
132
self.axis = axis
133
@@ -221,7 +221,7 @@ def split_dims(
221
# (3, ) and (3, 3) to (3, 4)
222
return type_cast(TensorVariable, x.squeeze(axis=axis))
223
224
- [axis] = normalize_axis_tuple(axis, x.ndim)
+ [axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
225
shape = as_tensor_variable(shape) # type: ignore[arg-type]
226
return type_cast(TensorVariable, SplitDims(axis)(x, shape))
227
0 commit comments