Skip to content

Commit e881686

Browse files
hukkaily015
authored andcommitted
[Feature] Support Omni-source training on ImageNet and Kinetics dataset. (#2143)
1 parent 6fdad85 commit e881686

File tree

17 files changed

+1158
-42
lines changed

17 files changed

+1158
-42
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
_base_ = ['../../_base_/default_runtime.py']
2+
3+
# model settings
4+
model = dict(
5+
type='RecognizerOmni',
6+
backbone=dict(type='OmniResNet'),
7+
cls_head=dict(
8+
type='OmniHead',
9+
image_classes=1000,
10+
video_classes=400,
11+
in_channels=2048,
12+
average_clips='prob'),
13+
data_preprocessor=dict(
14+
type='ActionDataPreprocessor',
15+
mean=[123.675, 116.28, 103.53],
16+
std=[58.395, 57.12, 57.375],
17+
format_shape='MIX2d3d'))
18+
19+
# dataset settings
20+
image_root = 'data/imagenet/'
21+
image_ann_train = 'meta/train.txt'
22+
23+
video_root = 'data/kinetics400/videos_train'
24+
video_root_val = 'data/kinetics400/videos_val'
25+
video_ann_train = 'data/kinetics400/kinetics400_train_list_videos.txt'
26+
video_ann_val = 'data/kinetics400/kinetics400_val_list_videos.txt'
27+
28+
num_images = 1281167 # number of training samples in the ImageNet dataset
29+
num_videos = 240435 # number of training samples in the Kinetics400 dataset
30+
batchsize_video = 16
31+
num_gpus = 8
32+
num_iter = num_videos // (batchsize_video * num_gpus)
33+
batchsize_image = num_images // (num_iter * num_gpus)
34+
35+
train_pipeline = [
36+
dict(type='DecordInit'),
37+
dict(type='SampleFrames', clip_len=8, frame_interval=8, num_clips=1),
38+
dict(type='DecordDecode'),
39+
dict(type='Resize', scale=(-1, 256)),
40+
dict(type='RandomResizedCrop'),
41+
dict(type='Resize', scale=(224, 224), keep_ratio=False),
42+
dict(type='Flip', flip_ratio=0.5),
43+
dict(type='FormatShape', input_format='NCTHW'),
44+
dict(type='PackActionInputs')
45+
]
46+
47+
val_pipeline = [
48+
dict(type='DecordInit'),
49+
dict(
50+
type='SampleFrames',
51+
clip_len=8,
52+
frame_interval=8,
53+
num_clips=1,
54+
test_mode=True),
55+
dict(type='DecordDecode'),
56+
dict(type='Resize', scale=(-1, 256)),
57+
dict(type='CenterCrop', crop_size=224),
58+
dict(type='FormatShape', input_format='NCTHW'),
59+
dict(type='PackActionInputs')
60+
]
61+
62+
test_pipeline = [
63+
dict(type='DecordInit'),
64+
dict(
65+
type='SampleFrames',
66+
clip_len=8,
67+
frame_interval=8,
68+
num_clips=10,
69+
test_mode=True),
70+
dict(type='DecordDecode'),
71+
dict(type='Resize', scale=(-1, 256)),
72+
dict(type='ThreeCrop', crop_size=256),
73+
dict(type='FormatShape', input_format='NCTHW'),
74+
dict(type='PackActionInputs')
75+
]
76+
77+
train_dataloader = dict(
78+
batch_size=batchsize_video,
79+
num_workers=4,
80+
persistent_workers=True,
81+
sampler=dict(type='DefaultSampler', shuffle=True),
82+
dataset=dict(
83+
type='VideoDataset',
84+
ann_file=video_ann_train,
85+
data_prefix=dict(video=video_root),
86+
pipeline=train_pipeline))
87+
88+
val_dataloader = dict(
89+
batch_size=16,
90+
num_workers=4,
91+
persistent_workers=True,
92+
sampler=dict(type='DefaultSampler', shuffle=False),
93+
dataset=dict(
94+
type='VideoDataset',
95+
ann_file=video_ann_val,
96+
data_prefix=dict(video=video_root_val),
97+
pipeline=val_pipeline,
98+
test_mode=True))
99+
100+
test_dataloader = dict(
101+
batch_size=1,
102+
num_workers=8,
103+
persistent_workers=True,
104+
sampler=dict(type='DefaultSampler', shuffle=False),
105+
dataset=dict(
106+
type='VideoDataset',
107+
ann_file=video_ann_val,
108+
data_prefix=dict(video=video_root_val),
109+
pipeline=test_pipeline,
110+
test_mode=True))
111+
112+
imagenet_pipeline = [
113+
dict(type='LoadRGBFromFile'),
114+
dict(type='mmcls.RandomResizedCrop', scale=224),
115+
dict(type='mmcls.RandomFlip', prob=0.5, direction='horizontal'),
116+
dict(type='mmcls.PackClsInputs'),
117+
]
118+
119+
image_dataloader = dict(
120+
batch_size=batchsize_image,
121+
num_workers=8,
122+
dataset=dict(
123+
type='mmcls.ImageNet',
124+
data_root=image_root,
125+
ann_file=image_ann_train,
126+
data_prefix='train',
127+
pipeline=imagenet_pipeline),
128+
sampler=dict(type='DefaultSampler', shuffle=True),
129+
)
130+
131+
val_evaluator = dict(type='AccMetric')
132+
test_evaluator = val_evaluator
133+
134+
train_cfg = dict(
135+
type='MultiLoaderEpochBasedTrainLoop',
136+
other_loaders=[image_dataloader],
137+
max_epochs=256,
138+
val_interval=4)
139+
140+
val_cfg = dict(type='ValLoop')
141+
test_cfg = dict(type='TestLoop')
142+
143+
# learning policy
144+
param_scheduler = [
145+
dict(
146+
type='LinearLR',
147+
start_factor=0.1,
148+
by_epoch=True,
149+
begin=0,
150+
end=34,
151+
convert_to_iter_based=True),
152+
dict(
153+
type='CosineAnnealingLR',
154+
T_max=222,
155+
eta_min=0,
156+
by_epoch=True,
157+
begin=34,
158+
end=256,
159+
convert_to_iter_based=True)
160+
]
161+
"""
162+
The learning rate is for total_batch_size = 16 x 16 (num_gpus x batch_size)
163+
If you want to use other batch size or number of GPU settings, please update
164+
the learning rate with the linear scaling rule.
165+
"""
166+
optim_wrapper = dict(
167+
optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001),
168+
clip_grad=dict(max_norm=40, norm_type=2))
169+
170+
# runtime settings
171+
default_hooks = dict(checkpoint=dict(interval=4, max_keep_ckpts=3))

mmaction/datasets/transforms/__init__.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
DecordInit, DenseSampleFrames,
77
GenerateLocalizationLabels, ImageDecode,
88
LoadAudioFeature, LoadHVULabel, LoadLocalizationFeature,
9-
LoadProposals, OpenCVDecode, OpenCVInit, PIMSDecode,
10-
PIMSInit, PyAVDecode, PyAVDecodeMotionVector, PyAVInit,
11-
RawFrameDecode, SampleAVAFrames, SampleFrames,
9+
LoadProposals, LoadRGBFromFile, OpenCVDecode, OpenCVInit,
10+
PIMSDecode, PIMSInit, PyAVDecode, PyAVDecodeMotionVector,
11+
PyAVInit, RawFrameDecode, SampleAVAFrames, SampleFrames,
1212
UniformSample, UntrimmedSampleFrames)
1313
from .pose_transforms import (GeneratePoseTarget, GenSkeFeat, JointToBone,
1414
LoadKineticsPose, MergeSkeFeat, PadTo,
@@ -21,20 +21,20 @@
2121
from .wrappers import ImgAug, PytorchVideoWrapper, TorchVisionWrapper
2222

2323
__all__ = [
24-
'SampleFrames', 'PyAVDecode', 'DecordDecode', 'DenseSampleFrames',
25-
'OpenCVDecode', 'MultiScaleCrop', 'RandomResizedCrop', 'RandomCrop',
26-
'Resize', 'Flip', 'Fuse', 'ThreeCrop', 'CenterCrop', 'TenCrop',
27-
'Transpose', 'FormatShape', 'GenerateLocalizationLabels',
28-
'LoadLocalizationFeature', 'LoadProposals', 'DecordInit', 'OpenCVInit',
29-
'PyAVInit', 'UntrimmedSampleFrames', 'RawFrameDecode', 'DecordInit',
30-
'OpenCVInit', 'PyAVInit', 'ColorJitter', 'LoadHVULabel', 'SampleAVAFrames',
31-
'AudioAmplify', 'MelSpectrogram', 'AudioDecode', 'FormatAudioShape',
32-
'LoadAudioFeature', 'AudioFeatureSelector', 'AudioDecodeInit',
33-
'ImageDecode', 'BuildPseudoClip', 'RandomRescale', 'PIMSDecode',
34-
'PyAVDecodeMotionVector', 'UniformSampleFrames', 'PoseDecode',
35-
'LoadKineticsPose', 'GeneratePoseTarget', 'PIMSInit', 'FormatGCNInput',
36-
'PadTo', 'ArrayDecode', 'JointToBone', 'PackActionInputs',
37-
'PackLocalizationInputs', 'ImgAug', 'TorchVisionWrapper',
38-
'PytorchVideoWrapper', 'PoseCompact', 'PreNormalize3D', 'ToMotion',
39-
'MergeSkeFeat', 'GenSkeFeat', 'PreNormalize2D', 'UniformSample'
24+
'ArrayDecode', 'AudioAmplify', 'AudioDecode', 'AudioDecodeInit',
25+
'AudioFeatureSelector', 'BuildPseudoClip', 'CenterCrop', 'ColorJitter',
26+
'DecordDecode', 'DecordInit', 'DecordInit', 'DenseSampleFrames', 'Flip',
27+
'FormatAudioShape', 'FormatGCNInput', 'FormatShape', 'Fuse', 'GenSkeFeat',
28+
'GenerateLocalizationLabels', 'GeneratePoseTarget', 'ImageDecode',
29+
'ImgAug', 'JointToBone', 'LoadAudioFeature', 'LoadHVULabel',
30+
'LoadKineticsPose', 'LoadLocalizationFeature', 'LoadProposals',
31+
'LoadRGBFromFile', 'MelSpectrogram', 'MergeSkeFeat', 'MultiScaleCrop',
32+
'OpenCVDecode', 'OpenCVInit', 'OpenCVInit', 'PIMSDecode', 'PIMSInit',
33+
'PackActionInputs', 'PackLocalizationInputs', 'PadTo', 'PoseCompact',
34+
'PoseDecode', 'PreNormalize2D', 'PreNormalize3D', 'PyAVDecode',
35+
'PyAVDecodeMotionVector', 'PyAVInit', 'PyAVInit', 'PytorchVideoWrapper',
36+
'RandomCrop', 'RandomRescale', 'RandomResizedCrop', 'RawFrameDecode',
37+
'Resize', 'SampleAVAFrames', 'SampleFrames', 'TenCrop', 'ThreeCrop',
38+
'ToMotion', 'TorchVisionWrapper', 'Transpose', 'UniformSample',
39+
'UniformSampleFrames', 'UntrimmedSampleFrames'
4040
]

mmaction/datasets/transforms/loading.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,92 @@
1616
from mmaction.utils import get_random_string, get_shm_dir, get_thread_id
1717

1818

19+
@TRANSFORMS.register_module()
20+
class LoadRGBFromFile(BaseTransform):
21+
"""Load a RGB image from file.
22+
23+
Required Keys:
24+
25+
- img_path
26+
27+
Modified Keys:
28+
29+
- img
30+
- img_shape
31+
- ori_shape
32+
33+
Args:
34+
to_float32 (bool): Whether to convert the loaded image to a float32
35+
numpy array. If set to False, the loaded image is an uint8 array.
36+
Defaults to False.
37+
color_type (str): The flag argument for :func:``mmcv.imfrombytes``.
38+
Defaults to 'color'.
39+
imdecode_backend (str): The image decoding backend type. The backend
40+
argument for :func:``mmcv.imfrombytes``.
41+
See :func:``mmcv.imfrombytes`` for details.
42+
Defaults to 'cv2'.
43+
io_backend (str): io backend where frames are store.
44+
Default: 'disk'.
45+
ignore_empty (bool): Whether to allow loading empty image or file path
46+
not existent. Defaults to False.
47+
kwargs (dict): Args for file client.
48+
"""
49+
50+
def __init__(self,
51+
to_float32: bool = False,
52+
color_type: str = 'color',
53+
imdecode_backend: str = 'cv2',
54+
io_backend: str = 'disk',
55+
ignore_empty: bool = False,
56+
**kwargs) -> None:
57+
self.ignore_empty = ignore_empty
58+
self.to_float32 = to_float32
59+
self.color_type = color_type
60+
self.imdecode_backend = imdecode_backend
61+
self.file_client = FileClient(io_backend, **kwargs)
62+
self.io_backend = io_backend
63+
64+
def transform(self, results: dict) -> dict:
65+
"""Functions to load image.
66+
67+
Args:
68+
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
69+
70+
Returns:
71+
dict: The dict contains loaded image and meta information.
72+
"""
73+
74+
filename = results['img_path']
75+
try:
76+
img_bytes = self.file_client.get(filename)
77+
img = mmcv.imfrombytes(
78+
img_bytes,
79+
flag=self.color_type,
80+
channel_order='rgb',
81+
backend=self.imdecode_backend)
82+
except Exception as e:
83+
if self.ignore_empty:
84+
return None
85+
else:
86+
raise e
87+
if self.to_float32:
88+
img = img.astype(np.float32)
89+
90+
results['img'] = img
91+
results['img_shape'] = img.shape[:2]
92+
results['ori_shape'] = img.shape[:2]
93+
return results
94+
95+
def __repr__(self):
96+
repr_str = (f'{self.__class__.__name__}('
97+
f'ignore_empty={self.ignore_empty}, '
98+
f'to_float32={self.to_float32}, '
99+
f"color_type='{self.color_type}', "
100+
f"imdecode_backend='{self.imdecode_backend}', "
101+
f"io_backend='{self.io_backend}')")
102+
return repr_str
103+
104+
19105
@TRANSFORMS.register_module()
20106
class LoadHVULabel(BaseTransform):
21107
"""Convert the HVU label from dictionaries to torch tensors.

mmaction/engine/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .hooks import * # noqa: F401, F403
33
from .model import * # noqa: F401, F403
44
from .optimizers import * # noqa: F401, F403
5+
from .runner import * # noqa: F401, F403

mmaction/engine/runner/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .multi_loop import MultiLoaderEpochBasedTrainLoop
3+
4+
__all__ = ['MultiLoaderEpochBasedTrainLoop']

0 commit comments

Comments
 (0)