Skip to content

Commit 2bbb4c3

Browse files
committed
some minor function update
1 parent 2e6d203 commit 2bbb4c3

File tree

22 files changed

+436
-53
lines changed

22 files changed

+436
-53
lines changed

davarocr/davar_common/apis/test.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,16 @@
2222
from mmdet.core import encode_mask_results
2323
from mmdet.apis.test import collect_results_cpu, collect_results_gpu
2424

25+
from davarocr.mmcv import DavarProgressBar
26+
2527

2628
def single_gpu_test(model,
2729
data_loader,
2830
show=False,
2931
out_dir=None,
3032
show_score_thr=0.3,
31-
model_type="DETECTOR"):
33+
model_type="DETECTOR",
34+
min_time_interval=1):
3235
""" Test model with single GPU, used for visualization.
3336
3437
Args:
@@ -37,15 +40,16 @@ def single_gpu_test(model,
3740
show (boolean): whether to show visualization
3841
out_dir (str): visualization results saved path
3942
show_score_thr (float): the threshold to show visualization.
40-
model_type(float): model type indicator, used to formalize final results.
43+
model_type(str): model type indicator, used to formalize final results.
44+
min_time_interval(int): progressbar minimal update unit
4145
Returns:
4246
dict: test results
4347
"""
4448

4549
model.eval()
4650
results = []
4751
dataset = data_loader.dataset
48-
prog_bar = mmcv.ProgressBar(len(dataset))
52+
prog_bar = DavarProgressBar(len(dataset), min_time_interval=min_time_interval)
4953
for _, data in enumerate(data_loader):
5054
with torch.no_grad():
5155
result = model(return_loss=False, rescale=True, **data)
@@ -95,7 +99,7 @@ def single_gpu_test(model,
9599
result = list(zip(result["text"], result["length"]))
96100
else:
97101
result = result["text"]
98-
batch_size = len(result)
102+
batch_size = len(result) if not isinstance(result[0], list) else len(result[0])
99103
elif model_type == "SPOTTER":
100104
pass
101105
# if isinstance(result[0], dict):
@@ -118,7 +122,8 @@ def multi_gpu_test(model,
118122
data_loader,
119123
tmpdir=None,
120124
gpu_collect=False,
121-
model_type="DETECTOR"):
125+
model_type="DETECTOR",
126+
min_time_interval=1):
122127
"""Test model with multiple gpus.
123128
124129
This method tests model with multiple gpus and collects the results
@@ -133,6 +138,8 @@ def multi_gpu_test(model,
133138
tmpdir (str): Path of directory to save the temporary results from
134139
different gpus under cpu mode.
135140
gpu_collect (bool): Option to use either gpu or cpu to collect results.
141+
model_type(str): model type indicator, used to formalize final results.
142+
min_time_interval(int): progressbar minimal update unit
136143
137144
Returns:
138145
list(dict): The prediction results.
@@ -142,7 +149,7 @@ def multi_gpu_test(model,
142149
dataset = data_loader.dataset
143150
rank, world_size = get_dist_info()
144151
if rank == 0:
145-
prog_bar = mmcv.ProgressBar(len(dataset))
152+
prog_bar = DavarProgressBar(len(dataset), min_time_interval=min_time_interval)
146153
time.sleep(2) # This line can prevent deadlock problem in some cases.
147154
for _, data in enumerate(data_loader):
148155

@@ -158,12 +165,17 @@ def multi_gpu_test(model,
158165
elif model_type == "RECOGNIZOR":
159166
if "prob" in result:
160167
result = result["text"]
168+
if isinstance(result[0], list):
169+
result = result[0]
161170
elif "length" in result and "text" not in result:
162171
result = result["length"]
163172
elif "length" in result and "text" in result:
164173
result = list(zip(result["text"], result["length"]))
165174
else:
166175
result = result["text"]
176+
if isinstance(result[0], list):
177+
result = result[0]
178+
167179
elif model_type == "SPOTTER":
168180
pass
169181
# if isinstance(result[0], dict):

davarocr/davar_common/apis/train.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,15 @@ def train_model(model,
137137
# Support batch_size > 1 in validation
138138
val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
139139
if val_samples_per_gpu > 1:
140-
# Replace 'ImageToTensor' to 'DefaultFormatBundle'
141-
cfg.data.val.pipeline = replace_ImageToTensor(
142-
cfg.data.val.get("pipeline", cfg.data.val.dataset.get("pipeline", None)))
140+
# in case the test dataset is concatenated
141+
val_pipeline = cfg.data.val.get("pipeline", cfg.data.val.dataset.get("pipeline", None))
142+
# supported multi dataset with different validation pipelines
143+
if isinstance(val_pipeline[0], dict):
144+
cfg.data.val.pipeline = replace_ImageToTensor(val_pipeline)
145+
elif isinstance(val_pipeline[0], list):
146+
cfg.data.val.pipeline = [
147+
replace_ImageToTensor(this_pipeline) for this_pipeline in val_pipeline]
148+
143149
val_dataset = davar_build_dataset(cfg.data.val, dict(test_mode=True))
144150
val_dataloader = davar_build_dataloader(
145151
val_dataset,

davarocr/davar_common/datasets/builder.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,6 @@ def parameter_align(cfg):
254254
if isinstance(cfg["dataset"]["img_prefix"], str):
255255
cfg["dataset"]["img_prefix"] = cfg["dataset"]["img_prefix"].split('|')
256256

257-
assert len(batch_ratios) == len(cfg["dataset"]["ann_file"]),\
258-
'the numbers of the batch ratios should equal to the numbers of the annotation files'
259-
260257
dataset_num = len(batch_ratios)
261258

262259
for key, item in cfg["dataset"].items():

davarocr/davar_common/datasets/davar_custom.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,6 @@ def process_anns(self, idx):
303303
img_info = copy.deepcopy(self.data_infos[idx].get('ann', None))
304304
if self.classes_config is not None:
305305
img_info['labels'] = [per[0] for per in img_info['labels']]
306-
307306
bboxes = []
308307
labels = []
309308
bboxes_ignore = []
@@ -313,11 +312,16 @@ def process_anns(self, idx):
313312
cares = [1] * len(img_info['labels'])
314313

315314
for i, care in enumerate(cares):
315+
x_min = min(img_info['bboxes'][i][0::2])
316+
x_max = max(img_info['bboxes'][i][0::2])
317+
y_min = min(img_info['bboxes'][i][1::2])
318+
y_max = max(img_info['bboxes'][i][1::2])
319+
rect_box = [x_min, y_min, x_max, y_max]
316320
if care:
317-
bboxes.append(img_info['bboxes'][i])
321+
bboxes.append(rect_box)
318322
labels.append(self.classes_config['classes'].index(img_info['labels'][i]))
319323
else:
320-
bboxes_ignore.append(img_info['bboxes'][i])
324+
bboxes_ignore.append(rect_box)
321325
labels_ignore.append(self.classes_config['classes'].index(img_info['labels'][i]))
322326
bboxes = np.array(bboxes).reshape(-1, 4)
323327
bboxes_ignore = np.array(bboxes_ignore).reshape(-1, 4)
@@ -359,7 +363,16 @@ def evaluate(self,
359363
allowed_metrics = ['mAP', 'recall']
360364
if metric not in allowed_metrics:
361365
raise KeyError(f'metric {metric} is not supported')
362-
# annotations = [self.get_ann_info(i) for i in range(len(self))]
366+
if len(results) > 0 and isinstance(results[0], dict):
367+
num_classes = len(self.classes_config['classes'])
368+
tmp_results = []
369+
for res in results:
370+
points = np.array(res['points']).reshape(-1, 4)
371+
scores = np.array(res['scores']).reshape(-1, 1)
372+
labels = np.array(res['labels'])
373+
bboxes = np.concatenate([points, scores], axis=-1)
374+
tmp_results.append([bboxes[labels == i, :] for i in range(num_classes)])
375+
results = tmp_results
363376
annotations = [self.process_anns(i) for i in range(len(self))]
364377
eval_results = OrderedDict()
365378
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr

davarocr/davar_common/datasets/davar_multi_dataset.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,5 +153,25 @@ def evaluate(self,
153153
dict: model evaluation metric
154154
155155
"""
156-
validation_result = self.datasets[0].evaluate(results, metric, logger, **eval_kwargs)
156+
157+
# use the group samples to validate
158+
group_samples = self.flag["group_samples"]
159+
start_idx = 0
160+
validation_result = dict()
161+
for dataset_idx, group_sample in enumerate(group_samples):
162+
this_results = results[start_idx:start_idx + group_sample]
163+
this_validation_result = self.datasets[dataset_idx].evaluate(
164+
this_results, metric, logger, **eval_kwargs)
165+
# record the each dataset info
166+
for key, value in this_validation_result.items():
167+
this_key = "{}_set{}".format(key, dataset_idx)
168+
validation_result[this_key] = value
169+
# calculate the average performance
170+
if dataset_idx == 0:
171+
validation_result[key] = value / len(group_samples)
172+
else:
173+
validation_result[key] += value / len(group_samples)
174+
# update the sample index
175+
start_idx += group_sample
176+
157177
return validation_result

davarocr/davar_common/datasets/pipelines/davar_loading.py

Lines changed: 113 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import cv2
1818
import pycocotools.mask as maskUtils
1919
import numpy as np
20+
2021
from mmdet.datasets.builder import PIPELINES
2122
from mmdet.core import BitmapMasks, PolygonMasks
2223

@@ -134,6 +135,7 @@ def __init__(self,
134135
text_profile=None,
135136
label_start_index=0,
136137
poly2mask=True,
138+
only_quad=False
137139
):
138140
""" Parameter initialization
139141
@@ -172,6 +174,7 @@ def __init__(self,
172174
according to `classes_config`. The start label will be added. e.g., for mmdet 1.x,
173175
this value is set to [1]; for mmdet 2.x, this will be set to [0].
174176
poly2mask (boolean): Whether to convert the instance masks from polygons to bitmaps. Default: True.
177+
only_quad (boolean): Whether only quad format annotation supported.
175178
"""
176179
self.with_bbox = with_bbox
177180
self.with_poly_bbox = with_poly_bbox
@@ -182,9 +185,10 @@ def __init__(self,
182185
self.with_text = with_text
183186
self.bieo_labels = bieo_labels
184187
self.text_profile = text_profile
185-
self.label_start_index=label_start_index
188+
self.label_start_index = label_start_index
186189
self.with_cbbox = with_cbbox
187190
self.poly2mask = poly2mask
191+
self.only_quad = only_quad
188192

189193
assert not (self.with_label and self.with_multi_label), \
190194
"Only one of with_label and with_multi_label can be true"
@@ -324,10 +328,25 @@ def _load_poly_bboxes(self, results):
324328
gt_poly_bboxes = []
325329
gt_poly_bboxes_ignore = []
326330

331+
height, width = results['img_info']['height'], results['img_info']['width']
332+
327333
for i, box in enumerate(tmp_gt_bboxes):
334+
for cor_idx in range(0, len(box), 2):
335+
box[cor_idx] = min(max(0, box[cor_idx]), width)
336+
box[cor_idx + 1] = min(max(0, box[cor_idx + 1]), height)
337+
328338
# If the bboxes are labeled in 2-point form, then transfer it into 4-point form.
329339
if len(box) == 4:
330340
box = [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]]
341+
342+
if self.only_quad and len(box) != 8:
343+
continue
344+
345+
if self.only_quad:
346+
box = self.sorted_bbox_convex(box.copy())
347+
if not self.is_convex(box.copy()):
348+
continue
349+
331350
if cares[i] == 1:
332351
gt_poly_bboxes.append(np.array(box))
333352
else:
@@ -390,6 +409,67 @@ def process_polygons(self, polygons):
390409
valid_polygons.append(polygon)
391410
return valid_polygons
392411

412+
def is_convex(self, bbox, area=2):
413+
""" Determine if a quadrilateral is a convex polygon
414+
415+
Args:
416+
bbox (list[float]): coordinate
417+
area (int): minimum area
418+
419+
Returns:
420+
bool: whether a convex polygon
421+
"""
422+
pre = 1
423+
n = 8
424+
for i in range(n // 2):
425+
cur = (bbox[(i * 2 + 2) % n] - bbox[i * 2]) * (bbox[(i * 2 + 5) % n] - bbox[(i * 2 + 3) % n]) \
426+
- (bbox[(i * 2 + 4) % n] - bbox[(i * 2 + 2) % n])\
427+
* (bbox[(i * 2 + 3) % n] - bbox[(i * 2 + 1) % n])
428+
if cur < area:
429+
return False
430+
else:
431+
if cur * pre < 0:
432+
return False
433+
else:
434+
pre = cur
435+
return True
436+
437+
def sorted_bbox_convex(self, bbox):
438+
"""
439+
Args:
440+
bbox (list[float]): coordinate
441+
442+
Returns:
443+
list[float]: sorted bbox
444+
"""
445+
assert len(bbox) == 8
446+
447+
bbox = [[bbox[0], bbox[1]], [bbox[2], bbox[3]], [bbox[4], bbox[5]], [bbox[6], bbox[7]]]
448+
tmp_bbox = bbox.copy()
449+
tmp_bbox = sorted(tmp_bbox, key=lambda x: x[0])
450+
new_bbox = []
451+
452+
if tmp_bbox[0][1] < tmp_bbox[1][1]:
453+
new_bbox.append(tmp_bbox[0])
454+
tmp_bbox.pop(0)
455+
else:
456+
new_bbox.append(tmp_bbox[1])
457+
tmp_bbox.pop(1)
458+
459+
tmp_bbox = sorted(tmp_bbox, key=lambda x: x[1])
460+
for idx in range(len(tmp_bbox)):
461+
if tmp_bbox[idx][0] > new_bbox[0][0]:
462+
new_bbox.append(tmp_bbox[idx])
463+
tmp_bbox.pop(idx)
464+
break
465+
466+
tmp_bbox = sorted(tmp_bbox, key=lambda x: x[0], reverse=True)
467+
new_bbox.append(tmp_bbox[0])
468+
new_bbox.append(tmp_bbox[1])
469+
470+
new_bbox = [i for cor in new_bbox for i in cor]
471+
return new_bbox
472+
393473
def _load_polymasks(self, results):
394474
"""Private function to load mask annotations.
395475
@@ -407,21 +487,46 @@ def _load_polymasks(self, results):
407487
cares = results["cares"]
408488
polygons = ann.get('bboxes', [])
409489
valid_polygons = []
490+
invalid_polygons = []
491+
410492
for i, box in enumerate(polygons):
493+
for cor_idx in range(0, len(box), 2):
494+
box[cor_idx] = min(max(0, box[cor_idx]), width)
495+
box[cor_idx + 1] = min(max(0, box[cor_idx + 1]), height)
496+
497+
# If the bboxes are labeled in 2-point form, then transfer it into 4-point form.
498+
if len(box) == 4:
499+
box = [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]]
500+
501+
if self.only_quad and len(box) != 8:
502+
continue
503+
504+
if self.only_quad:
505+
box = self.sorted_bbox_convex(box.copy())
506+
if not self.is_convex(box.copy()):
507+
continue
508+
411509
if cares[i] == 1:
412-
# Handle the case of 2-point annotation
413-
if len(box) == 4:
414-
box = [box[0], box[1], box[2], box[1], box[2], box[3], box[0], box[3]]
415510
valid_polygons.append([np.array(box)])
511+
else:
512+
invalid_polygons.append([np.array(box)])
416513

417514
if self.poly2mask:
418515
gt_masks = BitmapMasks(
419516
[self._poly2mask(mask, height, width) for mask in valid_polygons], height, width)
517+
gt_masks_ignore = BitmapMasks(
518+
[self._poly2mask(mask, height, width) for mask in invalid_polygons], height, width)
420519
else:
421520
gt_masks = PolygonMasks(
422521
[self.process_polygons(polygons) for polygons in valid_polygons], height, width)
522+
gt_masks_ignore = PolygonMasks(
523+
[self.process_polygons(polygons) for polygons in invalid_polygons], height, width)
524+
423525
results['gt_masks'] = gt_masks
526+
results['gt_masks_ignore'] = gt_masks_ignore
527+
424528
results['mask_fields'].append('gt_masks')
529+
results['mask_fields'].append('gt_masks_ignore')
425530
return results
426531

427532
def _load_labels(self, results):
@@ -445,7 +550,10 @@ def _load_labels(self, results):
445550
self.label_start_index = self.label_start_index[0]
446551

447552
# If there is no `labels` in annotation, set `label_start_index` as the default value for all bboxes.
448-
if tmp_labels is None or len(tmp_labels)==0:
553+
if tmp_labels is None:
554+
tmp_labels = [[self.label_start_index]] * bboxes_length
555+
# If `labels` in annotation are empty, set `label_start_index` as the default value for all bboxes.
556+
elif len(tmp_labels) == 0:
449557
tmp_labels = [[self.label_start_index]] * bboxes_length
450558

451559
gt_labels = []

0 commit comments

Comments
 (0)