88from pytensor import Variable
99from pytensor .graph import Apply
1010from pytensor .graph .op import Op
11+ from pytensor .graph .replace import _vectorize_node
1112from pytensor .tensor import TensorLike , as_tensor_variable
1213from pytensor .tensor .basic import (
14+ atleast_1d ,
1315 expand_dims ,
14- get_underlying_scalar_constant_value ,
16+ get_scalar_constant_value ,
1517 join ,
1618 split ,
1719)
2325
2426
2527class JoinDims (Op ):
26- __props__ = ("axis" ,)
28+ __props__ = (
29+ "start_axis" ,
30+ "n_axes" ,
31+ )
2732 view_map = {0 : [0 ]}
2833
29- def __init__ (self , axis : Sequence [ int ] ):
30- if any ( i < 0 for i in axis ) :
31- raise ValueError ("JoinDims axis must be non-negative" )
34+ def __init__ (self , input_ndims : int , start_axis : int | None , n_axes : int | None ):
35+ if start_axis < 0 :
36+ raise ValueError ("JoinDims start_axis must be non-negative" )
3237
33- if len (axis ) > 1 and np .diff (axis ).max () > 1 :
34- raise ValueError (
35- f"join_dims axis must be consecutive, got normalized axis: { axis } "
36- )
38+ self .start_axis = start_axis
39+ self .n_axes = n_axes
40+ self .input_ndims = input_ndims
3741
38- self .axis = axis
42+ output_ndims = 1 if not start_axis else min (1 , input_ndims - n_axes )
43+
44+ input_signature = "," .join (f"i{ i } " for i in range (input_ndims ))
45+ output_signature = "," .join (f"o{ i } " for i in range (output_ndims ))
46+
47+ self .gufunc_signature = f"({ input_signature } )->({ output_signature } )"
48+
49+ @property
50+ def axis_range (self ):
51+ return range (self .start_axis , self .start_axis + self .n_axes )
52+
53+ def output_shapes (self , input_shapes , joined_shape ):
54+ return (
55+ * input_shapes [: self .start_axis ],
56+ joined_shape ,
57+ * input_shapes [self .start_axis + self .n_axes :],
58+ )
3959
4060 def make_node (self , x : Variable ) -> Apply : # type: ignore[override]
4161 static_shapes = x .type .shape
42- if x .type .ndim < max ( self .axis ) + 1 :
62+ if x .type .ndim != self .input_ndims :
4363 raise ValueError (
44- f"Input ndim { x .type .ndim } is less than the maximum axis { max ( self .axis ) } + 1 "
64+ f"Input ndim { x .type .ndim } is not equal to expected ndim { self .input_ndims } "
4565 )
66+
67+ axis_range = self .axis_range
68+
4669 joined_shape = (
47- int (np .prod ([static_shapes [i ] for i in self . axis ]))
48- if all (static_shapes [i ] is not None for i in self . axis )
70+ int (np .prod ([static_shapes [i ] for i in axis_range ]))
71+ if all (static_shapes [i ] is not None for i in axis_range )
4972 else None
5073 )
5174
52- output_shapes = (
53- * static_shapes [: min (self .axis )],
54- joined_shape ,
55- * static_shapes [max (self .axis ) + 1 :],
56- )
57-
75+ output_shapes = self .output_shapes (static_shapes , joined_shape )
5876 output_type = tensor (shape = output_shapes , dtype = x .type .dtype )
77+
5978 return Apply (self , [x ], [output_type ])
6079
6180 def infer_shape (self , fgraph , node , shapes ):
6281 [input_shape ] = shapes
63- joined_shape = prod ([input_shape [i ] for i in self .axis ])
64- out_shape = (
65- * input_shape [: min (self .axis )],
66- joined_shape ,
67- * input_shape [max (self .axis ) + 1 :],
68- )
82+ axis_range = self .axis_range
6983
70- return [out_shape ]
84+ joined_shape = prod ([input_shape [i ] for i in axis_range ])
85+ return [self .output_shapes (input_shape , joined_shape )]
7186
7287 def perform (self , node , inputs , outputs ):
7388 (x ,) = inputs
7489 (out ,) = outputs
7590
7691 output_shape = [
77- * x .shape [: min ( self .axis ) ],
92+ * x .shape [: self .start_axis ],
7893 - 1 ,
79- * x .shape [max ( self .axis ) + 1 :],
94+ * x .shape [self .start_axis + self . n_axes :],
8095 ]
8196
8297 out [0 ] = x .reshape (tuple (output_shape ))
@@ -119,7 +134,22 @@ def join_dims(x: TensorLike, axis: Sequence[int] | int | None = None) -> TensorV
119134 return x
120135
121136 axis = normalize_axis_tuple (axis , x .ndim )
122- return type_cast (TensorVariable , JoinDims (axis )(x ))
137+
138+ if any (i < 0 for i in axis ):
139+ raise ValueError ("join_dims axis must be non-negative" )
140+
141+ if len (axis ) > 1 and np .diff (axis ).max () > 1 :
142+ raise ValueError (
143+ f"join_dims axis must be consecutive, got normalized axis: { axis } "
144+ )
145+
146+ start_axis = min (axis )
147+ n_axes = len (axis )
148+
149+ return type_cast (
150+ TensorVariable ,
151+ JoinDims (input_ndims = x .ndim , start_axis = start_axis , n_axes = n_axes )(x ),
152+ )
123153
124154
125155class SplitDims (Op ):
@@ -131,26 +161,24 @@ def __init__(self, axis: int | None = None):
131161 raise ValueError ("SplitDims axis must be non-negative" )
132162 self .axis = axis
133163
134- def _make_output_shape (self , input_shape , shape ):
135- [ axis ] = normalize_axis_tuple ( self . axis , len ( input_shape ))
136- output_shapes = list ( input_shape )
164+ def make_node (self , x : Variable , shape : Variable ) -> Apply : # type: ignore[override]
165+ if shape . type . dtype not in ( "int8" , "int16" , "int32" , "int64" ):
166+ raise TypeError ( "shape must be an integer tensor" )
137167
138168 def _get_constant_shape (x ):
139169 try :
140- # get_underling_scalar_constant_value returns a numpy scalar, we need a python int
141- return get_underlying_scalar_constant_value (x ).item ()
170+ return get_scalar_constant_value (x ).item ()
142171 except NotScalarConstantError :
143172 return x
144173
145- constant_shape = [_get_constant_shape (x ) for x in shape ]
146-
147- return * output_shapes [:axis ], * constant_shape , * output_shapes [axis + 1 :]
148-
149- def make_node (self , x : Variable , shape : Variable ) -> Apply : # type: ignore[override]
150- if shape .type .dtype not in ("int8" , "int16" , "int32" , "int64" ):
151- raise TypeError ("shape must be an integer tensor" )
174+ axis = self .axis
175+ constant_shape = [_get_constant_shape (s ) for s in shape ]
152176
153- output_shapes = self ._make_output_shape (x .type .shape , shape )
177+ output_shapes = [
178+ * x .type .shape [:axis ],
179+ * constant_shape ,
180+ * x .type .shape [axis + 1 :],
181+ ]
154182
155183 output = tensor (
156184 shape = tuple ([x if isinstance (x , int ) else None for x in output_shapes ]),
@@ -161,15 +189,37 @@ def make_node(self, x: Variable, shape: Variable) -> Apply: # type: ignore[over
161189 def infer_shape (self , fgraph , node , shapes ):
162190 [input_shape , _ ] = shapes
163191 _ , shape = node .inputs
164- output_shape = self ._make_output_shape (input_shape , shape )
192+ output_shapes = list (input_shape )
193+ axis = self .axis
165194
166- return [output_shape ]
195+ inferred_shape = [* output_shapes [:axis ], * shape , * output_shapes [axis + 1 :]]
196+ return [inferred_shape ]
167197
168198 def perform (self , node , inputs , outputs ):
169199 (x , shape ) = inputs
170200 (out ,) = outputs
171201
172- out [0 ] = x .reshape (self ._make_output_shape (x .shape , shape ))
202+ output_shape = [
203+ * x .shape [: self .axis ],
204+ * shape ,
205+ * x .shape [self .axis + 1 :],
206+ ]
207+
208+ out [0 ] = x .reshape (output_shape )
209+
210+
211+ @_vectorize_node .register (SplitDims )
212+ def _vectorize_splitdims (op , node , x , shape ):
213+ from pytensor .tensor .blockwise import vectorize_node_fallback
214+
215+ old_x , _ = node .inputs
216+ batched_ndims = x .type .ndim - old_x .type .ndim
217+
218+ if as_tensor_variable (shape ).type .ndim != 1 :
219+ return vectorize_node_fallback (op , node , x , shape )
220+
221+ axis = op .axis
222+ return split_dims (x , shape , axis = axis + batched_ndims ).owner
173223
174224
175225def split_dims (
@@ -223,7 +273,9 @@ def split_dims(
223273
224274 [axis ] = normalize_axis_tuple (axis , x .ndim ) # type: ignore[misc]
225275 shape = as_tensor_variable (shape ) # type: ignore[arg-type]
226- return type_cast (TensorVariable , SplitDims (axis )(x , shape ))
276+
277+ split_op = SplitDims (axis = axis )
278+ return type_cast (TensorVariable , split_op (x , shape ))
227279
228280
229281def _analyze_axes_list (axes ) -> tuple [int , int , int ]:
@@ -419,7 +471,7 @@ def pack(
419471 if all ([n_before == 0 , n_after == 0 , min_axes == 0 ]):
420472 # Special case -- we're raveling everything
421473 packed_shapes = [t .shape for t in tensor_list ]
422- reshaped_tensors = [t . ravel ( ) for t in tensor_list ]
474+ reshaped_tensors = [atleast_1d ( join_dims ( t , None ) ) for t in tensor_list ]
423475
424476 return join (0 , * reshaped_tensors ), packed_shapes
425477
0 commit comments