Skip to content

Commit fd37e8e

Browse files
committed
update
1 parent d84996f commit fd37e8e

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

returnn/tf/layers/basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,14 +1219,15 @@ def __init__(self, start, size, min_size=None, **kwargs):
12191219
super(SliceNdLayer2, self).__init__(**kwargs)
12201220
from returnn.tf.util.basic import slice_nd2, DimensionTag
12211221
assert start.output.have_batch_axis() and self.input_data.have_batch_axis()
1222+
self.start = start
12221223

12231224
input_data = self.input_data.copy_as_batch_major()
12241225
start = start.output.copy_as_batch_major()
12251226

12261227
# make sure axis of start are in input
12271228
is_equal_opts = dict(ignore_feature_dim=True, allow_same_spatial_dim=True, broadcast_matches=True)
12281229
for start_axis in range(start.batch_ndim):
1229-
assert input_data.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts), "the input should hold all axis in start"
1230+
assert input_data.get_dim_tag(start_axis).is_equal(start.get_dim_tag(start_axis), **is_equal_opts)
12301231

12311232
# Handle the case when layer is pulled out of rec loop but the input hasn't change
12321233
if self.optimized_out_of_loop_and_unchanged_input(input_data, start):
@@ -1261,12 +1262,12 @@ def __init__(self, start, size, min_size=None, **kwargs):
12611262
self.output.placeholder = slices
12621263

12631264
@classmethod
1264-
def optimized_out_of_loop_and_unchanged_input(cls, input, start):
1265+
def optimized_out_of_loop_and_unchanged_input(cls, input_data, start):
12651266
"""
12661267
:rtype: bool
12671268
The idea is to check that the axis after the last common axis is a feature axis instead of spatial.
12681269
"""
1269-
return input.get_dim_tag(start.batch_ndim) == input.get_dim_tag(input.get_feature_batch_axes()[0])
1270+
return input_data.get_dim_tag(start.batch_ndim) == input_data.get_dim_tag(input_data.get_feature_batch_axes()[0])
12701271

12711272
def get_dep_layers(self):
12721273
"""

0 commit comments

Comments
 (0)