Skip to content

Commit 7588b6b

Browse files
committed
update mask_roi
1 parent 7ed0179 commit 7588b6b

File tree

3 files changed

+167
-33
lines changed

3 files changed

+167
-33
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : mask_roi_extractor.py
5+
# Abstract : Extract RoI masking features from a single level feature map.
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-07-14
9+
##################################################################################################
10+
"""
11+
import numpy as np
12+
13+
import torch
14+
from mmcv.runner import force_fp32
15+
16+
from mmdet.models.builder import ROI_EXTRACTORS
17+
from mmdet.models.roi_heads.roi_extractors.single_level_roi_extractor import SingleRoIExtractor
18+
19+
20+
@ROI_EXTRACTORS.register_module()
21+
class MaskRoIExtractor(SingleRoIExtractor):
22+
""" Implementation of RoI masking feature extractor. """
23+
24+
def __init__(self,
25+
roi_layer,
26+
out_channels,
27+
featmap_strides,
28+
finest_scale=56):
29+
"""
30+
Args:
31+
roi_layer (dict): Specify RoI layer type and arguments.
32+
out_channels (int): Output channels of RoI layers.
33+
featmap_strides (List[int]): Strides of input feature maps.
34+
finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
35+
"""
36+
37+
super().__init__(roi_layer, out_channels, featmap_strides, finest_scale)
38+
39+
@force_fp32(apply_to=('feats', ), out_fp16=True)
40+
def forward(self, feats, rois, masks, roi_scale_factor=None):
41+
""" Forward computation.
42+
43+
Args:
44+
feats (list(Tensor)): original feature maps, in shape of [B x C x H x W]
45+
rois (Tensor): region of interest, in shape of [num_roi x 5]
46+
masks (list(BitmapMasks)): the mask corresponding to each img.
47+
roi_scale_factor (tuple): scale factor that RoI will be multiplied by.
48+
49+
Returns:
50+
Tensor: extract RoI masking feature maps, in shape of [num_roi x C x H x W]
51+
"""
52+
53+
out_size = self.roi_layers[0].output_size
54+
num_levels = len(feats)
55+
expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
56+
if torch.onnx.is_in_onnx_export():
57+
# Work around to export mask-rcnn to onnx
58+
roi_feats = rois[:, :1].clone().detach()
59+
roi_feats = roi_feats.expand(*expand_dims)
60+
roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
61+
roi_feats = roi_feats * 0
62+
else:
63+
roi_feats = feats[0].new_zeros(
64+
rois.size(0), self.out_channels, *out_size)
65+
66+
# TODO: remove this when parrots supports
67+
if torch.__version__ == 'parrots':
68+
roi_feats.requires_grad = True
69+
70+
if num_levels == 1:
71+
if len(rois) == 0:
72+
return roi_feats
73+
return self.roi_layers[0](feats[0], rois)
74+
75+
target_lvls = self.map_roi_levels(rois, num_levels)
76+
77+
if roi_scale_factor is not None:
78+
rois = self.roi_rescale(rois, roi_scale_factor)
79+
80+
for i in range(num_levels):
81+
mask = target_lvls == i
82+
if torch.onnx.is_in_onnx_export():
83+
# To keep all roi_align nodes exported to onnx
84+
# and skip nonzero op
85+
mask = mask.float().unsqueeze(-1).expand(*expand_dims).reshape(
86+
roi_feats.shape)
87+
roi_feats_t = self.roi_layers[i](feats[i], rois)
88+
roi_feats_t *= mask
89+
roi_feats += roi_feats_t
90+
continue
91+
inds = mask.nonzero(as_tuple=False).squeeze(1)
92+
if inds.numel() > 0:
93+
rois_ = rois[inds]
94+
roi_feats_t = self.roi_layers[i](feats[i], rois_)
95+
roi_feats[inds] = roi_feats_t
96+
else:
97+
# Sometimes some pyramid levels will not be used for RoI
98+
# feature extraction and this will cause an incomplete
99+
# computation graph in one GPU, which is different from those
100+
# in other GPUs and will cause a hanging error.
101+
# Therefore, we add it to ensure each feature pyramid is
102+
# included in the computation graph to avoid runtime bugs.
103+
roi_feats += sum(
104+
x.view(-1)[0]
105+
for x in self.parameters()) * 0. + feats[i].sum() * 0.
106+
107+
if masks is not None:
108+
left = 0
109+
right = 0
110+
output_size = self.roi_layers[0].output_size
111+
crop_masks = []
112+
for mask in masks:
113+
num = mask.masks.shape[0]
114+
right += num
115+
# Crop mask from gt_masks according to roi
116+
crop_mask = mask.crop_and_resize(rois[left:right, 1:], output_size,
117+
np.array(range(num)), device=rois.device)
118+
left += num
119+
crop_masks_t = torch.tensor(crop_mask.masks).to(roi_feats.device)
120+
crop_masks.append(crop_masks_t)
121+
crop_masks = torch.cat(crop_masks)
122+
crop_masks = crop_masks.unsqueeze(1)
123+
roi_feats = roi_feats * crop_masks.detach()
124+
return roi_feats

davarocr/davarocr/davar_spotting/models/roi_extractors/tps_roi_extractor.py

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -76,30 +76,47 @@ def forward(self, feats, fiducial_points):
7676
Returns:
7777
Tensor: rectification feature of shape [K x C x output_size]
7878
"""
79+
roi_feats = []
80+
scale_factor = 4
81+
7982
# only using 4x feature
80-
x = self.relu(self.bn(self.conv(feats[0])))
83+
feats = self.relu(self.bn(self.conv(feats[0])))
84+
_, _, height, width = feats.size()
8185

82-
roi_feats = []
83-
for batch_id in range(len(x)):
84-
batch_C_prime = fiducial_points[batch_id]
85-
if len(batch_C_prime) == 0:
86+
for feat, points in zip(feats, fiducial_points):
87+
if len(points) == 0:
8688
continue
87-
# B x point_num x 2
88-
batch_C_prime = torch.Tensor(batch_C_prime).cuda(device=x.device)
89-
# B x C x H x W
90-
batch_I = x[batch_id].unsqueeze(0).expand(len(batch_C_prime), -1, -1, -1)
91-
# B x N (= output_size[0] x output_size[1]) x 2
92-
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime)
93-
# B x output_size x 2
94-
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0),
95-
self.output_size[0],
96-
self.output_size[1],
97-
2])
98-
# B x C x output_size
99-
batch_I_r = F.grid_sample(batch_I,
100-
build_P_prime_reshape,
101-
padding_mode='border')
102-
roi_feats.append(batch_I_r)
89+
points = torch.Tensor(points).cuda(device=feat.device)
90+
points = points / scale_factor
91+
for point in points:
92+
# Clip points
93+
point[:, 0] = torch.clip(point[:, 0], 0, width)
94+
point[:, 1] = torch.clip(point[:, 1], 0, height)
95+
96+
# Caculate points boundary
97+
x1 = int(torch.min(point[:, 0]))
98+
x2 = int(torch.max(point[:, 0])) + 1
99+
y1 = int(torch.min(point[:, 1]))
100+
y2 = int(torch.max(point[:, 1])) + 1
101+
102+
# Normalize points for tps
103+
point[:, 0] = 2 * (point[:, 0] - x1) / (x2 - x1) - 1
104+
point[:, 1] = 2 * (point[:, 1] - y1) / (y2 - y1) - 1
105+
106+
# B x N (= output_size[0] x output_size[1]) x 2
107+
build_P_prime = self.GridGenerator.build_P_prime(point.unsqueeze(0))
108+
# B x output_size x 2
109+
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0),
110+
self.output_size[0],
111+
self.output_size[1],
112+
2])
113+
# Crop feature according to points boundary
114+
crop_feat = feat[:, y1:y2, x1:x2].unsqueeze(0)
115+
# B x C x output_size
116+
batch_I_r = F.grid_sample(crop_feat,
117+
build_P_prime_reshape,
118+
padding_mode='border')
119+
roi_feats.append(batch_I_r)
103120
roi_feats = torch.cat(roi_feats)
104121
return roi_feats
105122

@@ -151,23 +168,19 @@ def get_fiducial_points(self, imgs, polys):
151168
if len(batch_bboxes) > 0:
152169
batch_fiducial_points = np.stack(batch_fiducial_points, axis=0)
153170

154-
# Normalize fiducial points
155-
batch_fiducial_points[:, :, 0] = (2 * batch_fiducial_points[:, :, 0] - width) / width
156-
batch_fiducial_points[:, :, 1] = (2 * batch_fiducial_points[:, :, 1] - height) / height
157-
158171
fiducial_points.append(batch_fiducial_points)
159172
return fiducial_points
160173

161-
def normalize_fiducial_points(self, imgs, img_metas, fiducial_points):
162-
""" Normalize the fiducial points coordinates to [0,1].
174+
def rescale_fiducial_points(self, imgs, img_metas, fiducial_points):
175+
""" Rescale the fiducial points coordinates.
163176
164177
Args:
165178
imgs (Tensor): input image.
166179
img_metas (dict): image meta-info.
167180
fiducial_points list(np.array): tps fiducial points.
168181
169182
Returns:
170-
list(np.array): normalized points
183+
list(np.array): Rescaled points
171184
"""
172185
normalized_fiducial_points = []
173186
for img, img_meta, point in zip(imgs, img_metas, fiducial_points):
@@ -180,10 +193,7 @@ def normalize_fiducial_points(self, imgs, img_metas, fiducial_points):
180193
point[:, :, 0] = point[:, :, 0] * scale_factor[0]
181194
point[:, :, 1] = point[:, :, 1] * scale_factor[1]
182195

183-
# Normalize
184-
point[:, :, 0] = (2 * point[:, :, 0] - width) / width
185-
point[:, :, 1] = (2 * point[:, :, 1] - height) / height
186-
196+
# Change points order
187197
point_num = int(point.shape[1] / 2)
188198
point[:, point_num:, :] = point[:, point_num:, :][:, ::-1, :]
189199
normalized_fiducial_points.append(point)

davarocr/davarocr/davar_spotting/models/spotters/text_perceptron_spot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def simple_test(self,
187187
return results
188188

189189
# Compute normalized fiducial points
190-
fiducial_points = self.recog_roi_extractor.normalize_fiducial_points(img, img_meta, fiducial_points)
190+
fiducial_points = self.recog_roi_extractor.rescale_fiducial_points(img, img_meta, fiducial_points)
191191

192192
# Extract feature according to fiducial point
193193
recog_feats = self.recog_roi_extractor(feat[:self.recog_roi_extractor.num_inputs], fiducial_points)

0 commit comments

Comments
 (0)