Skip to content

Commit 4eaa89a

Browse files
committed
change test to slice_nd
1 parent fd37e8e commit 4eaa89a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/test_TFNetworkRecLayer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3361,12 +3361,14 @@ def check_reclayer_optimize_out(subnet_layer_dict, other_subnet_layers=None, sha
33613361
with make_scope() as session:
33623362
print("Create non-optimized rec layer (with subnet layer moved out)")
33633363
rec_layer_dict["optimize_move_layers_out"] = False
3364+
rec_layer_dict["unit"]["window"]["class"] = "slice_nd"
33643365
net1 = TFNetwork(config=config, train_flag=True, name="<root_not_opt>")
33653366
if shared_base_net:
33663367
net1.construct_from_dict(shared_base_net)
33673368
for key in shared_base_net:
33683369
assert key in net1.layers
33693370
net1.construct_from_dict({"output_not_opt": rec_layer_dict})
3371+
rec_layer_dict["unit"]["window"]["class"] = "slice_nd2"
33703372
rec_layer_dict["optimize_move_layers_out"] = True
33713373
print("Create optimized rec layer (with subnet layer inside loop)")
33723374
net2 = TFNetwork(config=config, extern_data=net1.extern_data, train_flag=True, name="<root_opt>")
@@ -3611,7 +3613,7 @@ def random_start_positions(source, **kwargs):
36113613
from_="position",
36123614
other_subnet_layers={
36133615
"my_layer": {"class": "gather_nd", "from": "base:data", "position": ":i"},
3614-
"window": {"class": "slice_nd2", # no_opt: [B,4,D], opt: [B,T,4,D]
3616+
"window": {"class": "slice_nd", # no_opt: [B,4,D], opt: [B,T,4,D]
36153617
"from": "base:data", "start": "data:source", "size": 4, "is_output_layer": True},
36163618
},
36173619
shared_base_net={

0 commit comments

Comments
 (0)