1616
1717
1818class JoinDims (Op ):
19+ __props__ = ("axis" ,)
20+ view_map = {0 : [0 ]}
21+
1922 def __init__ (self , axis : Sequence [int ] | int | None = None ):
23+ if (isinstance (axis , int ) and axis < 0 ) or (
24+ isinstance (axis , Iterable ) and any (i < 0 for i in axis )
25+ ):
26+ raise ValueError ("JoinDims axis must be non-negative" )
27+
28+ if len (axis ) > 1 and np .diff (axis ).max () > 1 :
29+ raise ValueError (
30+ f"join_dims axis must be consecutive, got normalized axis: { axis } "
31+ )
32+
2033 self .axis = axis
2134
2235 def make_node (self , x : Variable ) -> Apply :
@@ -36,6 +49,17 @@ def make_node(self, x: Variable) -> Apply:
3649 output_type = tensor (shape = output_shapes , dtype = x .type .dtype )
3750 return Apply (self , [x ], [output_type ])
3851
52+ def infer_shape (self , fgraph , node , shapes ):
53+ [input_shape ] = shapes
54+ joined_shape = prod ([input_shape [i ] for i in self .axis ])
55+ out_shape = (
56+ * input_shape [: min (self .axis )],
57+ joined_shape ,
58+ * input_shape [max (self .axis ) + 1 :],
59+ )
60+
61+ return [out_shape ]
62+
3963 def perform (self , node , inputs , outputs ):
4064 (x ,) = inputs
4165 (out ,) = outputs
@@ -82,16 +106,13 @@ def join_dims(x: Variable, axis: Sequence[int] | int | None = None) -> Variable:
82106 return x
83107
84108 axis = normalize_axis_tuple (axis , x .ndim )
85-
86- if len (axis ) > 1 and np .diff (axis ).max () > 1 :
87- raise ValueError (
88- f"join_dims axis must be consecutive, got normalized axis: { axis } "
89- )
90-
91109 return JoinDims (axis )(x )
92110
93111
94112class SplitDims (Op ):
113+ __props__ = ("axis" ,)
114+ view_map = {0 : [0 ]}
115+
95116 def __init__ (self , axis : int | None = None ):
96117 self .axis = axis
97118
@@ -110,6 +131,13 @@ def make_node(self, x: Variable, shape: Variable) -> Apply:
110131 )
111132 return Apply (self , [x , as_tensor_variable (shape )], [output ])
112133
134+ def infer_shape (self , fgraph , node , shapes ):
135+ [input_shape , _ ] = shapes
136+ _ , shape = node .inputs
137+ output_shape = self ._make_output_shape (input_shape , shape )
138+
139+ return [output_shape ]
140+
113141 def perform (self , node , inputs , outputs ):
114142 (x , shape ) = inputs
115143 (out ,) = outputs
0 commit comments