11from collections .abc import Iterable , Sequence
22from itertools import pairwise
3+ from typing import cast as type_cast
34
45import numpy as np
56from numpy .lib ._array_utils_impl import normalize_axis_tuple
89from pytensor .graph import Apply
910from pytensor .graph .op import Op
1011from pytensor .tensor import TensorLike , as_tensor_variable
11- from pytensor .tensor .basic import expand_dims , join , split
12+ from pytensor .tensor .basic import (
13+ expand_dims ,
14+ get_underlying_scalar_constant_value ,
15+ join ,
16+ split ,
17+ )
18+ from pytensor .tensor .exceptions import NotScalarConstantError
1219from pytensor .tensor .math import prod
1320from pytensor .tensor .shape import ShapeValueType
1421from pytensor .tensor .type import tensor
15- from pytensor .tensor .variable import TensorConstant , TensorVariable
22+ from pytensor .tensor .variable import TensorVariable
1623
1724
1825class JoinDims (Op ):
1926 __props__ = ("axis" ,)
2027 view_map = {0 : [0 ]}
2128
22- def __init__ (self , axis : Sequence [int ] | int | None = None ):
23- if (isinstance (axis , int ) and axis < 0 ) or (
24- isinstance (axis , Iterable ) and any (i < 0 for i in axis )
25- ):
29+ def __init__ (self , axis : Sequence [int ]):
30+ if any (i < 0 for i in axis ):
2631 raise ValueError ("JoinDims axis must be non-negative" )
2732
2833 if len (axis ) > 1 and np .diff (axis ).max () > 1 :
@@ -32,7 +37,7 @@ def __init__(self, axis: Sequence[int] | int | None = None):
3237
3338 self .axis = axis
3439
35- def make_node (self , x : Variable ) -> Apply :
40+ def make_node (self , x : Variable ) -> Apply : # type: ignore[override]
3641 static_shapes = x .type .shape
3742 joined_shape = (
3843 int (np .prod ([static_shapes [i ] for i in self .axis ]))
@@ -73,7 +78,7 @@ def perform(self, node, inputs, outputs):
7378 out [0 ] = x .reshape (tuple (output_shape ))
7479
7580
76- def join_dims (x : Variable , axis : Sequence [int ] | int | None = None ) -> Variable :
81+ def join_dims (x : TensorLike , axis : Sequence [int ] | int | None = None ) -> TensorVariable :
7782 """Join consecutive dimensions of a tensor into a single dimension.
7883
7984 Parameters
@@ -96,17 +101,21 @@ def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
96101 >>> y.type.shape
97102 (2, 12, 5)
98103 """
104+ x = as_tensor_variable (x )
105+
99106 if axis is None :
100- axis = range (x .ndim ). tolist ( )
101- if not isinstance (axis , ( list , tuple ) ):
107+ axis = list ( range (x .ndim ))
108+ elif isinstance (axis , int ):
102109 axis = [axis ]
110+ elif not isinstance (axis , list | tuple ):
111+ raise TypeError ("axis must be an int, a list/tuple of ints, or None" )
103112
104113 if not axis :
105114 # The user passed an empty list/tuple, so we return the input as is
106115 return x
107116
108117 axis = normalize_axis_tuple (axis , x .ndim )
109- return JoinDims (axis )(x )
118+ return type_cast ( TensorVariable , JoinDims (axis )(x ) )
110119
111120
112121class SplitDims (Op ):
@@ -120,16 +129,25 @@ def _make_output_shape(self, input_shape, shape):
120129 [axis ] = normalize_axis_tuple (self .axis , len (input_shape ))
121130 output_shapes = list (input_shape )
122131
123- return * output_shapes [:axis ], * shape , * output_shapes [axis + 1 :]
132+ def _get_constant_shape (x ):
133+ try :
134+ # get_underling_scalar_constant_value returns a numpy scalar, we need a python int
135+ return get_underlying_scalar_constant_value (x ).item ()
136+ except NotScalarConstantError :
137+ return x
138+
139+ constant_shape = [_get_constant_shape (x ) for x in shape ]
124140
125- def make_node (self , x : Variable , shape : Variable ) -> Apply :
141+ return * output_shapes [:axis ], * constant_shape , * output_shapes [axis + 1 :]
142+
143+ def make_node (self , x : Variable , shape : Variable ) -> Apply : # type: ignore[override]
126144 output_shapes = self ._make_output_shape (x .type .shape , shape )
127145
128146 output = tensor (
129147 shape = tuple ([x if isinstance (x , int ) else None for x in output_shapes ]),
130148 dtype = x .type .dtype ,
131149 )
132- return Apply (self , [x , as_tensor_variable ( shape ) ], [output ])
150+ return Apply (self , [x , shape ], [output ])
133151
134152 def infer_shape (self , fgraph , node , shapes ):
135153 [input_shape , _ ] = shapes
@@ -146,7 +164,9 @@ def perform(self, node, inputs, outputs):
146164
147165
148166def split_dims (
149- x : TensorLike , shape : ShapeValueType , axis : int | None = None
167+ x : TensorLike ,
168+ shape : ShapeValueType | Sequence [ShapeValueType ],
169+ axis : int | None = None ,
150170) -> TensorVariable :
151171 """Split a dimension of a tensor into multiple dimensions.
152172
@@ -183,18 +203,17 @@ def split_dims(
183203
184204 if isinstance (shape , int ):
185205 shape = [shape ]
186- elif isinstance (shape , TensorConstant ):
187- shape = shape .data .tolist ()
188206 else :
189- shape = list (shape )
207+ shape = list (shape ) # type: ignore[arg-type]
190208
191209 if not shape :
192- # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for example
193- # when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes (3, ) and
194- # (3, 3) to (3, 4)
195- return x .squeeze (axis = axis )
210+ # If we get an empty shape, there is potentially a dummy dimension at the requested axis. This happens for
211+ # example when splitting a packed tensor that had its dims expanded before packing (e.g. when packing shapes
212+ # (3, ) and (3, 3) to (3, 4)
213+ return type_cast ( TensorVariable , x .squeeze (axis = axis ) )
196214
197- return SplitDims (axis )(x , shape )
215+ shape = as_tensor_variable (shape ) # type: ignore[arg-type]
216+ return type_cast (TensorVariable , SplitDims (axis )(x , shape ))
198217
199218
200219def _analyze_axes_list (axes ) -> tuple [int , int , int ]:
@@ -380,19 +399,21 @@ def pack(
380399 # packed_tensor has shape (2, 4 + 5, 3) == (2, 9, 3)
381400 # packed_shapes is [(4,), (5,)]
382401 """
402+ tensor_list = [as_tensor_variable (t ) for t in tensors ]
403+
383404 n_before , n_after , min_axes = _analyze_axes_list (axes )
384405
406+ reshaped_tensors : list [TensorVariable ] = []
407+ packed_shapes : list [ShapeValueType ] = []
408+
385409 if all ([n_before == 0 , n_after == 0 , min_axes == 0 ]):
386410 # Special case -- we're raveling everything
387- packed_shapes = [tensor .shape for tensor in tensors ]
388- reshaped_tensors = [tensor .ravel () for tensor in tensors ]
411+ packed_shapes = [t .shape for t in tensor_list ]
412+ reshaped_tensors = [t .ravel () for t in tensor_list ]
389413
390414 return join (0 , * reshaped_tensors ), packed_shapes
391415
392- reshaped_tensors : list [TensorLike ] = []
393- packed_shapes : list [ShapeValueType ] = []
394-
395- for i , input_tensor in enumerate (tensors ):
416+ for i , input_tensor in enumerate (tensor_list ):
396417 n_dim = input_tensor .ndim
397418
398419 if n_dim < min_axes :
@@ -458,6 +479,8 @@ def unpack(
458479 unpacked_tensors : list of TensorLike
459480 A list of unpacked tensors with their original shapes restored.
460481 """
482+ packed_input = as_tensor_variable (packed_input )
483+
461484 if axes is None :
462485 if packed_input .ndim != 1 :
463486 raise ValueError (
0 commit comments