Skip to content

Commit 89ff46f

Browse files
committed
release DLD code
1 parent 2bbb4c3 commit 89ff46f

File tree

24 files changed

+2369
-7
lines changed

24 files changed

+2369
-7
lines changed

davarocr/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .davar_nlp_common import *
2020
from .davar_ner import *
2121
from .davar_order import *
22+
from .davar_distill import *
2223
from .mmcv import *
2324
from .version import __version__
2425

davarocr/davar_distill/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .models import *
12+
from .dataset import *
13+
from .core import *
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .beam_search import beam_decode
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : beam_search.py
5+
# Abstract : Beam search for attention decode
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
import torch
12+
from queue import PriorityQueue
13+
14+
15+
class BeamSearchNode(object):
16+
""" Beam search node class """
17+
def __init__(self, previous_node, char_id, logProb, length):
18+
"""
19+
Args:
20+
previous_node (obj:`BeamSearchNode`): node in queue
21+
char_id (dict): character id
22+
logProb (float): word probability
23+
length (int): word length
24+
"""
25+
self.prev_node = previous_node
26+
self.char_id = char_id
27+
self.logp = logProb
28+
self.leng = length
29+
30+
def eval(self):
31+
""" Calculate beam search path score
32+
33+
Returns:
34+
float: beam search path score
35+
"""
36+
return self.logp / float(self.leng - 1 + 1e-6)
37+
38+
def __lt__(self, other):
39+
"""
40+
Args:
41+
self (obj:`BeamSearchNode`): beam search node
42+
other (obj:`BeamSearchNode`): beam search node
43+
"""
44+
if self.eval() < other.eval():
45+
return False
46+
else:
47+
return True
48+
49+
50+
def beam_decode(encoder_outputs, beam_width=5, topk=1):
51+
""" Beam search decode
52+
53+
Args:
54+
encoder_outputs (Tensor): encoder outputs tensor of shape [B, T, C]
55+
where B is the batch size and T is the maximum length of the output sentence
56+
beam_width (int): beam search width
57+
topk (int): select top-k beam search result
58+
59+
Returns:
60+
list(list(Tensor)): beam search decoded path
61+
"""
62+
decoded_batch = []
63+
64+
# decoding goes sentence by sentence
65+
for idx in range(encoder_outputs.size(0)):
66+
# Start with the start of the sentence token
67+
decoder_input = torch.tensor([[0]], device=encoder_outputs.device).long()
68+
69+
# Number of sentence to generate
70+
endnodes = []
71+
number_required = min((topk + 1), topk - len(endnodes))
72+
73+
# starting node - previous node, char id, logp, length
74+
node = BeamSearchNode(None, decoder_input, 0, 1)
75+
nodes = PriorityQueue()
76+
77+
# start the queue
78+
nodes.put(node)
79+
qsize = 1
80+
81+
# start beam search
82+
while True:
83+
# give up when decoding takes too long
84+
if qsize > 2000:
85+
break
86+
87+
# fetch the best node
88+
priority_node = nodes.get()
89+
decoder_input = priority_node.char_id
90+
91+
if priority_node.char_id.item() == 1 and priority_node.prev_node != None:
92+
endnodes.append(priority_node)
93+
# if we reached maximum # of sentences required
94+
if len(endnodes) >= number_required:
95+
break
96+
else:
97+
continue
98+
99+
# PUT HERE REAL BEAM SEARCH OF TOP
100+
log_prob, indexes = torch.topk(encoder_outputs[idx][priority_node.leng-1], beam_width)
101+
nextnodes = []
102+
103+
for new_k in range(beam_width):
104+
decoded_t = indexes[new_k].view(1, -1)
105+
log_p = log_prob[new_k].item()
106+
107+
node = BeamSearchNode(priority_node, decoded_t, priority_node.logp + log_p, priority_node.leng + 1)
108+
# score = -node.eval()
109+
nextnodes.append(node)
110+
111+
# put them into queue
112+
for i in range(len(nextnodes)):
113+
nextnode = nextnodes[i]
114+
nodes.put(nextnode)
115+
# increase qsize
116+
qsize += len(nextnodes) - 1
117+
118+
# choose nbest paths, back trace them
119+
if len(endnodes) == 0:
120+
endnodes = [nodes.get() for _ in range(topk)]
121+
122+
utterances = []
123+
for endnode in sorted(endnodes, key=lambda x: x.eval()):
124+
utterance = []
125+
utterance.append(endnode.char_id)
126+
# back trace
127+
while endnode.prev_node != None:
128+
endnode = endnode.prev_node
129+
utterance.append(endnode.char_id)
130+
131+
utterance = utterance[::-1]
132+
utterances.append(utterance)
133+
134+
stack_utterances = []
135+
for path_id in range(len(utterances)):
136+
stack_utterances.append(torch.stack(utterances[path_id], dim=-1).squeeze(0).squeeze(0))
137+
decoded_batch.append(stack_utterances)
138+
139+
return decoded_batch
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .pipelines import DistillFormatBundle
12+
13+
__all__ = ['DistillFormatBundle']
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .distill_formating import DistillFormatBundle
12+
13+
__all__ = ['DistillFormatBundle']
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : distill_formating.py
5+
# Abstract : Definition of data formating process for knowledge distillation
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
import numpy as np
12+
from mmcv.parallel import DataContainer as DC
13+
14+
from mmdet.datasets.builder import PIPELINES
15+
from mmdet.datasets.pipelines.formating import to_tensor, DefaultFormatBundle
16+
17+
18+
@PIPELINES.register_module()
19+
class DistillFormatBundle(DefaultFormatBundle):
20+
""" The common data format pipeline used by DavarCustom dataset. including,
21+
(1) transferred into Tensor (2) contained by DataContainer (3) put on device (GPU|CPU)
22+
23+
- keys in ['img', 'gt_semantic_seg'] will be transferred into Tensor and put on GPU
24+
- keys in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore','gt_labels', 'stn_params']
25+
will be transferred into Tensor
26+
- keys in ['gt_masks', 'gt_poly_bboxes', 'gt_poly_bboxes_ignore', 'gt_cbboxes',
27+
'gt_cbboxes_ignore', 'gt_texts', 'gt_text'] will be put on CPU
28+
"""
29+
30+
def __call__(self, results):
31+
for key in ['img', 'hr_img']:
32+
if key in results:
33+
img = results[key]
34+
if len(img.shape) < 3:
35+
img = np.expand_dims(img, -1)
36+
img = np.ascontiguousarray(img.transpose(2, 0, 1))
37+
results[key] = DC(to_tensor(img), stack=True)
38+
39+
for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', 'stn_params']:
40+
if key in results:
41+
results[key] = DC(to_tensor(results[key]))
42+
ori_key = 'hr_' + key
43+
if ori_key in results:
44+
results[ori_key] = DC(to_tensor(results[ori_key]))
45+
46+
if 'gt_semantic_seg' in results:
47+
results['gt_semantic_seg'] = DC(
48+
to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
49+
50+
# Updated keys by DavarCustom dataset
51+
for key in ['gt_masks', 'gt_poly_bboxes', 'gt_poly_bboxes_ignore', 'gt_cbboxes',
52+
'gt_cbboxes_ignore', 'gt_texts', 'gt_text', 'array_gt_texts', 'gt_bieo_labels']:
53+
if key in results:
54+
results[key] = DC(results[key], cpu_only=True)
55+
ori_key = 'hr_' + key
56+
if ori_key in results:
57+
results[ori_key] = DC(results[ori_key], cpu_only=True)
58+
59+
return results
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .connect import ResolutionSelector
12+
from .distillation import SpotResolutionDistillation
13+
from .spotters import KDTwoStageEndToEnd
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-07-07
9+
##################################################################################################
10+
"""
11+
from .resolution_selector import ResolutionSelector
12+
13+
__all__ = ['ResolutionSelector']

0 commit comments

Comments
 (0)