Skip to content

Commit ef43a80

Browse files
More mypy
1 parent a166aca commit ef43a80

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class SplitDims(Op):
127127
view_map = {0: [0]}
128128

129129
def __init__(self, axis: int | None = None):
130-
if axis < 0:
130+
if axis is not None and axis < 0:
131131
raise ValueError("SplitDims axis must be non-negative")
132132
self.axis = axis
133133

@@ -221,7 +221,7 @@ def split_dims(
221221
# (3, ) and (3, 3) to (3, 4)
222222
return type_cast(TensorVariable, x.squeeze(axis=axis))
223223

224-
[axis] = normalize_axis_tuple(axis, x.ndim)
224+
[axis] = normalize_axis_tuple(axis, x.ndim) # type: ignore[misc]
225225
shape = as_tensor_variable(shape) # type: ignore[arg-type]
226226
return type_cast(TensorVariable, SplitDims(axis)(x, shape))
227227

0 commit comments

Comments
 (0)