From 39cbed576513570b43e7e372fcc87cdf80048c2c Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sat, 29 Nov 2025 00:22:54 +0800 Subject: [PATCH 01/22] Add EVS (Efficient Video Sampling) support for Qwen3-VL model Implement multimodal pruning capabilities to optimize video token processing: - Add video_pruning_rate configuration support - Implement EVS-based video embedding pruning with retention masks - Add MRoPE position recomputation for pruned sequences - Add postprocessing for both image and video embeddings - Include test coverage for the new functionality Signed-off-by: zitian.zhao --- .../multimodal/generation/test_qwen3_vl.py | 148 ++++++++++++++ vllm/model_executor/models/qwen3_vl.py | 183 ++++++++++++++++++ 2 files changed, 331 insertions(+) create mode 100644 tests/models/multimodal/generation/test_qwen3_vl.py diff --git a/tests/models/multimodal/generation/test_qwen3_vl.py b/tests/models/multimodal/generation/test_qwen3_vl.py new file mode 100644 index 000000000000..04ab43253f02 --- /dev/null +++ b/tests/models/multimodal/generation/test_qwen3_vl.py @@ -0,0 +1,148 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.multimodal.video import sample_frames_from_video + +from ....conftest import VIDEO_ASSETS + +models = ["Qwen/Qwen3-VL-3B-Instruct"] +target_dtype = "bfloat16" + +VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" + + +def qwen3_vl_chat_template(*query): + return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 + + +VIDEO_PROMPTS = VIDEO_ASSETS.prompts( + { + "baby_reading": qwen3_vl_chat_template( + VIDEO_PLACEHOLDER, + "Describe this video with a short sentence ", + "(no more than 20 words)", + ), + } +) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_qwen3_vl_evs_functionality( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, +) -> None: + """Test EVS (Efficient Video Sampling) functionality with different + pruning rates. + """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + prompts = [VIDEO_PROMPTS[0]] + videos = [sampled_vids[0]] + + # Initialize model with EVS configuration + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + dtype=dtype, + limit_mm_per_prompt={"video": 1}, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) + + # Basic validation that we got a response + assert len(outputs) == 1 + output_ids, output_text = outputs[0] + + # Ensure we got some output + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) + + +@pytest.mark.core_model +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) +@pytest.mark.parametrize("num_frames", [16]) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("use_bytecode_hook", [True, False]) +def test_qwen3_vl_evs_batched_videos( + vllm_runner, + video_assets, + model, + video_pruning_rate: float, + num_frames: int, + dtype: str, + max_tokens: int, + use_bytecode_hook: bool, + monkeypatch, +) -> None: + """Test EVS functionality with batched videos. + + This test validates that: + 1. The model handles batched video inputs correctly with EVS + 2. Both pruning configurations work with multiple videos + 3. The model doesn't crash when processing multiple videos simultaneously + """ + # Set the environment variable for this test + monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") + # Sample frames from video assets + sampled_vids = [ + sample_frames_from_video(asset.np_ndarrays, num_frames) + for asset in video_assets + ] + + # Test batched videos + prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]] + videos = [sampled_vids[0], sampled_vids[0]] # Use same video twice for testing + + # Initialize model with EVS configuration + with vllm_runner( + model, + runner="generate", + max_model_len=4000, + max_num_seqs=2, + dtype=dtype, + limit_mm_per_prompt={"video": 2}, + tensor_parallel_size=1, + video_pruning_rate=video_pruning_rate, + ) as vllm_model: + # Generate output - this should not crash + outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) + + # Basic validation that we got responses for both videos + assert len(outputs) == 2 + + for output_ids, output_text in outputs: + # Ensure we got some output for each video + assert len(output_ids) > 0 + assert len(output_text) > 0 + + # Ensure the output is a string + assert isinstance(output_text, str) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f0ba631e6680..fe2669b7638b 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -67,6 +67,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, @@ -92,6 +98,7 @@ SupportsLoRA, SupportsMRoPE, SupportsMultiModal, + SupportsMultiModalPruning, SupportsPP, ) from .qwen2_5_vl import ( @@ -1043,6 +1050,21 @@ def get_video_replacement_qwen3vl(item_idx: int): for curr_time in timestamps ] num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + + # EVS-specific code + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + T, H, W = map(int, grid_thw) + tokens_per_frame = (H // image_processor.merge_size) * ( + W // image_processor.merge_size + ) + num_tokens_per_frame = compute_retained_tokens_count( + tokens_per_frame, + T, + video_pruning_rate, + ) // T # Divide by T to get tokens per frame + # End of EVS-specific code + placeholder = [] for frame_idx in frames_idx_token: placeholder.extend(frame_idx) @@ -1189,6 +1211,7 @@ class Qwen3VLForConditionalGeneration( SupportsPP, SupportsMRoPE, SupportsEagle3, + SupportsMultiModalPruning, ): merge_by_field_config = True multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} @@ -1234,6 +1257,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): @@ -1420,6 +1447,97 @@ def _process_video_input( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) + def _postprocess_image_embeds_evs( + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + + def _postprocess_video_embeds_evs( + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input.get("second_per_grid_ts") + if second_per_grid_ts is None: + # For Qwen3-VL, second_per_grid_ts might not be available + # Use default value of 1.0 for each video + second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) + else: + second_per_grid_ts = second_per_grid_ts.long() + tokens_per_second = getattr( + self.config.vision_config, "tokens_per_second", 1.0 + ) + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: @@ -1461,6 +1579,63 @@ def iter_mm_grid_hw( else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def get_mrope_input_positions( self, input_tokens: list[int], @@ -1510,9 +1685,17 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings From e109c35ffdbd1f4e71adf4fa053c801d47f5cba0 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sat, 29 Nov 2025 18:42:30 +0800 Subject: [PATCH 02/22] Add debug logging for EVS (Efficient Video Sampling) functionality - Add INFO level log during model initialization to show EVS is enabled and display the configured pruning rate - Add DEBUG level log during video processing to show detailed pruning statistics including original/retained token counts, video dimensions, and actual reduction percentage - No functional changes, pure observability enhancement to help users verify EVS configuration and monitor pruning behavior Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fe2669b7638b..db8b99b5d724 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1261,6 +1261,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.is_multimodal_pruning_enabled = ( multimodal_config.is_multimodal_pruning_enabled() ) + if self.is_multimodal_pruning_enabled: + logger.debug( + "EVS (Efficient Video Sampling) enabled with pruning_rate=%.2f", + self.video_pruning_rate + ) if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): @@ -1525,6 +1530,18 @@ def _postprocess_video_embeds_evs( spatial_merge_size=self.visual.spatial_merge_size, q=self.video_pruning_rate, ) + + # Debug logging for EVS pruning + logger.debug( + "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " + "pruning_rate=%.2f, reduction=%.1f%%)", + emb.shape[0], + retention_mask.sum().item(), + size[0], size[1], size[2], + self.video_pruning_rate, + (1 - retention_mask.float().mean().item()) * 100 + ) + positions = compute_mrope_for_media( size, merge_size, From 2f65abcaf95a48be6072b9e5309c4fc448664e1d Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 30 Nov 2025 15:50:42 +0800 Subject: [PATCH 03/22] Fix EVS offset calculation in iter_mm_grid_hw for Qwen3-VL This commit addresses the ValueError that occurs when EVS (Efficient Video Sampling) is enabled with video_pruning_rate > 0. Changes: - Add EVS detection logic in iter_mm_grid_hw method - Implement _extract_frame_offsets_from_mask to extract frame offsets from the is_embed mask stored in mm_position - Add fallback to uniform distribution when mask is unavailable - Preserve original non-EVS behavior completely The new implementation supports sparse EVS retention patterns where different frames can have different numbers of retained tokens, which is the actual behavior of the EVS pruning algorithm. Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 95 ++++++++++++++++++++++++-- 1 file changed, 91 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index db8b99b5d724..6489da7881ef 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -79,6 +79,7 @@ MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, + PlaceholderRange, VideoItem, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser @@ -1577,6 +1578,20 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def iter_mm_grid_hw( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[int, int, int]]: + """ + Iterate over multimodal features and yield grid information. + + For videos with EVS (Efficient Video Sampling) enabled, this function + computes the offset based on the pruned token count rather than relying + on input_tokens.index(), which would fail when tokens are pruned. + + Args: + input_tokens: List of token IDs in the prompt + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_h, grid_w) for each frame/image + """ video_token_id = self.config.video_token_id spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): @@ -1589,13 +1604,85 @@ def iter_mm_grid_hw( t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + + # Check if EVS (Efficient Video Sampling) is enabled + is_evs_enabled = ( + hasattr(self, 'video_pruning_rate') + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + + if is_evs_enabled: + frame_offsets = self._extract_frame_offsets_from_mask( + mm_feature.mm_position, t + ) + if frame_offsets is not None: + for rel_offset in frame_offsets: + yield offset + rel_offset, llm_grid_h, llm_grid_w + continue + + # Fallback: distribute offsets uniformly when mask is missing + tokens_per_frame_original = llm_grid_h * llm_grid_w + total_retained_tokens = compute_retained_tokens_count( + tokens_per_frame_original, + t, + self.video_pruning_rate + ) + tokens_per_frame = ( + total_retained_tokens // t if t > 0 else tokens_per_frame_original + ) + for _ in range(t): + yield offset, llm_grid_h, llm_grid_w + offset += tokens_per_frame + else: + # Non-EVS mode: Use original logic with input_tokens.index() + for _ in range(t): + offset = input_tokens.index(video_token_id, offset) + yield offset, llm_grid_h, llm_grid_w + offset += llm_grid_h * llm_grid_w else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def _extract_frame_offsets_from_mask( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return relative offsets for each EVS-retained frame. + + The prompt processor stores a boolean mask inside ``mm_position`` that + marks which placeholder locations should be populated with video + embeddings. By splitting that mask into contiguous runs we can recover + the start of every retained frame without probing ``input_tokens``. + """ + + is_embed_mask = getattr(mm_position, "is_embed", None) + if is_embed_mask is None: + return None + + mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) + true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() + if true_indices.numel() == 0: + return None + + if true_indices.numel() == 1: + segments = [true_indices] + else: + diffs = torch.diff(true_indices) + split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() + if split_points.numel() == 0: + segments = [true_indices] + else: + segments = torch.tensor_split(true_indices, split_points.add(1).tolist()) + + if len(segments) < expected_frames: + logger.debug( + "EVS mask segments (%d) do not match total frames (%d)", + len(segments), + expected_frames, + ) + return None + + return [int(segment[0].item()) for segment in segments[:expected_frames]] + def recompute_mrope_positions( self, input_ids: list[int], From a3543a40051d79c1884e5e53711b69140f1b370c Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 30 Nov 2025 16:25:54 +0800 Subject: [PATCH 04/22] Fix MRoPE position calculation with EVS-extracted frame offsets When using EVS (Efficient Video Sampling), frame offsets are extracted from the is_embed mask and may not be strictly increasing. This caused text_len (offset - st) to become negative in get_mrope_input_positions, leading to ValueError in np.broadcast_to. Changes: - Skip text position creation when text_len <= 0 (no text between frames) - Update st_idx after adding text positions to maintain position continuity - Use st_idx directly for video frame positions instead of (text_len + st_idx) This ensures position indices remain monotonically increasing even when frames are consecutive in the mask. Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6489da7881ef..ecef0ccee38a 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1752,12 +1752,15 @@ def get_mrope_input_positions( ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) + if text_len > 0: + llm_pos_ids_list.append( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + # Update st_idx for video frame positions + st_idx += text_len grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - llm_pos_ids_list.append(grid_indices + text_len + st_idx) + llm_pos_ids_list.append(grid_indices + st_idx) st = offset + llm_grid_h * llm_grid_w if st < len(input_tokens): From 3d60b858e01f4d3648743fec11afc104194b9ef4 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 30 Nov 2025 17:00:47 +0800 Subject: [PATCH 05/22] Fix EVS placeholder alignment for Qwen3-VL Signed-off-by: zitian.zhao --- test_evs_fix.py | 124 +++++++++++++++++++++++++ vllm/model_executor/models/qwen3_vl.py | 37 +++++--- 2 files changed, 148 insertions(+), 13 deletions(-) create mode 100644 test_evs_fix.py diff --git a/test_evs_fix.py b/test_evs_fix.py new file mode 100644 index 000000000000..2eacd6bd5b51 --- /dev/null +++ b/test_evs_fix.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +"""Simple harness to reason about EVS placeholder offsets. + +The real implementation in ``iter_mm_grid_hw`` now relies on the +``is_embed`` mask stored in ``mm_position`` to recover the start offset of +each frame instead of scanning ``input_tokens``. This script mirrors that +behaviour so we can validate that sparse EVS retention patterns (e.g. first +frame fully kept, other frames pruned unevenly) are still handled. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +import torch + + +@dataclass +class MaskSimulationConfig: + """Helper configuration for generating placeholder masks. + + ``prefix_tokens`` and ``suffix_tokens`` approximate the extra tokens that + surround the `<|video_pad|>` sequence for each frame (timestamps, + `<|vision_start|>`, `<|vision_end|>`, etc.). ``tokens_per_frame`` encodes + how many `<|video_pad|>` tokens survive EVS for each frame. This lets us + mimic both balanced and sparse retention distributions. + """ + + tokens_per_frame: list[int] + prefix_tokens: int = 2 + suffix_tokens: int = 1 + + +def build_is_embed_mask(cfg: MaskSimulationConfig) -> torch.Tensor: + mask: list[int] = [] + for tokens in cfg.tokens_per_frame: + mask.extend([0] * cfg.prefix_tokens) + mask.extend([1] * tokens) + mask.extend([0] * cfg.suffix_tokens) + return torch.tensor(mask, dtype=torch.bool) + + +def extract_frame_offsets( + offset_start: int, mask: torch.Tensor, expected_frames: int +) -> tuple[list[int], list[int]]: + """Mimic the EVS branch in ``iter_mm_grid_hw``. + + We compute the first index of each contiguous run of ``True`` values, + convert it back to an absolute offset using ``offset_start`` and return + both the offsets and the corresponding run lengths. + """ + + flat_mask = mask.reshape(-1).to(torch.bool) + true_indices = torch.nonzero(flat_mask, as_tuple=False).flatten() + if true_indices.numel() == 0: + raise ValueError("Mask does not contain any embed tokens") + + if true_indices.numel() == 1: + segments: Iterable[torch.Tensor] = (true_indices,) + else: + diffs = true_indices[1:] - true_indices[:-1] + split_points = ( + torch.nonzero(diffs != 1, as_tuple=False).flatten().add(1).tolist() + ) + segments = torch.tensor_split(true_indices, split_points) + + segments = list(segments) + if len(segments) < expected_frames: + raise ValueError( + f"Expected {expected_frames} frame segments, got {len(segments)}" + ) + + offsets = [offset_start + int(segment[0].item()) for segment in segments[:expected_frames]] + lengths = [int(segment.numel()) for segment in segments[:expected_frames]] + return offsets, lengths + + +def test_sparse_distribution() -> None: + print("\n=== 测试场景 1: 稀疏分布 (真实 EVS 行为) ===") + per_frame = [50176, 15000, 12000, 10000, 8000, 145668, 5000, 5000] + cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2) + mask = build_is_embed_mask(cfg) + offsets, lengths = extract_frame_offsets(128, mask, len(per_frame)) + + for idx, (off, size, expected) in enumerate(zip(offsets, lengths, per_frame), 1): + print( + f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens (expected {expected})" + ) + assert size == expected + + print("✅ 稀疏分布模拟通过") + + +def test_uniform_distribution() -> None: + print("\n=== 测试场景 2: 均匀分布 (处理器当前实现) ===") + per_frame = [784 for _ in range(4)] + cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1) + mask = build_is_embed_mask(cfg) + offsets, lengths = extract_frame_offsets(42, mask, len(per_frame)) + + expected_offsets: list[int] = [] + cursor = 42 + for tokens in per_frame: + cursor += cfg.prefix_tokens + expected_offsets.append(cursor) + cursor += tokens + cfg.suffix_tokens + + for idx, (off, size, expected_offset) in enumerate( + zip(offsets, lengths, expected_offsets), 1 + ): + print( + f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens" + ) + assert size == per_frame[idx - 1] + assert off == expected_offset + + print("✅ 均匀分布模拟通过") + + +if __name__ == "__main__": + test_sparse_distribution() + test_uniform_distribution() + print("\n所有 EVS 相关测试通过 ✅") diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index ecef0ccee38a..f149e13e575b 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1050,28 +1050,39 @@ def get_video_replacement_qwen3vl(item_idx: int): tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] - # EVS-specific code video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate if video_pruning_rate is not None and video_pruning_rate > 0.0: - T, H, W = map(int, grid_thw) - tokens_per_frame = (H // image_processor.merge_size) * ( - W // image_processor.merge_size - ) - num_tokens_per_frame = compute_retained_tokens_count( + total_retained = compute_retained_tokens_count( tokens_per_frame, - T, + len(frames_idx_token), video_pruning_rate, - ) // T # Divide by T to get tokens per frame - # End of EVS-specific code + ) + if len(frames_idx_token) == 0: + per_frame_token_counts = [] + elif len(frames_idx_token) == 1: + per_frame_token_counts = [tokens_per_frame] + else: + first_frame_tokens = tokens_per_frame + remaining_tokens = max(total_retained - first_frame_tokens, 0) + base = remaining_tokens // (len(frames_idx_token) - 1) + remainder = remaining_tokens % (len(frames_idx_token) - 1) + per_frame_token_counts = [first_frame_tokens] + for frame_idx in range(1, len(frames_idx_token)): + extra = base + (1 if (frame_idx - 1) < remainder else 0) + per_frame_token_counts.append(extra) placeholder = [] - for frame_idx in frames_idx_token: - placeholder.extend(frame_idx) + for frame_idx, timestamp_tokens in enumerate(frames_idx_token): + placeholder.extend(timestamp_tokens) + tokens_this_frame = per_frame_token_counts[ + frame_idx if frame_idx < len(per_frame_token_counts) else -1 + ] placeholder.extend( [vision_start_token_id] - + [video_token_id] * num_tokens_per_frame + + [video_token_id] * tokens_this_frame + [vision_end_token_id] ) return PromptUpdateDetails.select_token_id(placeholder, video_token_id) From 100aa68a17ca9f571f2a3a8ffa5598421b515c58 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 30 Nov 2025 17:11:02 +0800 Subject: [PATCH 06/22] Format EVS helper and add SPDX header Signed-off-by: zitian.zhao --- test_evs_fix.py | 23 +++++++++++++++-------- vllm/model_executor/models/qwen3_vl.py | 26 ++++++++++++++------------ 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test_evs_fix.py b/test_evs_fix.py index 2eacd6bd5b51..02eb63a0b6dc 100644 --- a/test_evs_fix.py +++ b/test_evs_fix.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Simple harness to reason about EVS placeholder offsets. The real implementation in ``iter_mm_grid_hw`` now relies on the @@ -10,8 +12,8 @@ from __future__ import annotations +from collections.abc import Iterable from dataclasses import dataclass -from typing import Iterable import torch @@ -71,7 +73,9 @@ def extract_frame_offsets( f"Expected {expected_frames} frame segments, got {len(segments)}" ) - offsets = [offset_start + int(segment[0].item()) for segment in segments[:expected_frames]] + offsets = [ + offset_start + int(segment[0].item()) for segment in segments[:expected_frames] + ] lengths = [int(segment.numel()) for segment in segments[:expected_frames]] return offsets, lengths @@ -79,13 +83,16 @@ def extract_frame_offsets( def test_sparse_distribution() -> None: print("\n=== 测试场景 1: 稀疏分布 (真实 EVS 行为) ===") per_frame = [50176, 15000, 12000, 10000, 8000, 145668, 5000, 5000] - cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2) + cfg = MaskSimulationConfig( + tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2 + ) mask = build_is_embed_mask(cfg) offsets, lengths = extract_frame_offsets(128, mask, len(per_frame)) for idx, (off, size, expected) in enumerate(zip(offsets, lengths, per_frame), 1): print( - f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens (expected {expected})" + f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens " + f"(expected {expected})" ) assert size == expected @@ -95,7 +102,9 @@ def test_sparse_distribution() -> None: def test_uniform_distribution() -> None: print("\n=== 测试场景 2: 均匀分布 (处理器当前实现) ===") per_frame = [784 for _ in range(4)] - cfg = MaskSimulationConfig(tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1) + cfg = MaskSimulationConfig( + tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1 + ) mask = build_is_embed_mask(cfg) offsets, lengths = extract_frame_offsets(42, mask, len(per_frame)) @@ -109,9 +118,7 @@ def test_uniform_distribution() -> None: for idx, (off, size, expected_offset) in enumerate( zip(offsets, lengths, expected_offsets), 1 ): - print( - f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens" - ) + print(f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens") assert size == per_frame[idx - 1] assert off == expected_offset diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f149e13e575b..cba0ea3d6df1 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1276,7 +1276,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): if self.is_multimodal_pruning_enabled: logger.debug( "EVS (Efficient Video Sampling) enabled with pruning_rate=%.2f", - self.video_pruning_rate + self.video_pruning_rate, ) if not multimodal_config.get_limit_per_prompt( "image" @@ -1527,9 +1527,7 @@ def _postprocess_video_embeds_evs( second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) else: second_per_grid_ts = second_per_grid_ts.long() - tokens_per_second = getattr( - self.config.vision_config, "tokens_per_second", 1.0 - ) + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) video_embeds_out = [] for emb, size, video_second_per_grid_t in zip( @@ -1549,9 +1547,11 @@ def _postprocess_video_embeds_evs( "pruning_rate=%.2f, reduction=%.1f%%)", emb.shape[0], retention_mask.sum().item(), - size[0], size[1], size[2], + size[0], + size[1], + size[2], self.video_pruning_rate, - (1 - retention_mask.float().mean().item()) * 100 + (1 - retention_mask.float().mean().item()) * 100, ) positions = compute_mrope_for_media( @@ -1618,7 +1618,7 @@ def iter_mm_grid_hw( # Check if EVS (Efficient Video Sampling) is enabled is_evs_enabled = ( - hasattr(self, 'video_pruning_rate') + hasattr(self, "video_pruning_rate") and self.video_pruning_rate is not None and self.video_pruning_rate > 0.0 ) @@ -1635,12 +1635,12 @@ def iter_mm_grid_hw( # Fallback: distribute offsets uniformly when mask is missing tokens_per_frame_original = llm_grid_h * llm_grid_w total_retained_tokens = compute_retained_tokens_count( - tokens_per_frame_original, - t, - self.video_pruning_rate + tokens_per_frame_original, t, self.video_pruning_rate ) tokens_per_frame = ( - total_retained_tokens // t if t > 0 else tokens_per_frame_original + total_retained_tokens // t + if t > 0 + else tokens_per_frame_original ) for _ in range(t): yield offset, llm_grid_h, llm_grid_w @@ -1682,7 +1682,9 @@ def _extract_frame_offsets_from_mask( if split_points.numel() == 0: segments = [true_indices] else: - segments = torch.tensor_split(true_indices, split_points.add(1).tolist()) + segments = torch.tensor_split( + true_indices, split_points.add(1).tolist() + ) if len(segments) < expected_frames: logger.debug( From 0a991a60fae29a30b5700388906bb319de8c9546 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 30 Nov 2025 23:39:20 +0800 Subject: [PATCH 07/22] Remove test files from EVS implementation - Remove tests/models/multimodal/generation/test_qwen3_vl.py - Remove test_evs_fix.py (development test file) These test files were part of the development process and are not needed in the final implementation. Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- .../multimodal/generation/test_qwen3_vl.py | 148 ------------------ 1 file changed, 148 deletions(-) delete mode 100644 tests/models/multimodal/generation/test_qwen3_vl.py diff --git a/tests/models/multimodal/generation/test_qwen3_vl.py b/tests/models/multimodal/generation/test_qwen3_vl.py deleted file mode 100644 index 04ab43253f02..000000000000 --- a/tests/models/multimodal/generation/test_qwen3_vl.py +++ /dev/null @@ -1,148 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import pytest - -from vllm.multimodal.video import sample_frames_from_video - -from ....conftest import VIDEO_ASSETS - -models = ["Qwen/Qwen3-VL-3B-Instruct"] -target_dtype = "bfloat16" - -VIDEO_PLACEHOLDER = "<|vision_start|><|video_pad|><|vision_end|>" - - -def qwen3_vl_chat_template(*query): - return f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{''.join(query)}<|im_end|><|im_start|>assistant\n" # noqa: E501 - - -VIDEO_PROMPTS = VIDEO_ASSETS.prompts( - { - "baby_reading": qwen3_vl_chat_template( - VIDEO_PLACEHOLDER, - "Describe this video with a short sentence ", - "(no more than 20 words)", - ), - } -) - - -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) -@pytest.mark.parametrize("num_frames", [16]) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("use_bytecode_hook", [True, False]) -def test_qwen3_vl_evs_functionality( - vllm_runner, - video_assets, - model, - video_pruning_rate: float, - num_frames: int, - dtype: str, - max_tokens: int, - use_bytecode_hook: bool, - monkeypatch, -) -> None: - """Test EVS (Efficient Video Sampling) functionality with different - pruning rates. - """ - # Set the environment variable for this test - monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") - - # Sample frames from video assets - sampled_vids = [ - sample_frames_from_video(asset.np_ndarrays, num_frames) - for asset in video_assets - ] - - prompts = [VIDEO_PROMPTS[0]] - videos = [sampled_vids[0]] - - # Initialize model with EVS configuration - with vllm_runner( - model, - runner="generate", - max_model_len=4000, - dtype=dtype, - limit_mm_per_prompt={"video": 1}, - video_pruning_rate=video_pruning_rate, - ) as vllm_model: - # Generate output - this should not crash - outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) - - # Basic validation that we got a response - assert len(outputs) == 1 - output_ids, output_text = outputs[0] - - # Ensure we got some output - assert len(output_ids) > 0 - assert len(output_text) > 0 - - # Ensure the output is a string - assert isinstance(output_text, str) - - -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("video_pruning_rate", [0.0, 0.75]) -@pytest.mark.parametrize("num_frames", [16]) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("use_bytecode_hook", [True, False]) -def test_qwen3_vl_evs_batched_videos( - vllm_runner, - video_assets, - model, - video_pruning_rate: float, - num_frames: int, - dtype: str, - max_tokens: int, - use_bytecode_hook: bool, - monkeypatch, -) -> None: - """Test EVS functionality with batched videos. - - This test validates that: - 1. The model handles batched video inputs correctly with EVS - 2. Both pruning configurations work with multiple videos - 3. The model doesn't crash when processing multiple videos simultaneously - """ - # Set the environment variable for this test - monkeypatch.setenv("VLLM_USE_BYTECODE_HOOK", "1" if use_bytecode_hook else "0") - # Sample frames from video assets - sampled_vids = [ - sample_frames_from_video(asset.np_ndarrays, num_frames) - for asset in video_assets - ] - - # Test batched videos - prompts = [VIDEO_PROMPTS[0], VIDEO_PROMPTS[0]] - videos = [sampled_vids[0], sampled_vids[0]] # Use same video twice for testing - - # Initialize model with EVS configuration - with vllm_runner( - model, - runner="generate", - max_model_len=4000, - max_num_seqs=2, - dtype=dtype, - limit_mm_per_prompt={"video": 2}, - tensor_parallel_size=1, - video_pruning_rate=video_pruning_rate, - ) as vllm_model: - # Generate output - this should not crash - outputs = vllm_model.generate_greedy(prompts, max_tokens, videos=videos) - - # Basic validation that we got responses for both videos - assert len(outputs) == 2 - - for output_ids, output_text in outputs: - # Ensure we got some output for each video - assert len(output_ids) > 0 - assert len(output_text) > 0 - - # Ensure the output is a string - assert isinstance(output_text, str) From 6f08a98142330d4a82f2bce8bf9b969b10ba64fb Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Mon, 1 Dec 2025 00:10:45 +0800 Subject: [PATCH 08/22] Remove development test file test_evs_fix.py Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- test_evs_fix.py | 131 ------------------------------------------------ 1 file changed, 131 deletions(-) delete mode 100644 test_evs_fix.py diff --git a/test_evs_fix.py b/test_evs_fix.py deleted file mode 100644 index 02eb63a0b6dc..000000000000 --- a/test_evs_fix.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Simple harness to reason about EVS placeholder offsets. - -The real implementation in ``iter_mm_grid_hw`` now relies on the -``is_embed`` mask stored in ``mm_position`` to recover the start offset of -each frame instead of scanning ``input_tokens``. This script mirrors that -behaviour so we can validate that sparse EVS retention patterns (e.g. first -frame fully kept, other frames pruned unevenly) are still handled. -""" - -from __future__ import annotations - -from collections.abc import Iterable -from dataclasses import dataclass - -import torch - - -@dataclass -class MaskSimulationConfig: - """Helper configuration for generating placeholder masks. - - ``prefix_tokens`` and ``suffix_tokens`` approximate the extra tokens that - surround the `<|video_pad|>` sequence for each frame (timestamps, - `<|vision_start|>`, `<|vision_end|>`, etc.). ``tokens_per_frame`` encodes - how many `<|video_pad|>` tokens survive EVS for each frame. This lets us - mimic both balanced and sparse retention distributions. - """ - - tokens_per_frame: list[int] - prefix_tokens: int = 2 - suffix_tokens: int = 1 - - -def build_is_embed_mask(cfg: MaskSimulationConfig) -> torch.Tensor: - mask: list[int] = [] - for tokens in cfg.tokens_per_frame: - mask.extend([0] * cfg.prefix_tokens) - mask.extend([1] * tokens) - mask.extend([0] * cfg.suffix_tokens) - return torch.tensor(mask, dtype=torch.bool) - - -def extract_frame_offsets( - offset_start: int, mask: torch.Tensor, expected_frames: int -) -> tuple[list[int], list[int]]: - """Mimic the EVS branch in ``iter_mm_grid_hw``. - - We compute the first index of each contiguous run of ``True`` values, - convert it back to an absolute offset using ``offset_start`` and return - both the offsets and the corresponding run lengths. - """ - - flat_mask = mask.reshape(-1).to(torch.bool) - true_indices = torch.nonzero(flat_mask, as_tuple=False).flatten() - if true_indices.numel() == 0: - raise ValueError("Mask does not contain any embed tokens") - - if true_indices.numel() == 1: - segments: Iterable[torch.Tensor] = (true_indices,) - else: - diffs = true_indices[1:] - true_indices[:-1] - split_points = ( - torch.nonzero(diffs != 1, as_tuple=False).flatten().add(1).tolist() - ) - segments = torch.tensor_split(true_indices, split_points) - - segments = list(segments) - if len(segments) < expected_frames: - raise ValueError( - f"Expected {expected_frames} frame segments, got {len(segments)}" - ) - - offsets = [ - offset_start + int(segment[0].item()) for segment in segments[:expected_frames] - ] - lengths = [int(segment.numel()) for segment in segments[:expected_frames]] - return offsets, lengths - - -def test_sparse_distribution() -> None: - print("\n=== 测试场景 1: 稀疏分布 (真实 EVS 行为) ===") - per_frame = [50176, 15000, 12000, 10000, 8000, 145668, 5000, 5000] - cfg = MaskSimulationConfig( - tokens_per_frame=per_frame, prefix_tokens=3, suffix_tokens=2 - ) - mask = build_is_embed_mask(cfg) - offsets, lengths = extract_frame_offsets(128, mask, len(per_frame)) - - for idx, (off, size, expected) in enumerate(zip(offsets, lengths, per_frame), 1): - print( - f"Frame {idx:02d}: offset={off:6d}, retained={size:6d} tokens " - f"(expected {expected})" - ) - assert size == expected - - print("✅ 稀疏分布模拟通过") - - -def test_uniform_distribution() -> None: - print("\n=== 测试场景 2: 均匀分布 (处理器当前实现) ===") - per_frame = [784 for _ in range(4)] - cfg = MaskSimulationConfig( - tokens_per_frame=per_frame, prefix_tokens=2, suffix_tokens=1 - ) - mask = build_is_embed_mask(cfg) - offsets, lengths = extract_frame_offsets(42, mask, len(per_frame)) - - expected_offsets: list[int] = [] - cursor = 42 - for tokens in per_frame: - cursor += cfg.prefix_tokens - expected_offsets.append(cursor) - cursor += tokens + cfg.suffix_tokens - - for idx, (off, size, expected_offset) in enumerate( - zip(offsets, lengths, expected_offsets), 1 - ): - print(f"Frame {idx:02d}: offset={off:5d}, retained={size:4d} tokens") - assert size == per_frame[idx - 1] - assert off == expected_offset - - print("✅ 均匀分布模拟通过") - - -if __name__ == "__main__": - test_sparse_distribution() - test_uniform_distribution() - print("\n所有 EVS 相关测试通过 ✅") From 23f98fa5f27699c6a21fddaa4da1839076d2d6ef Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Mon, 1 Dec 2025 00:17:01 +0800 Subject: [PATCH 09/22] Replace incorrect EVS fallback with explicit error The fallback logic in iter_mm_grid_hw had several issues: 1. Token distribution inconsistency: Used uniform distribution (total // t) instead of first-frame-full distribution used by the processor 2. Incorrect offset calculation: Only counted video_pad tokens, ignored timestamp and start/end tokens in the placeholder 3. Should never trigger: Processor always generates is_embed mask Replace with RuntimeError to catch bugs early if mask is missing. Based on code review feedback. Co-Authored-By: deitxfge Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index cba0ea3d6df1..a2ec47c48a8e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1632,19 +1632,14 @@ def iter_mm_grid_hw( yield offset + rel_offset, llm_grid_h, llm_grid_w continue - # Fallback: distribute offsets uniformly when mask is missing - tokens_per_frame_original = llm_grid_h * llm_grid_w - total_retained_tokens = compute_retained_tokens_count( - tokens_per_frame_original, t, self.video_pruning_rate + # If EVS is enabled but mask is missing, this indicates a bug + # in the prompt processing pipeline. The is_embed mask should + # always be present when video_pruning_rate > 0. + raise RuntimeError( + f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " + "but is_embed mask is missing from mm_position. " + "This indicates a bug in prompt processing." ) - tokens_per_frame = ( - total_retained_tokens // t - if t > 0 - else tokens_per_frame_original - ) - for _ in range(t): - yield offset, llm_grid_h, llm_grid_w - offset += tokens_per_frame else: # Non-EVS mode: Use original logic with input_tokens.index() for _ in range(t): From 8fae3129f7870da6081b9dc6f91ed999efdd8d35 Mon Sep 17 00:00:00 2001 From: deitxfge Date: Mon, 1 Dec 2025 11:20:32 +0800 Subject: [PATCH 10/22] Remove unnecessary log Co-authored-by: skyloevil Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index a2ec47c48a8e..4772955311eb 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1273,11 +1273,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.is_multimodal_pruning_enabled = ( multimodal_config.is_multimodal_pruning_enabled() ) - if self.is_multimodal_pruning_enabled: - logger.debug( - "EVS (Efficient Video Sampling) enabled with pruning_rate=%.2f", - self.video_pruning_rate, - ) + if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): From 0ab2d0eef597ba8a841c74d85652d8c0496232f1 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Mon, 1 Dec 2025 12:20:52 +0800 Subject: [PATCH 11/22] solve pre-commit Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 4772955311eb..60468998faff 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1273,7 +1273,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.is_multimodal_pruning_enabled = ( multimodal_config.is_multimodal_pruning_enabled() ) - + if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): From 241daadadae3c4f0894989539df6484c10e01ef2 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 11:38:43 +0800 Subject: [PATCH 12/22] Add detailed logging to get_mrope_input_positions for debugging Added comprehensive logging to track: - Input parameters (token count, mm_features count) - Per-frame details (offset, st, text_len, grid dimensions) - Segment addition operations with shape information - Warnings for skipped segments (text_len <= 0) - Final shape validation and mismatch detection This helps identify issues with: - EVS pruning affecting token distribution - text_len=0 cases where frames are consecutive - Shape mismatches between generated positions and input tokens Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 75 +++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 60468998faff..91e61f1b90f3 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1749,33 +1749,94 @@ def get_mrope_input_positions( input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + logger.info("=" * 80) + logger.info("get_mrope_input_positions START") + logger.info(f"Total input_tokens: {len(input_tokens)}") + logger.info(f"Number of mm_features: {len(mm_features)}") + llm_pos_ids_list = [] st = 0 + frame_idx = 0 + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + + logger.info("-" * 60) + logger.info(f"Frame {frame_idx}:") + logger.info(f" offset={offset}, st={st}, text_len={text_len}") + logger.info(f" grid=({llm_grid_h}, {llm_grid_w}), tokens={llm_grid_h * llm_grid_w}") + logger.info(f" st_idx={st_idx}") + logger.info(f" current llm_pos_ids_list length: {len(llm_pos_ids_list)}") + + # 关键检查点 if text_len > 0: - llm_pos_ids_list.append( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) + logger.info(f" ✓ Adding text segment: text_len={text_len}") + text_positions = np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + logger.info(f" text_positions shape: {text_positions.shape}") + logger.info(f" text_positions range: [{text_positions.min()}, {text_positions.max()}]") + llm_pos_ids_list.append(text_positions) # Update st_idx for video frame positions st_idx += text_len + logger.info(f" Updated st_idx to: {st_idx}") + else: + logger.warning(f" ⚠ SKIPPED text segment: text_len={text_len} (zero or negative)") + logger.warning(f" This means frame {frame_idx} starts immediately after previous frame") + logger.warning(f" Possible cause: consecutive frames without text tokens between them") grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - llm_pos_ids_list.append(grid_indices + st_idx) + frame_positions = grid_indices + st_idx + logger.info(f" Adding frame positions:") + logger.info(f" frame_positions shape: {frame_positions.shape}") + logger.info(f" frame_positions range: [{frame_positions.min()}, {frame_positions.max()}]") + llm_pos_ids_list.append(frame_positions) + st = offset + llm_grid_h * llm_grid_w + logger.info(f" Updated st to: {st}") + logger.info(f" llm_pos_ids_list now has {len(llm_pos_ids_list)} segments") + + frame_idx += 1 + + # 处理最后的文本部分 + logger.info("-" * 60) + logger.info("Final text segment:") + logger.info(f" st={st}, total_tokens={len(input_tokens)}") if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) + logger.info(f" ✓ Adding final text: text_len={text_len}, st_idx={st_idx}") + final_text_positions = np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + logger.info(f" final_text_positions shape: {final_text_positions.shape}") + logger.info(f" final_text_positions range: [{final_text_positions.min()}, {final_text_positions.max()}]") + llm_pos_ids_list.append(final_text_positions) + else: + logger.info(f" ✗ No final text segment (st={st} >= len={len(input_tokens)})") + + logger.info("-" * 60) + logger.info("Concatenating positions:") + logger.info(f" Total segments: {len(llm_pos_ids_list)}") + for i, seg in enumerate(llm_pos_ids_list): + logger.info(f" Segment {i}: shape={seg.shape}") llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) + logger.info(f" Concatenated positions shape: {llm_positions.shape}") + logger.info(f" Expected shape: (3, {len(input_tokens)})") + + if llm_positions.shape[1] != len(input_tokens): + logger.error(f" ✗✗✗ SHAPE MISMATCH! ✗✗✗") + logger.error(f" Generated {llm_positions.shape[1]} positions for {len(input_tokens)} tokens") + logger.error(f" Difference: {llm_positions.shape[1] - len(input_tokens)}") + else: + logger.info(f" ✓ Shape matches!") + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + logger.info(f"mrope_position_delta: {mrope_position_delta}") + logger.info("get_mrope_input_positions END") + logger.info("=" * 80) + return torch.from_numpy(llm_positions), mrope_position_delta def get_language_model(self) -> torch.nn.Module: From 409bdc2a19df95f80e2a9080eee1d455d8842e63 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 12:33:12 +0800 Subject: [PATCH 13/22] Fix negative text_len bug in EVS mode Problem: get_mrope_input_positions calculated negative text_len values causing shape mismatches in EVS mode. Root cause was st update using theoretical token count (grid_h * grid_w) instead of actual count after EVS pruning. Solution: Added _get_actual_frame_token_counts() to extract actual token counts from is_embed mask. Pre-collect token counts and use actual values when updating st position. Changes: - New method _get_actual_frame_token_counts() analyzes is_embed mask - get_mrope_input_positions pre-collects actual token counts for EVS mode - st update logic uses actual tokens in EVS mode, theoretical in non-EVS - Enhanced logging to show actual vs theoretical token counts Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 151 ++++++++++++++++++++++--- 1 file changed, 136 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 91e61f1b90f3..5ec27d63327a 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1687,6 +1687,49 @@ def _extract_frame_offsets_from_mask( return [int(segment[0].item()) for segment in segments[:expected_frames]] + def _get_actual_frame_token_counts( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return actual token count for each EVS-retained frame. + + This function calculates the actual number of tokens per frame by + analyzing the is_embed mask, accounting for EVS pruning. + + Returns: + List of token counts for each frame, or None if EVS is not enabled. + """ + + is_embed_mask = getattr(mm_position, "is_embed", None) + if is_embed_mask is None: + return None + + mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) + true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() + if true_indices.numel() == 0: + return None + + if true_indices.numel() == 1: + segments = [true_indices] + else: + diffs = torch.diff(true_indices) + split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() + if split_points.numel() == 0: + segments = [true_indices] + else: + segments = torch.tensor_split( + true_indices, split_points.add(1).tolist() + ) + + if len(segments) < expected_frames: + logger.debug( + "EVS mask segments (%d) do not match total frames (%d)", + len(segments), + expected_frames, + ) + return None + + return [len(seg) for seg in segments[:expected_frames]] + def recompute_mrope_positions( self, input_ids: list[int], @@ -1754,9 +1797,33 @@ def get_mrope_input_positions( logger.info(f"Total input_tokens: {len(input_tokens)}") logger.info(f"Number of mm_features: {len(mm_features)}") + # Pre-collect actual frame token counts for EVS mode + frame_token_counts_map = {} + for mm_feature in mm_features: + if mm_feature.modality == "video": + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + if is_evs_enabled: + t = mm_feature.data["video_grid_thw"].data.tolist()[0] + token_counts = self._get_actual_frame_token_counts( + mm_feature.mm_position, t + ) + if token_counts: + frame_token_counts_map[mm_feature.mm_position.offset] = ( + token_counts + ) + logger.info( + f"EVS mode: collected {len(token_counts)} frame token counts for offset {mm_feature.mm_position.offset}" + ) + logger.info(f" Token counts: {token_counts}") + llm_pos_ids_list = [] st = 0 frame_idx = 0 + frame_counts_idx = {} for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features @@ -1767,33 +1834,79 @@ def get_mrope_input_positions( logger.info("-" * 60) logger.info(f"Frame {frame_idx}:") logger.info(f" offset={offset}, st={st}, text_len={text_len}") - logger.info(f" grid=({llm_grid_h}, {llm_grid_w}), tokens={llm_grid_h * llm_grid_w}") + logger.info( + f" grid=({llm_grid_h}, {llm_grid_w}), theoretical_tokens={llm_grid_h * llm_grid_w}" + ) logger.info(f" st_idx={st_idx}") logger.info(f" current llm_pos_ids_list length: {len(llm_pos_ids_list)}") # 关键检查点 if text_len > 0: logger.info(f" ✓ Adding text segment: text_len={text_len}") - text_positions = np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + text_positions = ( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) logger.info(f" text_positions shape: {text_positions.shape}") - logger.info(f" text_positions range: [{text_positions.min()}, {text_positions.max()}]") + logger.info( + f" text_positions range: [{text_positions.min()}, {text_positions.max()}]" + ) llm_pos_ids_list.append(text_positions) # Update st_idx for video frame positions st_idx += text_len logger.info(f" Updated st_idx to: {st_idx}") else: - logger.warning(f" ⚠ SKIPPED text segment: text_len={text_len} (zero or negative)") - logger.warning(f" This means frame {frame_idx} starts immediately after previous frame") - logger.warning(f" Possible cause: consecutive frames without text tokens between them") + logger.warning( + f" ⚠ SKIPPED text segment: text_len={text_len} (zero or negative)" + ) + logger.warning( + f" This means frame {frame_idx} starts immediately after previous frame" + ) + logger.warning( + " Possible cause: consecutive frames without text tokens between them" + ) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) frame_positions = grid_indices + st_idx - logger.info(f" Adding frame positions:") + logger.info(" Adding frame positions:") logger.info(f" frame_positions shape: {frame_positions.shape}") - logger.info(f" frame_positions range: [{frame_positions.min()}, {frame_positions.max()}]") + logger.info( + f" frame_positions range: [{frame_positions.min()}, {frame_positions.max()}]" + ) llm_pos_ids_list.append(frame_positions) - st = offset + llm_grid_h * llm_grid_w + # FIX: Use actual token count from EVS mask instead of theoretical grid size + # Find the base offset for this video feature + base_offset = None + for feat_offset in frame_token_counts_map: + if offset >= feat_offset: + base_offset = feat_offset + + if base_offset is not None and base_offset in frame_token_counts_map: + # EVS mode: use actual token count + if base_offset not in frame_counts_idx: + frame_counts_idx[base_offset] = 0 + + counts = frame_token_counts_map[base_offset] + idx = frame_counts_idx[base_offset] + + if idx < len(counts): + actual_tokens = counts[idx] + logger.info( + f" EVS mode: using actual_tokens={actual_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" + ) + st = offset + actual_tokens + frame_counts_idx[base_offset] += 1 + else: + # Fallback to theoretical count if index out of range + logger.warning( + f" EVS mode: frame index {idx} out of range, using theoretical count" + ) + st = offset + llm_grid_h * llm_grid_w + else: + # Non-EVS mode: use theoretical grid size + st = offset + llm_grid_h * llm_grid_w + logger.info(" Non-EVS mode: using theoretical token count") + logger.info(f" Updated st to: {st}") logger.info(f" llm_pos_ids_list now has {len(llm_pos_ids_list)} segments") @@ -1808,12 +1921,18 @@ def get_mrope_input_positions( st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st logger.info(f" ✓ Adding final text: text_len={text_len}, st_idx={st_idx}") - final_text_positions = np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + final_text_positions = ( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) logger.info(f" final_text_positions shape: {final_text_positions.shape}") - logger.info(f" final_text_positions range: [{final_text_positions.min()}, {final_text_positions.max()}]") + logger.info( + f" final_text_positions range: [{final_text_positions.min()}, {final_text_positions.max()}]" + ) llm_pos_ids_list.append(final_text_positions) else: - logger.info(f" ✗ No final text segment (st={st} >= len={len(input_tokens)})") + logger.info( + f" ✗ No final text segment (st={st} >= len={len(input_tokens)})" + ) logger.info("-" * 60) logger.info("Concatenating positions:") @@ -1826,11 +1945,13 @@ def get_mrope_input_positions( logger.info(f" Expected shape: (3, {len(input_tokens)})") if llm_positions.shape[1] != len(input_tokens): - logger.error(f" ✗✗✗ SHAPE MISMATCH! ✗✗✗") - logger.error(f" Generated {llm_positions.shape[1]} positions for {len(input_tokens)} tokens") + logger.error(" ✗✗✗ SHAPE MISMATCH! ✗✗✗") + logger.error( + f" Generated {llm_positions.shape[1]} positions for {len(input_tokens)} tokens" + ) logger.error(f" Difference: {llm_positions.shape[1] - len(input_tokens)}") else: - logger.info(f" ✓ Shape matches!") + logger.info(" ✓ Shape matches!") mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() logger.info(f"mrope_position_delta: {mrope_position_delta}") From 597e9421a30ab78f19cb55c6eb61f74e5e134abb Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 12:35:52 +0800 Subject: [PATCH 14/22] Replace EVS fallback with assertion Remove fallback to theoretical token count in EVS mode. If frame index is out of range, this indicates a serious bug that should fail fast rather than silently using incorrect values. Changes: - Replaced if-else fallback with assert statement - Added descriptive error message showing frame index and total frames - Ensures EVS mode always uses actual token counts from is_embed mask Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 5ec27d63327a..17fa38bd70ca 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1889,19 +1889,16 @@ def get_mrope_input_positions( counts = frame_token_counts_map[base_offset] idx = frame_counts_idx[base_offset] - if idx < len(counts): - actual_tokens = counts[idx] - logger.info( - f" EVS mode: using actual_tokens={actual_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" - ) - st = offset + actual_tokens - frame_counts_idx[base_offset] += 1 - else: - # Fallback to theoretical count if index out of range - logger.warning( - f" EVS mode: frame index {idx} out of range, using theoretical count" - ) - st = offset + llm_grid_h * llm_grid_w + assert idx < len( + counts + ), f"EVS frame index {idx} out of range (total frames: {len(counts)})" + + actual_tokens = counts[idx] + logger.info( + f" EVS mode: using actual_tokens={actual_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" + ) + st = offset + actual_tokens + frame_counts_idx[base_offset] += 1 else: # Non-EVS mode: use theoretical grid size st = offset + llm_grid_h * llm_grid_w From 59f654ced68e94df8e6fd6d549178bc89a0d004e Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 12:41:10 +0800 Subject: [PATCH 15/22] Add assertions to validate EVS frame token count extraction Strengthen validation with assertions: 1. Assert token_counts extraction succeeds when EVS is enabled 2. Assert base_offset is always in map if found (redundant but explicit) 3. Clarify fallback comment: applies to non-EVS video and images The fallback to theoretical token count is correct for: - Non-EVS videos: no pruning, theoretical = actual - Images: no EVS support, use theoretical count Changes ensure any logic errors fail fast with clear error messages. Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 17fa38bd70ca..1b3e1caa48a6 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1811,14 +1811,16 @@ def get_mrope_input_positions( token_counts = self._get_actual_frame_token_counts( mm_feature.mm_position, t ) - if token_counts: - frame_token_counts_map[mm_feature.mm_position.offset] = ( - token_counts - ) - logger.info( - f"EVS mode: collected {len(token_counts)} frame token counts for offset {mm_feature.mm_position.offset}" - ) - logger.info(f" Token counts: {token_counts}") + assert ( + token_counts is not None + ), "EVS enabled but failed to extract frame token counts from is_embed mask" + frame_token_counts_map[mm_feature.mm_position.offset] = ( + token_counts + ) + logger.info( + f"EVS mode: collected {len(token_counts)} frame token counts for offset {mm_feature.mm_position.offset}" + ) + logger.info(f" Token counts: {token_counts}") llm_pos_ids_list = [] st = 0 @@ -1881,8 +1883,12 @@ def get_mrope_input_positions( if offset >= feat_offset: base_offset = feat_offset - if base_offset is not None and base_offset in frame_token_counts_map: + if base_offset is not None: # EVS mode: use actual token count + assert ( + base_offset in frame_token_counts_map + ), f"Found base_offset {base_offset} but not in frame_token_counts_map" + if base_offset not in frame_counts_idx: frame_counts_idx[base_offset] = 0 @@ -1900,7 +1906,7 @@ def get_mrope_input_positions( st = offset + actual_tokens frame_counts_idx[base_offset] += 1 else: - # Non-EVS mode: use theoretical grid size + # Non-EVS mode (or image): use theoretical grid size st = offset + llm_grid_h * llm_grid_w logger.info(" Non-EVS mode: using theoretical token count") From 9ae3d6ed0319566ada1b377c4cc624af48a22d50 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 12:56:27 +0800 Subject: [PATCH 16/22] Fix frame positions generation to use actual token count Critical fix: Previously only st update used actual tokens, but frame positions generation still used theoretical grid size, causing massive position count mismatch. Problem: - st update: used actual_tokens (e.g., 37) - frame_positions: generated grid_h*grid_w (e.g., 63) positions - Result: 1157 positions for 753 tokens (404 extra!) Solution: - Determine actual_frame_tokens BEFORE generating positions - Slice grid_indices to only use actual_frame_tokens: frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx - Both st and positions now use same actual token count Changes: - Moved EVS token count determination before position generation - Added slicing [:, :actual_frame_tokens] to grid_indices - Ensures positions shape matches actual tokens in input_tokens Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 74 +++++++++++++------------- 1 file changed, 38 insertions(+), 36 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 1b3e1caa48a6..fb04d247726e 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1842,7 +1842,39 @@ def get_mrope_input_positions( logger.info(f" st_idx={st_idx}") logger.info(f" current llm_pos_ids_list length: {len(llm_pos_ids_list)}") - # 关键检查点 + # Determine actual token count for this frame FIRST + base_offset = None + for feat_offset in frame_token_counts_map: + if offset >= feat_offset: + base_offset = feat_offset + + if base_offset is not None: + # EVS mode: use actual token count + assert ( + base_offset in frame_token_counts_map + ), f"Found base_offset {base_offset} but not in frame_token_counts_map" + + if base_offset not in frame_counts_idx: + frame_counts_idx[base_offset] = 0 + + counts = frame_token_counts_map[base_offset] + idx = frame_counts_idx[base_offset] + + assert idx < len( + counts + ), f"EVS frame index {idx} out of range (total frames: {len(counts)})" + + actual_frame_tokens = counts[idx] + logger.info( + f" EVS mode: using actual_tokens={actual_frame_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" + ) + frame_counts_idx[base_offset] += 1 + else: + # Non-EVS mode (or image): use theoretical grid size + actual_frame_tokens = llm_grid_h * llm_grid_w + logger.info(" Non-EVS mode: using theoretical token count") + + # Add text segment if text_len > 0: logger.info(f" ✓ Adding text segment: text_len={text_len}") text_positions = ( @@ -1853,7 +1885,6 @@ def get_mrope_input_positions( f" text_positions range: [{text_positions.min()}, {text_positions.max()}]" ) llm_pos_ids_list.append(text_positions) - # Update st_idx for video frame positions st_idx += text_len logger.info(f" Updated st_idx to: {st_idx}") else: @@ -1867,8 +1898,10 @@ def get_mrope_input_positions( " Possible cause: consecutive frames without text tokens between them" ) + # Add frame segment with ACTUAL token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - frame_positions = grid_indices + st_idx + # Only take the first actual_frame_tokens positions + frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx logger.info(" Adding frame positions:") logger.info(f" frame_positions shape: {frame_positions.shape}") logger.info( @@ -1876,39 +1909,8 @@ def get_mrope_input_positions( ) llm_pos_ids_list.append(frame_positions) - # FIX: Use actual token count from EVS mask instead of theoretical grid size - # Find the base offset for this video feature - base_offset = None - for feat_offset in frame_token_counts_map: - if offset >= feat_offset: - base_offset = feat_offset - - if base_offset is not None: - # EVS mode: use actual token count - assert ( - base_offset in frame_token_counts_map - ), f"Found base_offset {base_offset} but not in frame_token_counts_map" - - if base_offset not in frame_counts_idx: - frame_counts_idx[base_offset] = 0 - - counts = frame_token_counts_map[base_offset] - idx = frame_counts_idx[base_offset] - - assert idx < len( - counts - ), f"EVS frame index {idx} out of range (total frames: {len(counts)})" - - actual_tokens = counts[idx] - logger.info( - f" EVS mode: using actual_tokens={actual_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" - ) - st = offset + actual_tokens - frame_counts_idx[base_offset] += 1 - else: - # Non-EVS mode (or image): use theoretical grid size - st = offset + llm_grid_h * llm_grid_w - logger.info(" Non-EVS mode: using theoretical token count") + # Update st using actual token count + st = offset + actual_frame_tokens logger.info(f" Updated st to: {st}") logger.info(f" llm_pos_ids_list now has {len(llm_pos_ids_list)} segments") From c27ae640c6a47a635b0846c43f5f1bbd41fed4b9 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 13:24:46 +0800 Subject: [PATCH 17/22] Solve pre-commit formatting issues - ruff format: passed - Skip ruff-check G004 errors (f-string in logging statements) - These are temporary debug logs for EVS fix validation Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index fb04d247726e..388ae3c19617 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1811,12 +1811,10 @@ def get_mrope_input_positions( token_counts = self._get_actual_frame_token_counts( mm_feature.mm_position, t ) - assert ( - token_counts is not None - ), "EVS enabled but failed to extract frame token counts from is_embed mask" - frame_token_counts_map[mm_feature.mm_position.offset] = ( - token_counts + assert token_counts is not None, ( + "EVS enabled but failed to extract frame token counts from is_embed mask" ) + frame_token_counts_map[mm_feature.mm_position.offset] = token_counts logger.info( f"EVS mode: collected {len(token_counts)} frame token counts for offset {mm_feature.mm_position.offset}" ) @@ -1850,9 +1848,9 @@ def get_mrope_input_positions( if base_offset is not None: # EVS mode: use actual token count - assert ( - base_offset in frame_token_counts_map - ), f"Found base_offset {base_offset} but not in frame_token_counts_map" + assert base_offset in frame_token_counts_map, ( + f"Found base_offset {base_offset} but not in frame_token_counts_map" + ) if base_offset not in frame_counts_idx: frame_counts_idx[base_offset] = 0 @@ -1860,9 +1858,9 @@ def get_mrope_input_positions( counts = frame_token_counts_map[base_offset] idx = frame_counts_idx[base_offset] - assert idx < len( - counts - ), f"EVS frame index {idx} out of range (total frames: {len(counts)})" + assert idx < len(counts), ( + f"EVS frame index {idx} out of range (total frames: {len(counts)})" + ) actual_frame_tokens = counts[idx] logger.info( From 3e855018fa76977160d77e9e4ebadfd7b020305c Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 14:11:42 +0800 Subject: [PATCH 18/22] Fix lint issues in qwen3_vl get_mrope_input_positions Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 117 ++++++++++++++++--------- 1 file changed, 76 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 388ae3c19617..32b50927df57 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1794,8 +1794,8 @@ def get_mrope_input_positions( ) -> tuple[torch.Tensor, int]: logger.info("=" * 80) logger.info("get_mrope_input_positions START") - logger.info(f"Total input_tokens: {len(input_tokens)}") - logger.info(f"Number of mm_features: {len(mm_features)}") + logger.info("Total input_tokens: %s", len(input_tokens)) + logger.info("Number of mm_features: %s", len(mm_features)) # Pre-collect actual frame token counts for EVS mode frame_token_counts_map = {} @@ -1812,33 +1812,42 @@ def get_mrope_input_positions( mm_feature.mm_position, t ) assert token_counts is not None, ( - "EVS enabled but failed to extract frame token counts from is_embed mask" + "EVS enabled but failed to extract frame token counts " + "from is_embed mask" ) frame_token_counts_map[mm_feature.mm_position.offset] = token_counts logger.info( - f"EVS mode: collected {len(token_counts)} frame token counts for offset {mm_feature.mm_position.offset}" + "EVS mode: collected %s frame token counts for offset %s", + len(token_counts), + mm_feature.mm_position.offset, ) - logger.info(f" Token counts: {token_counts}") + logger.info(" Token counts: %s", token_counts) llm_pos_ids_list = [] st = 0 - frame_idx = 0 frame_counts_idx = {} - for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( - input_tokens, mm_features + for frame_idx, (offset, llm_grid_h, llm_grid_w) in enumerate( + self.iter_mm_grid_hw(input_tokens, mm_features) ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 logger.info("-" * 60) - logger.info(f"Frame {frame_idx}:") - logger.info(f" offset={offset}, st={st}, text_len={text_len}") + logger.info("Frame %s:", frame_idx) + logger.info(" offset=%s, st=%s, text_len=%s", offset, st, text_len) + theoretical_tokens = llm_grid_h * llm_grid_w logger.info( - f" grid=({llm_grid_h}, {llm_grid_w}), theoretical_tokens={llm_grid_h * llm_grid_w}" + " grid=(%s, %s), theoretical_tokens=%s", + llm_grid_h, + llm_grid_w, + theoretical_tokens, + ) + logger.info(" st_idx=%s", st_idx) + logger.info( + " current llm_pos_ids_list length: %s", + len(llm_pos_ids_list), ) - logger.info(f" st_idx={st_idx}") - logger.info(f" current llm_pos_ids_list length: {len(llm_pos_ids_list)}") # Determine actual token count for this frame FIRST base_offset = None @@ -1849,7 +1858,9 @@ def get_mrope_input_positions( if base_offset is not None: # EVS mode: use actual token count assert base_offset in frame_token_counts_map, ( - f"Found base_offset {base_offset} but not in frame_token_counts_map" + "Found base_offset {} but not in frame_token_counts_map".format( + base_offset + ) ) if base_offset not in frame_counts_idx: @@ -1859,12 +1870,16 @@ def get_mrope_input_positions( idx = frame_counts_idx[base_offset] assert idx < len(counts), ( - f"EVS frame index {idx} out of range (total frames: {len(counts)})" + "EVS frame index {} out of range (total frames: {})".format( + idx, len(counts) + ) ) actual_frame_tokens = counts[idx] logger.info( - f" EVS mode: using actual_tokens={actual_frame_tokens} (vs theoretical={llm_grid_h * llm_grid_w})" + " EVS mode: using actual_tokens=%s (vs theoretical=%s)", + actual_frame_tokens, + theoretical_tokens, ) frame_counts_idx[base_offset] += 1 else: @@ -1874,26 +1889,31 @@ def get_mrope_input_positions( # Add text segment if text_len > 0: - logger.info(f" ✓ Adding text segment: text_len={text_len}") + logger.info(" ✓ Adding text segment: text_len=%s", text_len) text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - logger.info(f" text_positions shape: {text_positions.shape}") + logger.info(" text_positions shape: %s", text_positions.shape) logger.info( - f" text_positions range: [{text_positions.min()}, {text_positions.max()}]" + " text_positions range: [%s, %s]", + text_positions.min(), + text_positions.max(), ) llm_pos_ids_list.append(text_positions) st_idx += text_len - logger.info(f" Updated st_idx to: {st_idx}") + logger.info(" Updated st_idx to: %s", st_idx) else: logger.warning( - f" ⚠ SKIPPED text segment: text_len={text_len} (zero or negative)" + " ⚠ SKIPPED text segment: text_len=%s (zero or negative)", + text_len, ) logger.warning( - f" This means frame {frame_idx} starts immediately after previous frame" + " This means frame %s starts immediately after previous frame", + frame_idx, ) logger.warning( - " Possible cause: consecutive frames without text tokens between them" + " Possible cause: consecutive frames without text tokens " + "between them" ) # Add frame segment with ACTUAL token count (not theoretical) @@ -1901,63 +1921,78 @@ def get_mrope_input_positions( # Only take the first actual_frame_tokens positions frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx logger.info(" Adding frame positions:") - logger.info(f" frame_positions shape: {frame_positions.shape}") + logger.info(" frame_positions shape: %s", frame_positions.shape) logger.info( - f" frame_positions range: [{frame_positions.min()}, {frame_positions.max()}]" + " frame_positions range: [%s, %s]", + frame_positions.min(), + frame_positions.max(), ) llm_pos_ids_list.append(frame_positions) # Update st using actual token count st = offset + actual_frame_tokens - logger.info(f" Updated st to: {st}") - logger.info(f" llm_pos_ids_list now has {len(llm_pos_ids_list)} segments") - - frame_idx += 1 + logger.info(" Updated st to: %s", st) + logger.info( + " llm_pos_ids_list now has %s segments", len(llm_pos_ids_list) + ) # 处理最后的文本部分 logger.info("-" * 60) logger.info("Final text segment:") - logger.info(f" st={st}, total_tokens={len(input_tokens)}") + logger.info(" st=%s, total_tokens=%s", st, len(input_tokens)) if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - logger.info(f" ✓ Adding final text: text_len={text_len}, st_idx={st_idx}") + logger.info( + " ✓ Adding final text: text_len=%s, st_idx=%s", text_len, st_idx + ) final_text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - logger.info(f" final_text_positions shape: {final_text_positions.shape}") logger.info( - f" final_text_positions range: [{final_text_positions.min()}, {final_text_positions.max()}]" + " final_text_positions shape: %s", final_text_positions.shape + ) + logger.info( + " final_text_positions range: [%s, %s]", + final_text_positions.min(), + final_text_positions.max(), ) llm_pos_ids_list.append(final_text_positions) else: logger.info( - f" ✗ No final text segment (st={st} >= len={len(input_tokens)})" + " ✗ No final text segment (st=%s >= len=%s)", + st, + len(input_tokens), ) logger.info("-" * 60) logger.info("Concatenating positions:") - logger.info(f" Total segments: {len(llm_pos_ids_list)}") + logger.info(" Total segments: %s", len(llm_pos_ids_list)) for i, seg in enumerate(llm_pos_ids_list): - logger.info(f" Segment {i}: shape={seg.shape}") + logger.info(" Segment %s: shape=%s", i, seg.shape) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - logger.info(f" Concatenated positions shape: {llm_positions.shape}") - logger.info(f" Expected shape: (3, {len(input_tokens)})") + logger.info(" Concatenated positions shape: %s", llm_positions.shape) + logger.info(" Expected shape: (3, %s)", len(input_tokens)) if llm_positions.shape[1] != len(input_tokens): logger.error(" ✗✗✗ SHAPE MISMATCH! ✗✗✗") logger.error( - f" Generated {llm_positions.shape[1]} positions for {len(input_tokens)} tokens" + " Generated %s positions for %s tokens", + llm_positions.shape[1], + len(input_tokens), + ) + logger.error( + " Difference: %s", + llm_positions.shape[1] - len(input_tokens), ) - logger.error(f" Difference: {llm_positions.shape[1] - len(input_tokens)}") else: logger.info(" ✓ Shape matches!") mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - logger.info(f"mrope_position_delta: {mrope_position_delta}") + logger.info("mrope_position_delta: %s", mrope_position_delta) logger.info("get_mrope_input_positions END") logger.info("=" * 80) From a159d3061e72458e7add4641c3b9e2ae7a4134d5 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 14:15:06 +0800 Subject: [PATCH 19/22] Apply ruff formatting to qwen3_vl Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 32b50927df57..6afb8a2a39cd 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1933,9 +1933,7 @@ def get_mrope_input_positions( st = offset + actual_frame_tokens logger.info(" Updated st to: %s", st) - logger.info( - " llm_pos_ids_list now has %s segments", len(llm_pos_ids_list) - ) + logger.info(" llm_pos_ids_list now has %s segments", len(llm_pos_ids_list)) # 处理最后的文本部分 logger.info("-" * 60) From 8bf0d85c53f66111e32f6d07b5492769ecccdd8d Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 22:40:00 +0800 Subject: [PATCH 20/22] Remove all debug logging from EVS fix Clean up code by removing all temporary debug logs added during development. Keep only the core fix logic: - Pre-collect actual frame token counts from is_embed mask - Use actual token counts for both position generation and st update - Maintain assertions for fail-fast error detection - Keep text_len > 0 check as defensive programming All pre-commit checks pass: - ruff check: passed - ruff format: passed - mypy: passed Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 133 ++----------------------- 1 file changed, 9 insertions(+), 124 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 6afb8a2a39cd..bd5bbcc739ca 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1792,11 +1792,6 @@ def get_mrope_input_positions( input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: - logger.info("=" * 80) - logger.info("get_mrope_input_positions START") - logger.info("Total input_tokens: %s", len(input_tokens)) - logger.info("Number of mm_features: %s", len(mm_features)) - # Pre-collect actual frame token counts for EVS mode frame_token_counts_map = {} for mm_feature in mm_features: @@ -1816,51 +1811,27 @@ def get_mrope_input_positions( "from is_embed mask" ) frame_token_counts_map[mm_feature.mm_position.offset] = token_counts - logger.info( - "EVS mode: collected %s frame token counts for offset %s", - len(token_counts), - mm_feature.mm_position.offset, - ) - logger.info(" Token counts: %s", token_counts) llm_pos_ids_list = [] st = 0 frame_counts_idx = {} - for frame_idx, (offset, llm_grid_h, llm_grid_w) in enumerate( - self.iter_mm_grid_hw(input_tokens, mm_features) + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( + input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - logger.info("-" * 60) - logger.info("Frame %s:", frame_idx) - logger.info(" offset=%s, st=%s, text_len=%s", offset, st, text_len) - theoretical_tokens = llm_grid_h * llm_grid_w - logger.info( - " grid=(%s, %s), theoretical_tokens=%s", - llm_grid_h, - llm_grid_w, - theoretical_tokens, - ) - logger.info(" st_idx=%s", st_idx) - logger.info( - " current llm_pos_ids_list length: %s", - len(llm_pos_ids_list), - ) - - # Determine actual token count for this frame FIRST + # Determine actual token count for this frame base_offset = None for feat_offset in frame_token_counts_map: if offset >= feat_offset: base_offset = feat_offset if base_offset is not None: - # EVS mode: use actual token count + # EVS mode: use actual token count from is_embed mask assert base_offset in frame_token_counts_map, ( - "Found base_offset {} but not in frame_token_counts_map".format( - base_offset - ) + f"Found base_offset {base_offset} but not in frame_token_counts_map" ) if base_offset not in frame_counts_idx: @@ -1870,129 +1841,43 @@ def get_mrope_input_positions( idx = frame_counts_idx[base_offset] assert idx < len(counts), ( - "EVS frame index {} out of range (total frames: {})".format( - idx, len(counts) - ) + f"EVS frame index {idx} out of range (total frames: {len(counts)})" ) actual_frame_tokens = counts[idx] - logger.info( - " EVS mode: using actual_tokens=%s (vs theoretical=%s)", - actual_frame_tokens, - theoretical_tokens, - ) frame_counts_idx[base_offset] += 1 else: # Non-EVS mode (or image): use theoretical grid size actual_frame_tokens = llm_grid_h * llm_grid_w - logger.info(" Non-EVS mode: using theoretical token count") - # Add text segment + # Add text segment if exists if text_len > 0: - logger.info(" ✓ Adding text segment: text_len=%s", text_len) text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - logger.info(" text_positions shape: %s", text_positions.shape) - logger.info( - " text_positions range: [%s, %s]", - text_positions.min(), - text_positions.max(), - ) llm_pos_ids_list.append(text_positions) st_idx += text_len - logger.info(" Updated st_idx to: %s", st_idx) - else: - logger.warning( - " ⚠ SKIPPED text segment: text_len=%s (zero or negative)", - text_len, - ) - logger.warning( - " This means frame %s starts immediately after previous frame", - frame_idx, - ) - logger.warning( - " Possible cause: consecutive frames without text tokens " - "between them" - ) - # Add frame segment with ACTUAL token count (not theoretical) + # Add frame segment with actual token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) # Only take the first actual_frame_tokens positions frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx - logger.info(" Adding frame positions:") - logger.info(" frame_positions shape: %s", frame_positions.shape) - logger.info( - " frame_positions range: [%s, %s]", - frame_positions.min(), - frame_positions.max(), - ) llm_pos_ids_list.append(frame_positions) # Update st using actual token count st = offset + actual_frame_tokens - logger.info(" Updated st to: %s", st) - logger.info(" llm_pos_ids_list now has %s segments", len(llm_pos_ids_list)) - - # 处理最后的文本部分 - logger.info("-" * 60) - logger.info("Final text segment:") - logger.info(" st=%s, total_tokens=%s", st, len(input_tokens)) - + # Handle final text segment if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - logger.info( - " ✓ Adding final text: text_len=%s, st_idx=%s", text_len, st_idx - ) final_text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) - logger.info( - " final_text_positions shape: %s", final_text_positions.shape - ) - logger.info( - " final_text_positions range: [%s, %s]", - final_text_positions.min(), - final_text_positions.max(), - ) llm_pos_ids_list.append(final_text_positions) - else: - logger.info( - " ✗ No final text segment (st=%s >= len=%s)", - st, - len(input_tokens), - ) - - logger.info("-" * 60) - logger.info("Concatenating positions:") - logger.info(" Total segments: %s", len(llm_pos_ids_list)) - for i, seg in enumerate(llm_pos_ids_list): - logger.info(" Segment %s: shape=%s", i, seg.shape) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - logger.info(" Concatenated positions shape: %s", llm_positions.shape) - logger.info(" Expected shape: (3, %s)", len(input_tokens)) - - if llm_positions.shape[1] != len(input_tokens): - logger.error(" ✗✗✗ SHAPE MISMATCH! ✗✗✗") - logger.error( - " Generated %s positions for %s tokens", - llm_positions.shape[1], - len(input_tokens), - ) - logger.error( - " Difference: %s", - llm_positions.shape[1] - len(input_tokens), - ) - else: - logger.info(" ✓ Shape matches!") - mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() - logger.info("mrope_position_delta: %s", mrope_position_delta) - logger.info("get_mrope_input_positions END") - logger.info("=" * 80) return torch.from_numpy(llm_positions), mrope_position_delta From 32426c9bc2d8b374915a2f12faad5fa5c4ba9020 Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 22:51:39 +0800 Subject: [PATCH 21/22] Refactor: Extract shared EVS mask parsing logic Extract duplicated EVS mask parsing code into a single helper method to improve maintainability and reduce code duplication. Changes: - Add _get_evs_mask_segments() as shared helper method - Refactor _extract_frame_offsets_from_mask() to use helper - Refactor _get_actual_frame_token_counts() to use helper - Add comprehensive docstrings explaining idempotency - Reduce code duplication: ~80 lines -> ~50 lines Benefits: - Single source of truth for mask parsing logic - Easier to maintain and test - Idempotent: pure function with no side effects - Better separation of concerns - Future changes only need to update one place The helper method is a pure function that: - Does not modify any state - Always returns same output for same input - No side effects or mutations Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 96 +++++++++++++++----------- 1 file changed, 56 insertions(+), 40 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index bd5bbcc739ca..8acd7e2f4d87 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1645,26 +1645,38 @@ def iter_mm_grid_hw( else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") - def _extract_frame_offsets_from_mask( + def _get_evs_mask_segments( self, mm_position: PlaceholderRange, expected_frames: int - ) -> list[int] | None: - """Return relative offsets for each EVS-retained frame. + ) -> list[torch.Tensor] | None: + """Extract contiguous segments from EVS is_embed mask. - The prompt processor stores a boolean mask inside ``mm_position`` that - marks which placeholder locations should be populated with video - embeddings. By splitting that mask into contiguous runs we can recover - the start of every retained frame without probing ``input_tokens``. - """ + The EVS (Efficient Video Sampling) mask marks which placeholder + positions should be filled with video embeddings. This method splits + the mask into contiguous segments, where each segment represents one + retained frame. + This is a pure function - it does not modify any state and always + returns the same output for the same input (idempotent). + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frame segments + + Returns: + List of tensors, each containing indices for one frame segment, + or None if EVS is not enabled or validation fails. + """ is_embed_mask = getattr(mm_position, "is_embed", None) if is_embed_mask is None: return None + # Find all True positions in the mask mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() if true_indices.numel() == 0: return None + # Split into contiguous segments (where diff > 1 indicates a gap) if true_indices.numel() == 1: segments = [true_indices] else: @@ -1677,58 +1689,62 @@ def _extract_frame_offsets_from_mask( true_indices, split_points.add(1).tolist() ) + # Validate segment count matches expected frames if len(segments) < expected_frames: logger.debug( - "EVS mask segments (%d) do not match total frames (%d)", + "EVS mask segments (%d) do not match expected frames (%d)", len(segments), expected_frames, ) return None - return [int(segment[0].item()) for segment in segments[:expected_frames]] + return segments[:expected_frames] - def _get_actual_frame_token_counts( + def _extract_frame_offsets_from_mask( self, mm_position: PlaceholderRange, expected_frames: int ) -> list[int] | None: - """Return actual token count for each EVS-retained frame. + """Return relative offsets for each EVS-retained frame. - This function calculates the actual number of tokens per frame by - analyzing the is_embed mask, accounting for EVS pruning. + The prompt processor stores a boolean mask inside ``mm_position`` that + marks which placeholder locations should be populated with video + embeddings. By splitting that mask into contiguous runs we can recover + the start of every retained frame without probing ``input_tokens``. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames Returns: - List of token counts for each frame, or None if EVS is not enabled. + List of starting offsets (relative to mm_position) for each frame, + or None if EVS is not enabled. """ - - is_embed_mask = getattr(mm_position, "is_embed", None) - if is_embed_mask is None: + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: return None - mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) - true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() - if true_indices.numel() == 0: - return None + return [int(segment[0].item()) for segment in segments] - if true_indices.numel() == 1: - segments = [true_indices] - else: - diffs = torch.diff(true_indices) - split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() - if split_points.numel() == 0: - segments = [true_indices] - else: - segments = torch.tensor_split( - true_indices, split_points.add(1).tolist() - ) + def _get_actual_frame_token_counts( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return actual token count for each EVS-retained frame. - if len(segments) < expected_frames: - logger.debug( - "EVS mask segments (%d) do not match total frames (%d)", - len(segments), - expected_frames, - ) + This function calculates the actual number of tokens per frame by + analyzing the is_embed mask, accounting for EVS pruning. Each frame + may have a different token count due to content-aware pruning. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of token counts for each frame, or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: return None - return [len(seg) for seg in segments[:expected_frames]] + return [len(seg) for seg in segments] def recompute_mrope_positions( self, From a2fa91601a852a5c86db2f16d40b4e6e3f7f7aae Mon Sep 17 00:00:00 2001 From: "zitian.zhao" Date: Sun, 7 Dec 2025 23:00:39 +0800 Subject: [PATCH 22/22] Remove text_len > 0 check as it's always true According to Qwen3-VL chat template specification, video tokens cannot appear at the start of input_tokens. The format is always: <|im_start|>user<|vision_start|><|video_pad|>... This means text_len = offset - st is always >= 3 for the first frame (at minimum: im_start, role, vision_start tokens). Therefore, the 'if text_len > 0' check is redundant and can be safely removed to simplify the code. Changes: - Remove conditional check for text_len > 0 - Always add text_positions to the list - Simplifies code by removing unnecessary branch Signed-off-by: zitian.zhao --- vllm/model_executor/models/qwen3_vl.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index 8acd7e2f4d87..8f03b244bf80 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -1866,13 +1866,12 @@ def get_mrope_input_positions( # Non-EVS mode (or image): use theoretical grid size actual_frame_tokens = llm_grid_h * llm_grid_w - # Add text segment if exists - if text_len > 0: - text_positions = ( - np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx - ) - llm_pos_ids_list.append(text_positions) - st_idx += text_len + # Add text segment + text_positions = ( + np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx + ) + llm_pos_ids_list.append(text_positions) + st_idx += text_len # Add frame segment with actual token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1)