Skip to content

Commit 2e00438

Browse files
committed
Format EVS helper and add SPDX header
Signed-off-by: zitian.zhao <[email protected]>
1 parent 8a3cedb commit 2e00438

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

test_evs_fix.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#!/usr/bin/env python3
2+
# SPDX-License-Identifier: Apache-2.0
3+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
24
"""Simple harness to reason about EVS placeholder offsets.
35
46
The real implementation in ``iter_mm_grid_hw`` now relies on the
@@ -10,8 +12,8 @@
1012

1113
from __future__ import annotations
1214

15+
from collections.abc import Iterable
1316
from dataclasses import dataclass
14-
from typing import Iterable
1517

1618
import torch
1719

@@ -71,21 +73,26 @@ def extract_frame_offsets(
7173
f"Expected {expected_frames} frame segments, got {len(segments)}"
7274
)
7375

74-
offsets = [offset_start + int(segment[0].item()) for segment in segments[:expected_frames]]
76+
offsets = [
77+
offset_start + int(segment[0].item()) for segment in segments[:expected_frames]
78+
]
7579
lengths = [int(segment.numel()) for segment in segments[:expected_frames]]
7680
return offsets, lengths
7781

7882

7983
def test_sparse_distribution() -> None:
8084
print("\n=== 测试场景 1: 稀疏分布 (真实 EVS 行为) ===")
8185
per_frame = [50176, 15000, 12000, 10000, 8000, 145668, 5000, 5000]
82-
cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2)
86+
cfg = MaskSimulationConfig(
87+
tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2
88+
)
8389
mask = build_is_embed_mask(cfg)
8490
offsets, lengths = extract_frame_offsets(128, mask, len(per_frame))
8591

8692
for idx, (off, size, expected) in enumerate(zip(offsets, lengths, per_frame), 1):
8793
print(
88-
f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens (expected {expected})"
94+
f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens "
95+
f"(expected {expected})"
8996
)
9097
assert size == expected
9198

@@ -95,7 +102,9 @@ def test_sparse_distribution() -> None:
95102
def test_uniform_distribution() -> None:
96103
print("\n=== 测试场景 2: 均匀分布 (处理器当前实现) ===")
97104
per_frame = [784 for _ in range(4)]
98-
cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1)
105+
cfg = MaskSimulationConfig(
106+
tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1
107+
)
99108
mask = build_is_embed_mask(cfg)
100109
offsets, lengths = extract_frame_offsets(42, mask, len(per_frame))
101110

@@ -109,9 +118,7 @@ def test_uniform_distribution() -> None:
109118
for idx, (off, size, expected_offset) in enumerate(
110119
zip(offsets, lengths, expected_offsets), 1
111120
):
112-
print(
113-
f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens"
114-
)
121+
print(f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens")
115122
assert size == per_frame[idx - 1]
116123
assert off == expected_offset
117124

vllm/model_executor/models/qwen3_vl.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
12761276
if self.is_multimodal_pruning_enabled:
12771277
logger.debug(
12781278
"EVS (Efficient Video Sampling) enabled with pruning_rate=%.2f",
1279-
self.video_pruning_rate
1279+
self.video_pruning_rate,
12801280
)
12811281
if not multimodal_config.get_limit_per_prompt(
12821282
"image"
@@ -1527,9 +1527,7 @@ def _postprocess_video_embeds_evs(
15271527
second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long)
15281528
else:
15291529
second_per_grid_ts = second_per_grid_ts.long()
1530-
tokens_per_second = getattr(
1531-
self.config.vision_config, "tokens_per_second", 1.0
1532-
)
1530+
tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
15331531

15341532
video_embeds_out = []
15351533
for emb, size, video_second_per_grid_t in zip(
@@ -1549,9 +1547,11 @@ def _postprocess_video_embeds_evs(
15491547
"pruning_rate=%.2f, reduction=%.1f%%)",
15501548
emb.shape[0],
15511549
retention_mask.sum().item(),
1552-
size[0], size[1], size[2],
1550+
size[0],
1551+
size[1],
1552+
size[2],
15531553
self.video_pruning_rate,
1554-
(1 - retention_mask.float().mean().item()) * 100
1554+
(1 - retention_mask.float().mean().item()) * 100,
15551555
)
15561556

15571557
positions = compute_mrope_for_media(
@@ -1618,7 +1618,7 @@ def iter_mm_grid_hw(
16181618

16191619
# Check if EVS (Efficient Video Sampling) is enabled
16201620
is_evs_enabled = (
1621-
hasattr(self, 'video_pruning_rate')
1621+
hasattr(self, "video_pruning_rate")
16221622
and self.video_pruning_rate is not None
16231623
and self.video_pruning_rate > 0.0
16241624
)
@@ -1635,12 +1635,12 @@ def iter_mm_grid_hw(
16351635
# Fallback: distribute offsets uniformly when mask is missing
16361636
tokens_per_frame_original = llm_grid_h * llm_grid_w
16371637
total_retained_tokens = compute_retained_tokens_count(
1638-
tokens_per_frame_original,
1639-
t,
1640-
self.video_pruning_rate
1638+
tokens_per_frame_original, t, self.video_pruning_rate
16411639
)
16421640
tokens_per_frame = (
1643-
total_retained_tokens // t if t > 0 else tokens_per_frame_original
1641+
total_retained_tokens // t
1642+
if t > 0
1643+
else tokens_per_frame_original
16441644
)
16451645
for _ in range(t):
16461646
yield offset, llm_grid_h, llm_grid_w
@@ -1682,7 +1682,9 @@ def _extract_frame_offsets_from_mask(
16821682
if split_points.numel() == 0:
16831683
segments = [true_indices]
16841684
else:
1685-
segments = torch.tensor_split(true_indices, split_points.add(1).tolist())
1685+
segments = torch.tensor_split(
1686+
true_indices, split_points.add(1).tolist()
1687+
)
16861688

16871689
if len(segments) < expected_frames:
16881690
logger.debug(

0 commit comments

Comments
 (0)