diff --git a/vllm/model_executor/models/qwen3_vl.py b/vllm/model_executor/models/qwen3_vl.py index f0ba631e6680..8f03b244bf80 100644 --- a/vllm/model_executor/models/qwen3_vl.py +++ b/vllm/model_executor/models/qwen3_vl.py @@ -67,12 +67,19 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.evs import ( + compute_mrope_for_media, + compute_retained_tokens_count, + compute_retention_mask, + recompute_mrope_positions, +) from vllm.multimodal.inputs import ( MultiModalDataDict, MultiModalFeatureSpec, MultiModalFieldConfig, MultiModalKwargsItem, MultiModalKwargsItems, + PlaceholderRange, VideoItem, ) from vllm.multimodal.parse import ImageSize, MultiModalDataItems, MultiModalDataParser @@ -92,6 +99,7 @@ SupportsLoRA, SupportsMRoPE, SupportsMultiModal, + SupportsMultiModalPruning, SupportsPP, ) from .qwen2_5_vl import ( @@ -1042,13 +1050,39 @@ def get_video_replacement_qwen3vl(item_idx: int): tokenizer.encode(f"<{curr_time:.1f} seconds>", add_special_tokens=False) for curr_time in timestamps ] - num_tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + tokens_per_frame = int(grid_thw[1:].prod()) // merge_length + per_frame_token_counts = [tokens_per_frame for _ in frames_idx_token] + + video_pruning_rate = self.info.ctx.get_mm_config().video_pruning_rate + if video_pruning_rate is not None and video_pruning_rate > 0.0: + total_retained = compute_retained_tokens_count( + tokens_per_frame, + len(frames_idx_token), + video_pruning_rate, + ) + if len(frames_idx_token) == 0: + per_frame_token_counts = [] + elif len(frames_idx_token) == 1: + per_frame_token_counts = [tokens_per_frame] + else: + first_frame_tokens = tokens_per_frame + remaining_tokens = max(total_retained - first_frame_tokens, 0) + base = remaining_tokens // (len(frames_idx_token) - 1) + remainder = remaining_tokens % (len(frames_idx_token) - 1) + per_frame_token_counts = [first_frame_tokens] + for frame_idx in range(1, len(frames_idx_token)): + extra = base + (1 if (frame_idx - 1) < remainder else 0) + per_frame_token_counts.append(extra) + placeholder = [] - for frame_idx in frames_idx_token: - placeholder.extend(frame_idx) + for frame_idx, timestamp_tokens in enumerate(frames_idx_token): + placeholder.extend(timestamp_tokens) + tokens_this_frame = per_frame_token_counts[ + frame_idx if frame_idx < len(per_frame_token_counts) else -1 + ] placeholder.extend( [vision_start_token_id] - + [video_token_id] * num_tokens_per_frame + + [video_token_id] * tokens_this_frame + [vision_end_token_id] ) return PromptUpdateDetails.select_token_id(placeholder, video_token_id) @@ -1189,6 +1223,7 @@ class Qwen3VLForConditionalGeneration( SupportsPP, SupportsMRoPE, SupportsEagle3, + SupportsMultiModalPruning, ): merge_by_field_config = True multimodal_cpu_fields = {"image_grid_thw", "video_grid_thw"} @@ -1234,6 +1269,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.config = config self.multimodal_config = multimodal_config self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data" + self.video_pruning_rate = multimodal_config.video_pruning_rate + self.is_multimodal_pruning_enabled = ( + multimodal_config.is_multimodal_pruning_enabled() + ) + if not multimodal_config.get_limit_per_prompt( "image" ) and not multimodal_config.get_limit_per_prompt("video"): @@ -1420,6 +1460,109 @@ def _process_video_input( sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist() return video_embeds.split(sizes) + def _postprocess_image_embeds_evs( + self, + image_embeds_split: tuple[torch.Tensor, ...], + image_input: Qwen2_5_VLImageInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Append mrope positions for each for images. + This is necessary to recover correct mrope + positions after video pruning + + Args: + image_embeds_split: Tuple of image embeddings for + each image item. + image_input: Image input data. + + Returns: + Tuple of image embeddings for each image item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + merge_size = self.visual.spatial_merge_size + grid_thw = image_input["image_grid_thw"] + grid_thw_list = grid_thw.tolist() + image_embeds_out = [] + for emb, size in zip(image_embeds_split, grid_thw_list): + positions = compute_mrope_for_media(size, merge_size).to(emb.device) + emb = torch.cat([emb, positions], dim=1) + image_embeds_out.append(emb) + image_embeds_split = image_embeds_out + return tuple(image_embeds_split) + + def _postprocess_video_embeds_evs( + self, + video_embeds_split: tuple[torch.Tensor, ...], + video_input: Qwen2_5_VLVideoInputs, + ) -> tuple[torch.Tensor, ...]: + """ + Prunes video embeddings via Efficient Video Sampling (EVS) + and then appends mrope positions for each retained embeddings + + Args: + video_embeds_split: Tuple of video embeddings for each video item. + video_input: Video input data. + + Returns: + Tuple of video embeddings for each video item. + Resulting embeddings will have extra 4 channels for + computed mrope positions. + """ + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + grid_thw_list = grid_thw.tolist() + merge_size = self.visual.spatial_merge_size + + # Cast to long to match the original code + # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa + second_per_grid_ts = video_input.get("second_per_grid_ts") + if second_per_grid_ts is None: + # For Qwen3-VL, second_per_grid_ts might not be available + # Use default value of 1.0 for each video + second_per_grid_ts = torch.ones(len(grid_thw_list), dtype=torch.long) + else: + second_per_grid_ts = second_per_grid_ts.long() + tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0) + + video_embeds_out = [] + for emb, size, video_second_per_grid_t in zip( + video_embeds_split, grid_thw_list, second_per_grid_ts + ): + # For each video, we compute retention mask using EVS + retention_mask = compute_retention_mask( + emb, + size, + spatial_merge_size=self.visual.spatial_merge_size, + q=self.video_pruning_rate, + ) + + # Debug logging for EVS pruning + logger.debug( + "EVS: Video tokens pruned from %d to %d (T=%d,H=%d,W=%d, " + "pruning_rate=%.2f, reduction=%.1f%%)", + emb.shape[0], + retention_mask.sum().item(), + size[0], + size[1], + size[2], + self.video_pruning_rate, + (1 - retention_mask.float().mean().item()) * 100, + ) + + positions = compute_mrope_for_media( + size, + merge_size, + tokens_per_second=tokens_per_second, + video_second_per_grid=video_second_per_grid_t.item(), + ).to(emb.device) + + emb = emb[retention_mask] + positions = positions[retention_mask] + emb = torch.cat([emb, positions], dim=1) + video_embeds_out.append(emb) + return tuple(video_embeds_out) + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: mm_input_by_modality = {} for input_key in kwargs: @@ -1442,6 +1585,20 @@ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: def iter_mm_grid_hw( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec] ) -> Iterator[tuple[int, int, int]]: + """ + Iterate over multimodal features and yield grid information. + + For videos with EVS (Efficient Video Sampling) enabled, this function + computes the offset based on the pruned token count rather than relying + on input_tokens.index(), which would fail when tokens are pruned. + + Args: + input_tokens: List of token IDs in the prompt + mm_features: List of multimodal feature specifications + + Yields: + Tuple of (offset, grid_h, grid_w) for each frame/image + """ video_token_id = self.config.video_token_id spatial_merge_size = self.config.vision_config.spatial_merge_size for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset): @@ -1454,42 +1611,289 @@ def iter_mm_grid_hw( t, h, w = mm_feature.data["video_grid_thw"].data.tolist() llm_grid_h = h // spatial_merge_size llm_grid_w = w // spatial_merge_size - for _ in range(t): - offset = input_tokens.index(video_token_id, offset) - yield offset, llm_grid_h, llm_grid_w - offset += llm_grid_h * llm_grid_w + + # Check if EVS (Efficient Video Sampling) is enabled + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + + if is_evs_enabled: + frame_offsets = self._extract_frame_offsets_from_mask( + mm_feature.mm_position, t + ) + if frame_offsets is not None: + for rel_offset in frame_offsets: + yield offset + rel_offset, llm_grid_h, llm_grid_w + continue + + # If EVS is enabled but mask is missing, this indicates a bug + # in the prompt processing pipeline. The is_embed mask should + # always be present when video_pruning_rate > 0. + raise RuntimeError( + f"EVS is enabled (pruning_rate={self.video_pruning_rate}) " + "but is_embed mask is missing from mm_position. " + "This indicates a bug in prompt processing." + ) + else: + # Non-EVS mode: Use original logic with input_tokens.index() + for _ in range(t): + offset = input_tokens.index(video_token_id, offset) + yield offset, llm_grid_h, llm_grid_w + offset += llm_grid_h * llm_grid_w else: raise ValueError(f"Unsupported modality: {mm_feature.modality}") + def _get_evs_mask_segments( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[torch.Tensor] | None: + """Extract contiguous segments from EVS is_embed mask. + + The EVS (Efficient Video Sampling) mask marks which placeholder + positions should be filled with video embeddings. This method splits + the mask into contiguous segments, where each segment represents one + retained frame. + + This is a pure function - it does not modify any state and always + returns the same output for the same input (idempotent). + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frame segments + + Returns: + List of tensors, each containing indices for one frame segment, + or None if EVS is not enabled or validation fails. + """ + is_embed_mask = getattr(mm_position, "is_embed", None) + if is_embed_mask is None: + return None + + # Find all True positions in the mask + mask_tensor = torch.as_tensor(is_embed_mask, dtype=torch.bool).view(-1) + true_indices = torch.nonzero(mask_tensor, as_tuple=False).flatten() + if true_indices.numel() == 0: + return None + + # Split into contiguous segments (where diff > 1 indicates a gap) + if true_indices.numel() == 1: + segments = [true_indices] + else: + diffs = torch.diff(true_indices) + split_points = torch.nonzero(diffs != 1, as_tuple=False).flatten() + if split_points.numel() == 0: + segments = [true_indices] + else: + segments = torch.tensor_split( + true_indices, split_points.add(1).tolist() + ) + + # Validate segment count matches expected frames + if len(segments) < expected_frames: + logger.debug( + "EVS mask segments (%d) do not match expected frames (%d)", + len(segments), + expected_frames, + ) + return None + + return segments[:expected_frames] + + def _extract_frame_offsets_from_mask( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return relative offsets for each EVS-retained frame. + + The prompt processor stores a boolean mask inside ``mm_position`` that + marks which placeholder locations should be populated with video + embeddings. By splitting that mask into contiguous runs we can recover + the start of every retained frame without probing ``input_tokens``. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of starting offsets (relative to mm_position) for each frame, + or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [int(segment[0].item()) for segment in segments] + + def _get_actual_frame_token_counts( + self, mm_position: PlaceholderRange, expected_frames: int + ) -> list[int] | None: + """Return actual token count for each EVS-retained frame. + + This function calculates the actual number of tokens per frame by + analyzing the is_embed mask, accounting for EVS pruning. Each frame + may have a different token count due to content-aware pruning. + + Args: + mm_position: MultiModal position containing the is_embed mask + expected_frames: Expected number of frames + + Returns: + List of token counts for each frame, or None if EVS is not enabled. + """ + segments = self._get_evs_mask_segments(mm_position, expected_frames) + if segments is None: + return None + + return [len(seg) for seg in segments] + + def recompute_mrope_positions( + self, + input_ids: list[int], + multimodal_embeddings: tuple[torch.Tensor, ...], + mrope_positions: torch.LongTensor, + num_computed_tokens: int, + ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]: + """ + Update part of input mrope positions (starting with + num_computed_tokens index). Original mrope_positions are computed + for unpruned sequence and becomes incorrect once pruning occurs, + so once we prune media tokens we should reflect this in the + mrope_positions before we feed it to LLM. + + Args: + input_ids: (N,) All input tokens of the prompt (Containing + entire sequence). + multimodal_embeddings: Tuple of multimodal embeddings. + mrope_positions: Existing mrope positions (3, N) for entire + sequence + num_computed_tokens: A number of computed tokens so far. + + Returns: + Tuple of (multimodal_embeddings, mrope_positions, + mrope_position_delta). + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + # Device + device = ( + multimodal_embeddings[0].device + if len(multimodal_embeddings) + else mrope_positions.device + ) + + # Tensors + input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long) + + mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings] + mm_embeddings_pos = [ + mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings + ] + + positions, mrope_positions_delta = recompute_mrope_positions( + input_ids_t, + mm_embeddings_pos, + mrope_positions, + num_computed_tokens, + vision_start_token_id, + image_token_id, + video_token_id, + ) + + return tuple(mm_embeddings_out), positions, mrope_positions_delta + def get_mrope_input_positions( self, input_tokens: list[int], mm_features: list[MultiModalFeatureSpec], ) -> tuple[torch.Tensor, int]: + # Pre-collect actual frame token counts for EVS mode + frame_token_counts_map = {} + for mm_feature in mm_features: + if mm_feature.modality == "video": + is_evs_enabled = ( + hasattr(self, "video_pruning_rate") + and self.video_pruning_rate is not None + and self.video_pruning_rate > 0.0 + ) + if is_evs_enabled: + t = mm_feature.data["video_grid_thw"].data.tolist()[0] + token_counts = self._get_actual_frame_token_counts( + mm_feature.mm_position, t + ) + assert token_counts is not None, ( + "EVS enabled but failed to extract frame token counts " + "from is_embed mask" + ) + frame_token_counts_map[mm_feature.mm_position.offset] = token_counts + llm_pos_ids_list = [] st = 0 + frame_counts_idx = {} + for offset, llm_grid_h, llm_grid_w in self.iter_mm_grid_hw( input_tokens, mm_features ): text_len = offset - st st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - llm_pos_ids_list.append( + + # Determine actual token count for this frame + base_offset = None + for feat_offset in frame_token_counts_map: + if offset >= feat_offset: + base_offset = feat_offset + + if base_offset is not None: + # EVS mode: use actual token count from is_embed mask + assert base_offset in frame_token_counts_map, ( + f"Found base_offset {base_offset} but not in frame_token_counts_map" + ) + + if base_offset not in frame_counts_idx: + frame_counts_idx[base_offset] = 0 + + counts = frame_token_counts_map[base_offset] + idx = frame_counts_idx[base_offset] + + assert idx < len(counts), ( + f"EVS frame index {idx} out of range (total frames: {len(counts)})" + ) + + actual_frame_tokens = counts[idx] + frame_counts_idx[base_offset] += 1 + else: + # Non-EVS mode (or image): use theoretical grid size + actual_frame_tokens = llm_grid_h * llm_grid_w + + # Add text segment + text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(text_positions) + st_idx += text_len + # Add frame segment with actual token count (not theoretical) grid_indices = np.indices((1, llm_grid_h, llm_grid_w)).reshape(3, -1) - llm_pos_ids_list.append(grid_indices + text_len + st_idx) - st = offset + llm_grid_h * llm_grid_w + # Only take the first actual_frame_tokens positions + frame_positions = grid_indices[:, :actual_frame_tokens] + st_idx + llm_pos_ids_list.append(frame_positions) + + # Update st using actual token count + st = offset + actual_frame_tokens + # Handle final text segment if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 text_len = len(input_tokens) - st - llm_pos_ids_list.append( + final_text_positions = ( np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx ) + llm_pos_ids_list.append(final_text_positions) llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + return torch.from_numpy(llm_positions), mrope_position_delta def get_language_model(self) -> torch.nn.Module: @@ -1510,9 +1914,17 @@ def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings | None: multimodal_input = mm_input_by_modality[modality] if modality == "image": image_embeddings = self._process_image_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + image_embeddings = self._postprocess_image_embeds_evs( + image_embeddings, multimodal_input + ) multimodal_embeddings += tuple(image_embeddings) if modality == "video": video_embeddings = self._process_video_input(multimodal_input) + if self.is_multimodal_pruning_enabled: + video_embeddings = self._postprocess_video_embeds_evs( + video_embeddings, multimodal_input + ) multimodal_embeddings += tuple(video_embeddings) return multimodal_embeddings