Skip to content

Commit 29f9184

Browse files
authored
refactor: make processing multi_modal_input generic (#678)
This PR updates the multimodal input handling logic in areal/utils/data.py to remove the reliance on a single hard-coded key ("multi_modal_input"). The current implementation limits extensibility and can cause issues when users introduce additional multimodal fields for multimodal-specific use cases, for example multi_modal_input_augmented. This refactor introduces a more flexible and scalable approach to detecting and processing multimodal inputs. The core suggestion is that all multimodal fields should begin with the prefix "multi_modal_input". By following this naming convention, the code can use pattern-based matching to automatically identify all multimodal input keys present in the data.
1 parent 1810edf commit 29f9184

File tree

1 file changed

+35
-29
lines changed

1 file changed

+35
-29
lines changed

areal/utils/data.py

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,26 @@ def list_of_dict2dict_of_list(
8181
return {key: [dict_item[key] for dict_item in list_of_dicts] for key in keys}
8282

8383

84+
def is_multi_modal_key(key: str) -> bool:
85+
# Any key matching: multi_modal_input*
86+
return key.startswith("multi_modal_input")
87+
88+
8489
def pad_sequences_to_tensors(
8590
sequence_list: list[dict[str, Any]], pad_value: float = 0.0
8691
) -> dict[str, Any]:
8792
if not sequence_list:
8893
return {}
89-
skip_keys = {"multi_modal_input"}
9094
max_length = max(
9195
len(seq)
9296
for item in sequence_list
9397
for key, seq in item.items()
94-
if key not in skip_keys
98+
if not is_multi_modal_key(key)
9599
)
96100
result = {}
97101
for key in sequence_list[0].keys():
98102
padded = []
99-
if key == "multi_modal_input":
103+
if is_multi_modal_key(key):
100104
for i in range(len(sequence_list)):
101105
if sequence_list[i][key]:
102106
item = sequence_list[i][key][0]
@@ -118,11 +122,20 @@ def pad_sequences_to_tensors(
118122
padded.append(padded_x)
119123
result[key] = torch.stack(padded)
120124
attention_mask = [
121-
[1] * len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
125+
[1]
126+
* len(
127+
next(iter(item[key] for key in item.keys() if not is_multi_modal_key(key)))
128+
)
122129
+ [0]
123130
* (
124131
max_length
125-
- len(next(iter(item[key] for key in item.keys() if key not in skip_keys)))
132+
- len(
133+
next(
134+
iter(
135+
item[key] for key in item.keys() if not is_multi_modal_key(key)
136+
)
137+
)
138+
)
126139
)
127140
for item in sequence_list
128141
]
@@ -163,31 +176,21 @@ def concat_padded_tensors(
163176
max_length = max([x["attention_mask"].shape[1] for x in tensor_dicts])
164177
result = {}
165178

166-
has_any_multi_modal = any("multi_modal_input" in td for td in tensor_dicts)
167-
168-
merged_multi_modal = None
169-
170-
if has_any_multi_modal:
179+
multimodal_keys = {
180+
key for td in tensor_dicts for key in td if is_multi_modal_key(key)
181+
}
182+
# Merge multimodal keys
183+
for mm_key in multimodal_keys:
171184
merged_multi_modal = []
172-
173-
# Merge multi-modal data maintaining per-dp correspondence
174-
for tensor_dict in tensor_dicts:
175-
td_batch_size = get_batch_size(tensor_dict)
176-
177-
if "multi_modal_input" in tensor_dict:
178-
# Has multi_modal_input - extend the lists
179-
multi_modal = tensor_dict["multi_modal_input"]
180-
else:
181-
multi_modal = [{} for _ in range(td_batch_size)]
182-
183-
merged_multi_modal.extend(multi_modal)
184-
185-
result["multi_modal_input"] = merged_multi_modal
185+
for td in tensor_dicts:
186+
bs = get_batch_size(td)
187+
merged_multi_modal.extend(td.get(mm_key, [{} for _ in range(bs)]))
188+
result[mm_key] = merged_multi_modal
186189

187190
# Process each key
188191
for key in tensor_dicts[0].keys():
189192
tensors_to_concat = []
190-
if key == "multi_modal_input":
193+
if is_multi_modal_key(key):
191194
continue
192195
for tensor_dict in tensor_dicts:
193196
tensor = tensor_dict[key]
@@ -444,11 +447,14 @@ def split_padded_tensor_dict_into_mb_list(
444447
.numpy()
445448
)
446449

450+
# check for multimodal input data
451+
multimodal_keys = {key for key in data if is_multi_modal_key(key)}
452+
447453
# check tensor shape, split only 1d tensors with length "total_lens"
448454
to_split = {}
449455
not_to_split = {}
450456
for key, value in data.items():
451-
if key == "multi_modal_input":
457+
if key in multimodal_keys:
452458
continue
453459
if key == "position_ids" or (
454460
torch.is_tensor(value) and value.numel() == bs * max_seqlen
@@ -493,8 +499,8 @@ def _split(tensor):
493499

494500
to_split = dict_map(to_split, lambda x: _split(x))
495501

496-
if "multi_modal_input" in data:
497-
multi_modal_input = data["multi_modal_input"]
502+
for key in multimodal_keys:
503+
multi_modal_input = data[key]
498504

499505
# Prepare the pixel_values and image_grid_thw for each group
500506
multi_modal_input_split = []
@@ -504,7 +510,7 @@ def _split(tensor):
504510
# Stack pixel_values for each group (assuming pixel_values is a list of tensors)
505511
multi_modal_input_split.append(group_pixel_multi_modal_input)
506512
# Pack the split pixel_values and image_grid_thw back into the data
507-
to_split["multi_modal_input"] = multi_modal_input_split
513+
to_split[key] = multi_modal_input_split
508514
mbs = dict_of_list2list_of_dict(to_split)
509515

510516
results = []

0 commit comments

Comments
 (0)