Skip to content

Commit b40ea1e

Browse files
committed
add davar_table and LGPMA
1 parent 7588b6b commit b40ea1e

File tree

26 files changed

+2573
-0
lines changed

26 files changed

+2573
-0
lines changed

davarocr/davarocr/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .davar_spotting import *
1515
from .davar_ie import *
1616
from .davar_videotext import *
17+
from .davar_table import *
1718
from .mmcv import *
1819
from .version import __version__
1920

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
from .models import *
13+
from .core import *
14+
from .datasets import *
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
from .mask import BitmapMasksTable, get_lpmasks
13+
from .bbox import recon_noncell, recon_largecell
14+
from .post_processing import PostLGPMA
15+
16+
__all__ = ['BitmapMasksTable', 'get_lpmasks', 'recon_noncell', 'recon_largecell', 'PostLGPMA']
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
from .bbox_process import recon_noncell, recon_largecell, nms_inter_classes, bbox2adj
13+
14+
__all__ = ['recon_noncell', 'recon_largecell', 'nms_inter_classes', 'bbox2adj']
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : bbox_process.py
5+
# Abstract : Implementation of bboxes process used in LGPMA.
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
import numpy as np
13+
14+
15+
def recon_noncell(bboxlist, celllist, imgshape, padding=1):
16+
""" Produce pseudo-bboxes for empty cells
17+
18+
Args:
19+
bboxlist (list): (n x 4).Bboxes of text region in each cell(empty cell is noted as [])
20+
celllist (list): (n x 4).Start row, start column, end row and end column of each cell
21+
imgshape (tuple): (height, width).The height and width of input image.
22+
padding (int): If cells in the first/last row/col are all empty, extend them to padding pixels from boundary.
23+
24+
Returns:
25+
list(list): (n x 4).Bboxes of text region in each cell (including empty cells)
26+
"""
27+
28+
cells_non = np.array([b for a, b in zip(bboxlist, celllist) if a])
29+
bboxes_non = np.array([b for b in bboxlist if b])
30+
bboxlist_append = bboxlist.copy()
31+
cellnp = np.array(celllist, dtype='int32')
32+
for i, bbox in enumerate(bboxlist_append):
33+
if bbox:
34+
continue
35+
row = [cellnp[i, 0], cellnp[i, 2]]
36+
col = [cellnp[i, 1], cellnp[i, 3]]
37+
rowindex_top = np.where((cells_non[:, 0] == row[0]))[0]
38+
rowindex_down = np.where((cells_non[:, 2] == row[1]))[0]
39+
colindex_left = np.where((cells_non[:, 1] == col[0]))[0]
40+
colindex_right = np.where((cells_non[:, 3] == col[1]))[0]
41+
42+
# At least one cell in this row is non-empty.
43+
if len(rowindex_top):
44+
ymin = bboxes_non[rowindex_top, 1].min()
45+
46+
# All cells in this row are empty and this row is the first row.
47+
elif not row[0]:
48+
ymin = padding
49+
50+
# All cells in this row are empty and this row is not the first row.
51+
else:
52+
rowindex_top_mod = np.where((cells_non[:, 2] == row[0] - 1))[0]
53+
span_number = 1
54+
while len(rowindex_top_mod) == 0 and (row[0] - span_number) > 0:
55+
span_number += 1
56+
rowindex_top_mod = np.where((cells_non[:, 2] == row[0] - span_number))[0]
57+
if len(rowindex_top_mod) == 0:
58+
ymin = padding
59+
else:
60+
ymin = bboxes_non[rowindex_top_mod, 3].max() + padding
61+
62+
# At least one cell in this row is non-empty.
63+
if len(rowindex_down):
64+
ymax = bboxes_non[rowindex_down, 3].max()
65+
66+
# All cells in this row are empty and this row is the last row.
67+
elif row[1] >= cells_non[:, 2].max():
68+
ymax = imgshape[0] - padding
69+
70+
# All cells in this row are empty and this row is not the last row.
71+
else:
72+
rowindex_down_next = np.where((cells_non[:, 0] == row[1] + 1))[0]
73+
span_number = 1
74+
while len(rowindex_down_next) == 0 and (row[1] + span_number) <= cells_non[:, 2].max() - 1:
75+
span_number += 1
76+
rowindex_down_next = np.where((cells_non[:, 0] == row[1] + span_number))[0]
77+
if len(rowindex_down_next) == 0:
78+
ymax = imgshape[0] - padding
79+
else:
80+
ymax = bboxes_non[rowindex_down_next, 1].min() - padding
81+
82+
# At least one cell in this column is non-empty.
83+
if len(colindex_left):
84+
xmin = bboxes_non[colindex_left, 0].min()
85+
86+
# All cells in this column are empty and this column is the first column.
87+
elif not col[0]:
88+
xmin = padding
89+
90+
# All cells in this column are empty and this column is not the last column.
91+
else:
92+
colindex1_left_mod = np.where((cells_non[:, 3] == col[0] - 1))[0]
93+
span_number = 1
94+
while len(colindex1_left_mod) == 0 and (col[0] - span_number) > 0:
95+
span_number += 1
96+
colindex1_left_mod = np.where((cells_non[:, 3] == col[0] - span_number))[0]
97+
if len(colindex1_left_mod) == 0:
98+
xmin = padding
99+
else:
100+
xmin = bboxes_non[colindex1_left_mod, 2].max() + padding
101+
102+
# At least one cell in this column is non-empty.
103+
if len(colindex_right):
104+
xmax = bboxes_non[colindex_right, 2].max()
105+
106+
# All cells in this column are empty and this column is the last column.
107+
elif col[1] > cells_non[:, 3].max():
108+
xmax = imgshape[1] - padding
109+
110+
# All cells in this column are empty and this column is not the last column.
111+
else:
112+
colindex_right_mod = np.where((cells_non[:, 1] == col[1] + 1))[0]
113+
span_number = 1
114+
while len(colindex_right_mod) == 0 and (col[1] + span_number) <= cells_non[:, 3].max() - 1:
115+
span_number += 1
116+
colindex_right_mod = np.where((cells_non[:, 1] == col[1] + span_number))[0]
117+
if len(colindex_right_mod) == 0:
118+
xmax = imgshape[1] - padding
119+
else:
120+
xmax = bboxes_non[colindex_right_mod, 0].min() - padding
121+
bboxlist_append[i] = list(map(int, [xmin, ymin, xmax, ymax]))
122+
123+
return bboxlist_append
124+
125+
126+
def recon_largecell(bboxlist, celllist):
127+
""" Produce pseudo-bboxes for aligned cells
128+
129+
Args:
130+
bboxlist (list): (n x 4).Bboxes of text region in each cell (including empty cells)
131+
celllist (list): (n x 4).Start row, start column, end row and end column of each cell
132+
133+
Returns:
134+
list(list): (n x 4).Bboxes of aligned cells (including empty cells)
135+
"""
136+
137+
bboxlist_align = bboxlist.copy()
138+
bboxnp = np.array(bboxlist, dtype='int32')
139+
cellnp = np.array(celllist, dtype='int32')
140+
for i in range(len(bboxlist)):
141+
row = [cellnp[i, 0], cellnp[i, 2]]
142+
col = [cellnp[i, 1], cellnp[i, 3]]
143+
rowindex1 = np.where((cellnp[:, 0] == row[0]))[0]
144+
rowindex2 = np.where((cellnp[:, 2] == row[1]))[0]
145+
colindex1 = np.where((cellnp[:, 1] == col[0]))[0]
146+
colindex2 = np.where((cellnp[:, 3] == col[1]))[0]
147+
newbbox = [bboxnp[colindex1, 0].min(), bboxnp[rowindex1, 1].min(), bboxnp[colindex2, 2].max(),
148+
bboxnp[rowindex2, 3].max()]
149+
bboxlist_align[i] = list(map(int, newbbox))
150+
151+
return bboxlist_align
152+
153+
154+
def rect_max_iou(box_1, box_2):
155+
"""Calculate the maximum IoU between two boxes: the intersect area / the area of the smaller box
156+
157+
Args:
158+
box_1 (np.array | list): [x1, y1, x2, y2]
159+
box_2 (np.array | list): [x1, y1, x2, y2]
160+
161+
Returns:
162+
float: maximum IoU between the two boxes
163+
"""
164+
165+
addone = 0 # 0 in mmdet2.0 / 1 in mmdet 1.0
166+
box_1, box_2 = np.array(box_1), np.array(box_2)
167+
168+
x_start = np.maximum(box_1[0], box_2[0])
169+
y_start = np.maximum(box_1[1], box_2[1])
170+
x_end = np.minimum(box_1[2], box_2[2])
171+
y_end = np.minimum(box_1[3], box_2[3])
172+
173+
area1 = (box_1[2] - box_1[0] + addone) * (box_1[3] - box_1[1] + addone)
174+
area2 = (box_2[2] - box_2[0] + addone) * (box_2[3] - box_2[1] + addone)
175+
overlap = np.maximum(x_end - x_start + addone, 0) * np.maximum(y_end - y_start + addone, 0)
176+
177+
return overlap / min(area1, area2)
178+
179+
180+
def nms_inter_classes(bboxes, iou_thres=0.3):
181+
"""NMS between all classes
182+
183+
Args:
184+
bboxes(list): [bboxes in cls1(np.array), bboxes in cls2(np.array), ...]. bboxes of each classes
185+
iou_thres(float): nsm threshold
186+
187+
Returns:
188+
np.array: (n x 4).bboxes of targets after NMS between all classes
189+
list(list): (n x 1).labels of targets after NMS between all classes
190+
"""
191+
192+
lable_id = 0
193+
merge_bboxes, merge_labels = [], []
194+
for bboxes_cls in bboxes:
195+
if lable_id:
196+
merge_bboxes = np.concatenate((merge_bboxes, bboxes_cls), axis=0)
197+
else:
198+
merge_bboxes = bboxes_cls
199+
merge_labels += [[lable_id]] * len(bboxes_cls)
200+
lable_id += 1
201+
202+
mark = np.ones(len(merge_bboxes), dtype=int)
203+
score_index = merge_bboxes[:, -1].argsort()[::-1]
204+
for i, cur in enumerate(score_index):
205+
if mark[cur] == 0:
206+
continue
207+
for ind in score_index[i + 1:]:
208+
if mark[ind] == 1 and rect_max_iou(merge_bboxes[cur], merge_bboxes[ind]) >= iou_thres:
209+
mark[ind] = 0
210+
new_bboxes = merge_bboxes[mark == 1, :4]
211+
new_labels = np.array(merge_labels)[mark == 1]
212+
new_labels = [list(map(int, lab)) for lab in new_labels]
213+
214+
return new_bboxes, new_labels
215+
216+
217+
def bbox2adj(bboxes_non):
218+
"""Calculating row and column adjacent relationships according to bboxes of non-empty aligned cells
219+
220+
Args:
221+
bboxes_non(np.array): (n x 4).bboxes of non-empty aligned cells
222+
223+
Returns:
224+
np.array: (n x n).row adjacent relationships of non-empty aligned cells
225+
np.array: (n x n).column adjacent relationships of non-empty aligned cells
226+
"""
227+
228+
adjr = np.zeros([bboxes_non.shape[0], bboxes_non.shape[0]], dtype='int')
229+
adjc = np.zeros([bboxes_non.shape[0], bboxes_non.shape[0]], dtype='int')
230+
x_middle = bboxes_non[:, ::2].mean(axis=1)
231+
y_middle = bboxes_non[:, 1::2].mean(axis=1)
232+
for i, box in enumerate(bboxes_non):
233+
indexr = np.where((bboxes_non[:, 1] < y_middle[i]) & (bboxes_non[:, 3] > y_middle[i]))[0]
234+
indexc = np.where((bboxes_non[:, 0] < x_middle[i]) & (bboxes_non[:, 2] > x_middle[i]))[0]
235+
adjr[indexr, i], adjr[i, indexr] = 1, 1
236+
adjc[indexc, i], adjc[i, indexc] = 1, 1
237+
238+
# Determine if there are special row relationship
239+
for j, box2 in enumerate(bboxes_non):
240+
if not (box2[1] + 4 >= box[3] or box[1] + 4 >= box2[3]):
241+
indexr2 = np.where((max(box[1], box2[1]) < y_middle[:]) & (y_middle[:] < min(box[3], box2[3])))[0]
242+
if len(indexr2):
243+
adjr[j, i], adjr[i, j] = 1, 1
244+
245+
# Determine if there are special column relationship
246+
for j, box2 in enumerate(bboxes_non):
247+
if not (box2[0] + 0 >= box[2] or box[0] + 0 >= box2[2]):
248+
indexc2 = np.where((max(box[0], box2[0]) < x_middle[:]) & (x_middle[:] < min(box[2], box2[2])))[0]
249+
if len(indexc2):
250+
adjc[j, i], adjc[i, j] = 1, 1
251+
252+
return adjr, adjc
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
from .structures import BitmapMasksTable
13+
from .lp_mask_target import get_lpmasks
14+
15+
__all__ = ['BitmapMasksTable', 'get_lpmasks']
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : lp_mask_target.py
5+
# Abstract : Produce local pyramid mask according to gt_bbox and gt_mask.
6+
7+
# Current Version: 1.0.0
8+
# Date : 2021-09-18
9+
##################################################################################################
10+
"""
11+
12+
from math import ceil
13+
import numpy as np
14+
from .structures import BitmapMasksTable
15+
16+
17+
def get_lpmasks(gt_masks, gt_bboxes):
18+
"""Produce local pyramid mask according to gt_bbox and gt_mask (for a batch of imags).
19+
20+
Args:
21+
gt_masks(list(BitmapMasks)): masks of the text regions
22+
gt_bboxes(list(Tensor)): bboxes of the aligned cells
23+
24+
Returns:
25+
list(BitmapMasks):pyramid masks in horizontal direction
26+
list(BitmapMasks):pyramid masks in vertical direction
27+
"""
28+
29+
gt_masks_temp = map(get_lpmask_single, gt_masks, gt_bboxes)
30+
gt_masks_temp = list(gt_masks_temp)
31+
gt_lpmasks_hor = [temp[0] for temp in gt_masks_temp]
32+
gt_lpmasks_ver = [temp[1] for temp in gt_masks_temp]
33+
34+
return gt_lpmasks_hor, gt_lpmasks_ver
35+
36+
37+
def get_lpmask_single(gt_mask, gt_bbox):
38+
"""Produce local pyramid mask according to gt_bbox and gt_mask ((for one image).
39+
40+
Args;
41+
gt_mask(BitmapMasks): masks of the text regions (for one image)
42+
gt_bbox(Tensor): (n x 4).bboxes of the aligned cells (for one image)
43+
44+
Returns;
45+
BitmapMasksTable;pyramid masks in horizontal direction (for one image)
46+
BitmapMasksTable;pyramid masks in vertical direction (for one image)
47+
"""
48+
49+
(num, high, width) = gt_mask.masks.shape
50+
mask_s1 = np.zeros((num, high, width), np.float32)
51+
mask_s2 = np.zeros((num, high, width), np.float32)
52+
for ind, box_text in zip(range(num), gt_mask.masks):
53+
left_col, left_row, right_col, right_row = list(map(float, gt_bbox[ind, 0:4]))
54+
x_min, y_min, x_max, y_max = ceil(left_col), ceil(left_row), ceil(right_col) - 1, ceil(right_row) - 1
55+
middle_x, middle_y = round(np.where(box_text == 1)[1].mean()), round(np.where(box_text == 1)[0].mean())
56+
57+
# Calculate the pyramid mask in horizontal direction
58+
col_np = np.arange(x_min, x_max + 1).reshape(1, -1)
59+
col_np_1 = (col_np[:, :middle_x - x_min] - left_col) / (middle_x - left_col)
60+
col_np_2 = (right_col - col_np[:, middle_x - x_min:]) / (right_col - middle_x)
61+
col_np = np.concatenate((col_np_1, col_np_2), axis=1)
62+
mask_s1[ind, y_min:y_max + 1, x_min:x_max + 1] = col_np
63+
64+
# Calculate the pyramid mask in vertical direction
65+
row_np = np.arange(y_min, y_max + 1).reshape(-1, 1)
66+
row_np_1 = (row_np[:middle_y - y_min, :] - left_row) / (middle_y - left_row)
67+
row_np_2 = (right_row - row_np[middle_y - y_min:, :]) / (right_row - middle_y)
68+
row_np = np.concatenate((row_np_1, row_np_2), axis=0)
69+
mask_s2[ind, y_min:y_max + 1, x_min:x_max + 1] = row_np
70+
71+
mask_s1 = BitmapMasksTable(mask_s1, high, width)
72+
mask_s2 = BitmapMasksTable(mask_s2, high, width)
73+
74+
return mask_s1, mask_s2

0 commit comments

Comments
 (0)