Skip to content

Commit f7644b6

Browse files
Appease mypy
1 parent 04f09ac commit f7644b6

File tree

1 file changed

+52
-29
lines changed

1 file changed

+52
-29
lines changed

pytensor/tensor/shape_ops.py

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import Iterable, Sequence
22
from itertools import pairwise
3+
from typing import cast as type_cast
34

45
import numpy as np
56
from numpy.lib._array_utils_impl import normalize_axis_tuple
@@ -8,21 +9,25 @@
89
from pytensor.graph import Apply
910
from pytensor.graph.op import Op
1011
from pytensor.tensor import TensorLike, as_tensor_variable
11-
from pytensor.tensor.basic import expand_dims, join, split
12+
from pytensor.tensor.basic import (
13+
expand_dims,
14+
get_underlying_scalar_constant_value,
15+
join,
16+
split,
17+
)
18+
from pytensor.tensor.exceptions import NotScalarConstantError
1219
from pytensor.tensor.math import prod
1320
from pytensor.tensor.shape import ShapeValueType
1421
from pytensor.tensor.type import tensor
15-
from pytensor.tensor.variable import TensorConstant, TensorVariable
22+
from pytensor.tensor.variable import TensorVariable
1623

1724

1825
class JoinDims(Op):
1926
__props__ = ("axis",)
2027
view_map = {0: [0]}
2128

22-
def __init__(self, axis: Sequence[int] | int | None = None):
23-
if (isinstance(axis, int) and axis < 0) or (
24-
isinstance(axis, Iterable) and any(i < 0 for i in axis)
25-
):
29+
def __init__(self, axis: Sequence[int]):
30+
if any(i < 0 for i in axis):
2631
raise ValueError("JoinDims axis must be non-negative")
2732

2833
if len(axis) > 1 and np.diff(axis).max() > 1:
@@ -32,7 +37,7 @@ def __init__(self, axis: Sequence[int] | int | None = None):
3237

3338
self.axis = axis
3439

35-
def make_node(self, x: Variable) -> Apply:
40+
def make_node(self, x: Variable) -> Apply: # type: ignore[override]
3641
static_shapes = x.type.shape
3742
joined_shape = (
3843
int(np.prod([static_shapes[i] for i in self.axis]))
@@ -73,7 +78,7 @@ def perform(self, node, inputs, outputs):
7378
out[0] = x.reshape(tuple(output_shape))
7479

7580

76-
def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
81+
def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorVariable:
7782
"""Join consecutive dimensions of a tensor into a single dimension.
7883
7984
Parameters
@@ -96,17 +101,21 @@ def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
96101
>>> y.type.shape
97102
(2, 12, 5)
98103
"""
104+
x = as_tensor_variable(x)
105+
99106
if axis is None:
100-
axis = range(x.ndim).tolist()
101-
if not isinstance(axis, (list, tuple)):
107+
axis = list(range(x.ndim))
108+
elif isinstance(axis, int):
102109
axis = [axis]
110+
elif not isinstance(axis, list | tuple):
111+
raise TypeError("axis must be an int, a list/tuple of ints, or None")
103112

104113
if not axis:
105114
# The user passed an empty list/tuple, so we return the input as is
106115
return x
107116

108117
axis = normalize_axis_tuple(axis, x.ndim)
109-
return JoinDims(axis)(x)
118+
return type_cast(TensorVariable, JoinDims(axis)(x))
110119

111120

112121
class SplitDims(Op):
@@ -120,16 +129,25 @@ def _make_output_shape(self, input_shape, shape):
120129
[axis] = normalize_axis_tuple(self.axis, len(input_shape))
121130
output_shapes = list(input_shape)
122131

123-
return *output_shapes[:axis], *shape, *output_shapes[axis + 1 :]
132+
def _get_constant_shape(x):
133+
try:
134+
# get_underling_scalar_constant_value returns a numpy scalar, we need a python int
135+
return get_underlying_scalar_constant_value(x).item()
136+
except NotScalarConstantError:
137+
return x
138+
139+
constant_shape = [_get_constant_shape(x) for x in shape]
124140

125-
def make_node(self, x: Variable, shape: Variable) -> Apply:
141+
return *output_shapes[:axis], *constant_shape, *output_shapes[axis + 1 :]
142+
143+
def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[override]
126144
output_shapes = self._make_output_shape(x.type.shape, shape)
127145

128146
output = tensor(
129147
shape=tuple([x if isinstance(x, int) else None for x in output_shapes]),
130148
dtype=x.type.dtype,
131149
)
132-
return Apply(self, [x, as_tensor_variable(shape)], [output])
150+
return Apply(self, [x, shape], [output])
133151

134152
def infer_shape(self, fgraph, node, shapes):
135153
[input_shape, _] = shapes
@@ -146,7 +164,9 @@ def perform(self, node, inputs, outputs):
146164

147165

148166
def split_dims(
149-
x: TensorLike, shape: ShapeValueType, axis: int | None = None
167+
x: TensorLike,
168+
shape: ShapeValueType | Sequence[ShapeValueType],
169+
axis: int | None = None,
150170
) -> TensorVariable:
151171
"""Split a dimension of a tensor into multiple dimensions.
152172
@@ -183,18 +203,17 @@ def split_dims(
183203

184204
if isinstance(shape, int):
185205
shape = [shape]
186-
elif isinstance(shape, TensorConstant):
187-
shape = shape.data.tolist()
188206
else:
189-
shape = list(shape)
207+
shape = list(shape) # type: ignore[arg-type]
190208

191209
if not shape:
192-
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for example
193-
# when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes (3, ) and
194-
# (3, 3) to (3, 4)
195-
return x.squeeze(axis=axis)
210+
# If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
211+
# example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
212+
# (3, ) and (3, 3) to (3, 4)
213+
return type_cast(TensorVariable, x.squeeze(axis=axis))
196214

197-
return SplitDims(axis)(x, shape)
215+
shape = as_tensor_variable(shape) # type: ignore[arg-type]
216+
return type_cast(TensorVariable, SplitDims(axis)(x, shape))
198217

199218

200219
def _analyze_axes_list(axes) -> tuple[int, int, int]:
@@ -380,19 +399,21 @@ def pack(
380399
# packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
381400
# packed_shapes is [(4,), (5,)]
382401
"""
402+
tensor_list = [as_tensor_variable(t) for t in tensors]
403+
383404
n_before, n_after, min_axes = _analyze_axes_list(axes)
384405

406+
reshaped_tensors: list[TensorVariable] = []
407+
packed_shapes: list[ShapeValueType] = []
408+
385409
if all([n_before == 0, n_after == 0, min_axes == 0]):
386410
# Special case -- we're raveling everything
387-
packed_shapes = [tensor.shape for tensor in tensors]
388-
reshaped_tensors = [tensor.ravel() for tensor in tensors]
411+
packed_shapes = [t.shape for t in tensor_list]
412+
reshaped_tensors = [t.ravel() for t in tensor_list]
389413

390414
return join(0, *reshaped_tensors), packed_shapes
391415

392-
reshaped_tensors: list[TensorLike] = []
393-
packed_shapes: list[ShapeValueType] = []
394-
395-
for i, input_tensor in enumerate(tensors):
416+
for i, input_tensor in enumerate(tensor_list):
396417
n_dim = input_tensor.ndim
397418

398419
if n_dim < min_axes:
@@ -458,6 +479,8 @@ def unpack(
458479
unpacked_tensors : list of TensorLike
459480
A list of unpacked tensors with their original shapes restored.
460481
"""
482+
packed_input = as_tensor_variable(packed_input)
483+
461484
if axes is None:
462485
if packed_input.ndim != 1:
463486
raise ValueError(

0 commit comments

Comments
 (0)