File tree Expand file tree Collapse file tree 2 files changed +11
-6
lines changed
Expand file tree Collapse file tree 2 files changed +11
-6
lines changed Original file line number Diff line number Diff line change @@ -1886,9 +1886,7 @@ def get_layer(name):
18861886 prev_layer = prev_layers [layer_name ]
18871887 assert layer .output .batch_shape == prev_layer .output .batch_shape
18881888 assert layer .output .batch_dim_axis == prev_layer .output .batch_dim_axis
1889- assert sorted (layer .output .size_placeholder .keys ()) == sorted (prev_layer .output .size_placeholder .keys ())
1890- for i in range (len (layer .output .size_placeholder )):
1891- assert layer .output .get_size_dim_tag (i ) == prev_layer .output .get_size_dim_tag (i )
1889+ assert layer .output .get_dyn_size_tags () == prev_layer .output .get_dyn_size_tags ()
18921890
18931891 def get_prev_template_layer (self , layer_name ):
18941892 """
Original file line number Diff line number Diff line change @@ -518,7 +518,7 @@ def is_dynamic(self):
518518 :return: whether the dim is not static. usually means that it has seq lengths
519519 :rtype: bool
520520 """
521- return self .dimension is not None
521+ return self .dimension is None and not self . is_batch_dim ()
522522
523523 def can_be_used_as_dim (self ):
524524 """
@@ -5412,13 +5412,20 @@ def get_time_dim_tag(self):
54125412 assert self .time_dim_axis is not None
54135413 return self .get_dim_tag (self .time_dim_axis )
54145414
5415+ def get_dyn_size_tags (self ):
5416+ """
5417+ :return: all dim tags with dynamic size
5418+ :rtype: list[Dim]
5419+ """
5420+ return [dim_tag for dim_tag in self ._dim_tags if dim_tag .is_dynamic ()]
5421+
54155422 def get_size_dim_tag (self , number ):
54165423 """
54175424 :param int number: index in sorted(size_placeholder.keys())
54185425 :rtype: Dim
54195426 """
5420- axis_wo_batch = sorted ( self .size_placeholder . keys ())[ number ]
5421- return self . get_dim_tag ( self . get_batch_axis ( axis_wo_batch ))
5427+ dyn_size_tags = self .get_dyn_size_tags ()
5428+ return dyn_size_tags [ number ]
54225429
54235430 def get_batch_shape_dim_tags (self ):
54245431 """
You can’t perform that action at this time.
0 commit comments