Skip to content

Commit 347e2d2

Browse files
authored
Fix the bug of #5
1 parent d3d06ed commit 347e2d2

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

davarocr/davarocr/davar_ie/models/connects/multimodal_context_module.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# Filename : multimodal_context_module.py
55
# Abstract : compute multimodal context for each bbox/ node.
66
7-
# Current Version: 1.0.0
8-
# Date : 2021-05-20
7+
# Current Version: 1.0.1
8+
# Date : 2021-07-28
99
######################################################################################################
1010
"""
1111
import copy
@@ -127,9 +127,9 @@ def pack_batch(self,
127127
last_idx = 0
128128

129129
# pack
130-
for _ in enumerate(pos_feat):
130+
for i, _ in enumerate(pos_feat):
131131
# visual feat
132-
b_s = pos_feat[_].size(0)
132+
b_s = pos_feat[i].size(0)
133133
img_feat = img_feat_all[0]
134134
img_feat_size = list(img_feat.size())
135135
img_feat_size[0] = max_length - b_s
@@ -144,8 +144,8 @@ def pack_batch(self,
144144
torch.cat((img_feat[last_idx: last_idx + b_s], img_feat.new_full(img_feat_size, 0)), 0))
145145

146146
# pos feat
147-
per_pos_feat = pos_feat[_]
148-
image_shape_h, image_shape_w = img_meta[_]['img_shape'][:2]
147+
per_pos_feat = pos_feat[i]
148+
image_shape_h, image_shape_w = img_meta[i]['img_shape'][:2]
149149
per_pos_feat_expand = per_pos_feat.new_full((per_pos_feat.size(0), 4), 0)
150150
per_pos_feat_expand[:, 0] = per_pos_feat[:, 0]
151151
per_pos_feat_expand[:, 1] = per_pos_feat[:, 1]
@@ -163,15 +163,15 @@ def pack_batch(self,
163163

164164
# classification labels
165165
if info_labels is not None:
166-
per_label = info_labels[_]
166+
per_label = info_labels[i]
167167
img_feat_size = list(per_label.size())
168168
img_feat_size[0] = max_length - b_s
169169
batched_img_label.append(
170170
torch.cat((per_label, per_label.new_full(img_feat_size, 255)), 0))
171171

172172
# bieo labels
173173
if bieo_labels is not None:
174-
per_label = copy.deepcopy(bieo_labels[_])
174+
per_label = copy.deepcopy(bieo_labels[i])
175175
per_label = torch.tensor(per_label, dtype=torch.long).to(img_feat.device)
176176

177177
img_feat_size = list(per_label.size())

0 commit comments

Comments
 (0)