@@ -1219,14 +1219,15 @@ def __init__(self, start, size, min_size=None, **kwargs):
12191219 super (SliceNdLayer2 , self ).__init__ (** kwargs )
12201220 from returnn .tf .util .basic import slice_nd2 , DimensionTag
12211221 assert start .output .have_batch_axis () and self .input_data .have_batch_axis ()
1222+ self .start = start
12221223
12231224 input_data = self .input_data .copy_as_batch_major ()
12241225 start = start .output .copy_as_batch_major ()
12251226
12261227 # make sure axis of start are in input
12271228 is_equal_opts = dict (ignore_feature_dim = True , allow_same_spatial_dim = True , broadcast_matches = True )
12281229 for start_axis in range (start .batch_ndim ):
1229- assert input_data .get_dim_tag (start_axis ).is_equal (start .get_dim_tag (start_axis ), ** is_equal_opts ), "the input should hold all axis in start"
1230+ assert input_data .get_dim_tag (start_axis ).is_equal (start .get_dim_tag (start_axis ), ** is_equal_opts )
12301231
12311232 # Handle the case when layer is pulled out of rec loop but the input hasn't change
12321233 if self .optimized_out_of_loop_and_unchanged_input (input_data , start ):
@@ -1261,12 +1262,12 @@ def __init__(self, start, size, min_size=None, **kwargs):
12611262 self .output .placeholder = slices
12621263
12631264 @classmethod
1264- def optimized_out_of_loop_and_unchanged_input (cls , input , start ):
1265+ def optimized_out_of_loop_and_unchanged_input (cls , input_data , start ):
12651266 """
12661267 :rtype: bool
12671268 The idea is to check that the axis after the last common axis is a feature axis instead of spatial.
12681269 """
1269- return input .get_dim_tag (start .batch_ndim ) == input .get_dim_tag (input .get_feature_batch_axes ()[0 ])
1270+ return input_data .get_dim_tag (start .batch_ndim ) == input_data .get_dim_tag (input_data .get_feature_batch_axes ()[0 ])
12701271
12711272 def get_dep_layers (self ):
12721273 """
0 commit comments