Skip to content

Commit baa264b

Browse files
hukkaily015
authored andcommitted
[Feature] support repeat_aug (#2170)
1 parent e881686 commit baa264b

File tree

3 files changed

+226
-8
lines changed

3 files changed

+226
-8
lines changed

mmaction/datasets/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
from .base import BaseActionDataset
66
from .pose_dataset import PoseDataset
77
from .rawframe_dataset import RawframeDataset
8+
from .repeat_aug_dataset import RepeatAugDataset, repeat_pseudo_collate
89
from .transforms import * # noqa: F401, F403
910
from .video_dataset import VideoDataset
1011

1112
__all__ = [
12-
'VideoDataset',
13-
'RawframeDataset',
14-
'AVADataset',
15-
'AVAKineticsDataset',
16-
'PoseDataset',
17-
'BaseActionDataset',
18-
'ActivityNetDataset',
19-
'AudioDataset',
13+
'AVADataset', 'AVAKineticsDataset', 'ActivityNetDataset', 'AudioDataset',
14+
'BaseActionDataset', 'PoseDataset', 'RawframeDataset', 'RepeatAugDataset',
15+
'VideoDataset', 'repeat_pseudo_collate'
2016
]
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from copy import deepcopy
3+
from typing import Any, Callable, List, Optional, Sequence, Union
4+
5+
import numpy as np
6+
from mmengine.dataset import COLLATE_FUNCTIONS, pseudo_collate
7+
8+
from mmaction.registry import DATASETS
9+
from mmaction.utils import ConfigType
10+
from .video_dataset import VideoDataset
11+
12+
13+
def get_type(transform: Union[dict, Callable]) -> str:
14+
"""get the type of the transform."""
15+
if isinstance(transform, dict) and 'type' in transform:
16+
return transform['type']
17+
elif callable(transform):
18+
return transform.__repr__().split('(')[0]
19+
else:
20+
raise TypeError
21+
22+
23+
@DATASETS.register_module()
24+
class RepeatAugDataset(VideoDataset):
25+
"""Video dataset for action recognition.
26+
27+
The dataset loads raw videos and apply specified transforms to return a
28+
dict containing the frame tensors and other information.
29+
30+
The ann_file is a text file with multiple lines, and each line indicates
31+
a sample video with the filepath and label, which are split with a
32+
whitespace. Example of a annotation file:
33+
34+
.. code-block:: txt
35+
36+
some/path/000.mp4 1
37+
some/path/001.mp4 1
38+
some/path/002.mp4 2
39+
some/path/003.mp4 2
40+
some/path/004.mp4 3
41+
some/path/005.mp4 3
42+
43+
44+
Args:
45+
ann_file (str): Path to the annotation file.
46+
pipeline (List[Union[dict, ConfigDict, Callable]]): A sequence of
47+
data transforms.
48+
data_prefix (dict or ConfigDict): Path to a directory where videos
49+
are held. Defaults to ``dict(video='')``.
50+
multi_class (bool): Determines whether the dataset is a multi-class
51+
dataset. Defaults to False.
52+
num_classes (int, optional): Number of classes of the dataset, used in
53+
multi-class datasets. Defaults to None.
54+
start_index (int): Specify a start index for frames in consideration of
55+
different filename format. However, when taking videos as input,
56+
it should be set to 0, since frames loaded from videos count
57+
from 0. Defaults to 0.
58+
modality (str): Modality of data. Support ``RGB``, ``Flow``.
59+
Defaults to ``RGB``.
60+
test_mode (bool): Store True when building test or validation dataset.
61+
Defaults to False.
62+
"""
63+
64+
def __init__(self,
65+
ann_file: str,
66+
pipeline: List[Union[dict, Callable]],
67+
data_prefix: ConfigType = dict(video=''),
68+
num_repeats: int = 4,
69+
multi_class: bool = False,
70+
num_classes: Optional[int] = None,
71+
start_index: int = 0,
72+
modality: str = 'RGB',
73+
**kwargs) -> None:
74+
75+
use_decord = get_type(pipeline[0]) == 'DecordInit' and \
76+
get_type(pipeline[2]) == 'DecordDecode'
77+
78+
assert use_decord, (
79+
'RepeatAugDataset requires decord as the video '
80+
'loading backend, will support more backends in the '
81+
'future')
82+
83+
super().__init__(
84+
ann_file,
85+
pipeline=pipeline,
86+
data_prefix=data_prefix,
87+
multi_class=multi_class,
88+
num_classes=num_classes,
89+
start_index=start_index,
90+
modality=modality,
91+
test_mode=False,
92+
**kwargs)
93+
self.num_repeats = num_repeats
94+
95+
def prepare_data(self, idx) -> List[dict]:
96+
"""Get data processed by ``self.pipeline``.
97+
98+
Reduce the video loading and decompressing.
99+
Args:
100+
idx (int): The index of ``data_info``.
101+
Returns:
102+
List[dict]: A list of length num_repeats.
103+
"""
104+
transforms = self.pipeline.transforms
105+
106+
data_info = self.get_data_info(idx)
107+
data_info = transforms[0](data_info) # DecordInit
108+
109+
frame_inds_list, frame_inds_length = [], [0]
110+
111+
fake_data_info = dict(
112+
total_frames=data_info['total_frames'],
113+
start_index=data_info['start_index'])
114+
115+
for repeat in range(self.num_repeats):
116+
data_info_ = transforms[1](fake_data_info) # SampleFrames
117+
frame_inds = data_info_['frame_inds']
118+
frame_inds_list.append(frame_inds.reshape(-1))
119+
frame_inds_length.append(frame_inds.size + frame_inds_length[-1])
120+
121+
for key in data_info_:
122+
data_info[key] = data_info_[key]
123+
124+
data_info['frame_inds'] = np.concatenate(frame_inds_list)
125+
126+
data_info = transforms[2](data_info) # DecordDecode
127+
imgs = data_info.pop('imgs')
128+
129+
data_info_list = []
130+
for repeat in range(self.num_repeats):
131+
data_info_ = deepcopy(data_info)
132+
start = frame_inds_length[repeat]
133+
end = frame_inds_length[repeat + 1]
134+
data_info_['imgs'] = imgs[start:end]
135+
for transform in transforms[3:]:
136+
data_info_ = transform(data_info_)
137+
data_info_list.append(data_info_)
138+
del imgs
139+
return data_info_list
140+
141+
142+
@COLLATE_FUNCTIONS.register_module()
143+
def repeat_pseudo_collate(data_batch: Sequence) -> Any:
144+
data_batch = [i for j in data_batch for i in j]
145+
return pseudo_collate(data_batch)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import pytest
3+
from mmengine.testing import assert_dict_has_keys
4+
5+
from mmaction.datasets import RepeatAugDataset
6+
from mmaction.utils import register_all_modules
7+
from .base import BaseTestDataset
8+
9+
10+
class TestVideoDataset(BaseTestDataset):
11+
register_all_modules()
12+
13+
def test_video_dataset(self):
14+
with pytest.raises(AssertionError):
15+
# Currently only support decord backend
16+
video_dataset = RepeatAugDataset(
17+
self.video_ann_file,
18+
self.video_pipeline,
19+
data_prefix={'video': self.data_prefix},
20+
start_index=3)
21+
22+
video_pipeline = [
23+
dict(type='DecordInit'),
24+
dict(
25+
type='SampleFrames', clip_len=4, frame_interval=2,
26+
num_clips=1),
27+
dict(type='DecordDecode')
28+
]
29+
30+
video_dataset = RepeatAugDataset(
31+
self.video_ann_file,
32+
video_pipeline,
33+
data_prefix={'video': self.data_prefix},
34+
start_index=3)
35+
assert len(video_dataset) == 2
36+
assert video_dataset.start_index == 3
37+
38+
video_dataset = RepeatAugDataset(
39+
self.video_ann_file,
40+
video_pipeline,
41+
data_prefix={'video': self.data_prefix})
42+
assert video_dataset.start_index == 0
43+
44+
def test_video_dataset_multi_label(self):
45+
video_pipeline = [
46+
dict(type='DecordInit'),
47+
dict(
48+
type='SampleFrames', clip_len=4, frame_interval=2,
49+
num_clips=1),
50+
dict(type='DecordDecode')
51+
]
52+
video_dataset = RepeatAugDataset(
53+
self.video_ann_file_multi_label,
54+
video_pipeline,
55+
data_prefix={'video': self.data_prefix},
56+
multi_class=True,
57+
num_classes=100)
58+
assert video_dataset.start_index == 0
59+
60+
def test_video_pipeline(self):
61+
video_pipeline = [
62+
dict(type='DecordInit'),
63+
dict(
64+
type='SampleFrames', clip_len=4, frame_interval=2,
65+
num_clips=1),
66+
dict(type='DecordDecode')
67+
]
68+
target_keys = ['filename', 'label', 'start_index', 'modality']
69+
70+
# RepeatAugDataset not in test mode
71+
video_dataset = RepeatAugDataset(
72+
self.video_ann_file,
73+
video_pipeline,
74+
data_prefix={'video': self.data_prefix})
75+
result = video_dataset[0]
76+
assert isinstance(result, (list, tuple))
77+
assert assert_dict_has_keys(result[0], target_keys)

0 commit comments

Comments
 (0)