Skip to content

Commit c58bdeb

Browse files
committed
release NER models
1 parent c473ce6 commit c58bdeb

File tree

98 files changed

+33368
-63
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+33368
-63
lines changed

davarocr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from .davar_layout import *
1717
from .davar_videotext import *
1818
from .davar_table import *
19+
from .davar_nlp_common import *
20+
from .davar_ner import *
1921
from .mmcv import *
2022
from .version import __version__
2123

davarocr/davar_common/apis/inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def init_model(config, checkpoint=None, device='cuda:0', cfg_options=None):
5757
elif cfg_types == "SPOTTER":
5858
from davarocr.davar_spotting.models.builder import build_spotter
5959
model = build_spotter(config.model, test_cfg=config.get('test_cfg'))
60+
elif cfg_types == "NER":
61+
from davarocr.davar_ner.models.builder import build_ner
62+
model = build_ner(config.model, test_cfg=config.get('test_cfg'))
6063
else:
6164
raise NotImplementedError
6265

davarocr/davar_common/core/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
# Date : 2021-05-20
99
##################################################################################################
1010
"""
11-
from .builder import POSTPROCESS, build_postprocess
11+
from .builder import POSTPROCESS, build_postprocess, CONVERTERS, build_converter
1212
from .evaluation import DavarDistEvalHook, DavarEvalHook
1313

1414

1515
__all__ = ['POSTPROCESS',
1616
'build_postprocess',
17-
17+
'CONVERTERS',
18+
'build_converter',
1819
"DavarEvalHook",
1920
"DavarDistEvalHook",
2021
]

davarocr/davar_common/core/builder.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,30 @@
88
# Date : 2020-05-31
99
##################################################################################################
1010
"""
11-
from mmcv.utils import Registry
12-
from mmdet.models.builder import build
11+
from mmcv.utils import Registry, build_from_cfg
1312

1413
POSTPROCESS = Registry('postprocess')
15-
14+
CONVERTERS = Registry('converter')
1615

1716
def build_postprocess(cfg):
1817
""" Build POSTPROCESS module
1918
2019
Args:
21-
cfg(dict): module configuration
20+
cfg(mmcv.Config): module configuration
2221
2322
Returns:
2423
obj: POSTPROCESS module
2524
"""
26-
return build(cfg, POSTPROCESS)
25+
return build_from_cfg(cfg, POSTPROCESS)
26+
27+
28+
def build_converter(cfg):
29+
"""
30+
Args:
31+
cfg (mmcv.Config): model config):
32+
33+
Returns:
34+
obj: CONVERTER module
35+
36+
"""
37+
return build_from_cfg(cfg, CONVERTERS)

davarocr/davar_ner/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-05-06
9+
##################################################################################################
10+
"""
11+
from .datasets import *
12+
from .models import *
13+
from .core import *
14+
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-05-06
9+
##################################################################################################
10+
"""
11+
from .evaluation import eval_ner_f1
12+
from .converters import SpanConverter, TransformersConverter
13+
14+
__all__ = [
15+
'eval_ner_f1',
16+
'SpanConverter',
17+
'TransformersConverter'
18+
]
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : __init__.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-05-06
9+
##################################################################################################
10+
"""
11+
from .transformers_converter import TransformersConverter
12+
from .span_converter import SpanConverter
13+
14+
15+
__all__ = ['TransformersConverter', 'SpanConverter']
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : base_converter.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-05-06
9+
##################################################################################################
10+
"""
11+
from abc import ABCMeta, abstractmethod
12+
13+
14+
class BaseConverter(metaclass=ABCMeta):
15+
""" Base converter, Convert between text, index and tensor for NER pipeline.
16+
"""
17+
@abstractmethod
18+
def convert_text2id(self, results):
19+
""" Convert token to ids.
20+
21+
Args:
22+
results (dict): A dict must containing the token key:
23+
- tokens (list]): Tokens list.
24+
Returns:
25+
dict: corresponding ids
26+
"""
27+
pass
28+
29+
@abstractmethod
30+
def convert_pred2entities(self, preds, masks, **kwargs):
31+
""" Gets entities from preds.
32+
33+
Args:
34+
preds (list): Sequence of preds.
35+
masks (Tensor): The valid part is 1 and the invalid part is 0.
36+
Returns:
37+
list: List of entities.
38+
"""
39+
pass
40+
41+
@abstractmethod
42+
def convert_entity2label(self, labels):
43+
""" Convert labeled entities to ids.
44+
45+
Args:
46+
labels (list): eg:['B-PER', 'I-PER']
47+
Returns:
48+
dict: corresponding labels
49+
"""
50+
pass
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""
2+
##################################################################################################
3+
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
4+
# Filename : span_converter.py
5+
# Abstract :
6+
7+
# Current Version: 1.0.0
8+
# Date : 2022-05-06
9+
##################################################################################################
10+
"""
11+
from seqeval.scheme import Tokens, IOBES
12+
from davarocr.davar_common.core import CONVERTERS
13+
from .transformers_converter import TransformersConverter
14+
15+
16+
@CONVERTERS.register_module()
17+
class SpanConverter(TransformersConverter):
18+
"""Span converter, converter for span model.
19+
"""
20+
def _generate_labelid_dict(self):
21+
label2id_dict = {label: i for i, label in enumerate(['O'] + self.label_list)}
22+
id2label_dict = {value: key for key, value in label2id_dict.items()}
23+
return label2id_dict, id2label_dict
24+
25+
26+
def _extract_subjects(self, seq):
27+
"""Get entities from label sequence
28+
"""
29+
entities = [(t.to_tuple()[1], t.to_tuple()[2], t.to_tuple()[3]) for t in Tokens(seq, IOBES).entities]
30+
return entities
31+
32+
33+
def convert_entity2label(self, labels):
34+
"""Convert labeled entities to ids.
35+
36+
Args:
37+
labels (list): eg:['B-PER', 'I-PER']
38+
39+
Returns:
40+
dict: corresponding ids
41+
"""
42+
labels = self._labels_convert(labels, self.only_label_first_subword)
43+
cls_token_at_end=self.cls_token_at_end
44+
pad_on_left = self.pad_on_left
45+
label2id = self.label2id_dict
46+
subjects = self._extract_subjects(labels)#get entities
47+
start_ids = [0] * len(labels)
48+
end_ids = [0] * len(labels)
49+
subjects_id = []
50+
for subject in subjects:
51+
label = subject[0]
52+
start = subject[1]
53+
end = subject[2]
54+
55+
#set label for span
56+
start_ids[start] = label2id[label]
57+
end_ids[end-1] = label2id[label]#the true position is end-1
58+
subjects_id.append((label2id[label], start, end))
59+
60+
# Account for [CLS] and [SEP] with "- 2".
61+
special_tokens_count = 2
62+
if len(labels) > self.max_len - special_tokens_count:
63+
start_ids = start_ids[: (self.max_len - special_tokens_count)]
64+
end_ids = end_ids[: (self.max_len - special_tokens_count)]
65+
66+
#add sep
67+
start_ids += [0]
68+
end_ids += [0]
69+
if cls_token_at_end:
70+
#add [CLS] at end
71+
start_ids += [0]
72+
end_ids += [0]
73+
else:
74+
#add [CLS] at begin
75+
start_ids = [0]+ start_ids
76+
end_ids = [0]+ end_ids
77+
padding_length = self.max_len - len(labels) - 2
78+
if pad_on_left:
79+
#pad on left
80+
start_ids = ([0] * padding_length) + start_ids
81+
end_ids = ([0] * padding_length) + end_ids
82+
else:
83+
#pad on right
84+
start_ids += ([0] * padding_length)
85+
end_ids += ([0] * padding_length)
86+
res = dict(start_positions=start_ids, end_positions=end_ids)
87+
return res
88+
89+
def convert_pred2entities(self, preds, masks, **kwargs):
90+
"""Gets entities from preds.
91+
92+
Args:
93+
preds (list): Sequence of preds.
94+
masks (tensor): The valid part is 1 and the invalid part is 0.
95+
Returns:
96+
list: List of [[[entity_type,
97+
entity_start, entity_end]]].
98+
"""
99+
id2label = self.id2label
100+
pred_entities = []
101+
for pred in preds:
102+
entities = []
103+
entity = [0, 0, 0]
104+
for tag in pred:
105+
entity[0] = id2label[tag[0]]
106+
entity[1] = tag[1] - 1
107+
entity[2] = tag[2] - 1
108+
entities.append(entity.copy())
109+
pred_entities.append(entities.copy())
110+
tokens_index = [index.cpu().numpy().tolist()[0] for index in kwargs['tokens_index']]
111+
pred_entities = [self._labels_convert_ori(pred_entity, tokens_index) for pred_entity in pred_entities]
112+
return pred_entities

0 commit comments

Comments
 (0)