Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions examples/offline_inference/whisper_multilora_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this example is similar to multilora_inference.py, so do we need to add this example?

# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This example shows how to use multi-LoRA functionality with
Whisper models for speech-to-text transcription.
Usage:
python whisper_multilora_inference.py
Note: Replace LORA_PATH with your actual LoRA adapter path.
If you don't have a LoRA adapter, the example will run with
the base model only.
"""

import os

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.lora.request import LoRARequest


def create_whisper_prompt(language: str = "en") -> dict:
"""Create a Whisper transcription prompt with audio input.
Args:
language: ISO 639-1 language code (e.g., "en", "ko", "ja")
Returns:
Dictionary with prompt and multi-modal data
"""
# Load sample audio from vLLM assets
audio_asset = AudioAsset("mary_had_lamb")
audio_data = audio_asset.audio_and_sample_rate

# Whisper prompt format:
# <|startoftranscript|><|language|><|task|><|notimestamps|>
prompt = f"<|startoftranscript|><|{language}|><|transcribe|><|notimestamps|>"

return {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_data,
},
}


def run_base_model_inference(llm: LLM, sampling_params: SamplingParams) -> None:
"""Run inference using the base Whisper model without LoRA."""
print("\n" + "=" * 60)
print("Running inference with BASE MODEL (no LoRA)")
print("=" * 60)

inputs = create_whisper_prompt(language="en")
outputs = llm.generate([inputs], sampling_params=sampling_params)

for output in outputs:
print(f"Transcription: {output.outputs[0].text}")


def run_lora_inference(
llm: LLM,
sampling_params: SamplingParams,
lora_path: str,
lora_name: str,
lora_id: int,
) -> None:
"""Run inference using a specific LoRA adapter.
Args:
llm: The vLLM engine
sampling_params: Sampling parameters
lora_path: Path to the LoRA adapter
lora_name: Name identifier for the LoRA
lora_id: Unique integer ID for the LoRA
"""
print("\n" + "=" * 60)
print(f"Running inference with LoRA: {lora_name}")
print("=" * 60)

inputs = create_whisper_prompt(language="en")
lora_request = LoRARequest(lora_name, lora_id, lora_path)

outputs = llm.generate(
[inputs],
sampling_params=sampling_params,
lora_request=lora_request,
)

for output in outputs:
print(f"Transcription: {output.outputs[0].text}")


def main():
"""Main function demonstrating Whisper Multi-LoRA inference."""
# Initialize Whisper model with LoRA support enabled
print("Initializing Whisper model with Multi-LoRA support...")
llm = LLM(
model="openai/whisper-large-v3-turbo",
enable_lora=True,
max_loras=4, # Maximum number of LoRAs to keep in memory
max_lora_rank=64, # Maximum LoRA rank supported
max_model_len=448, # Whisper's max target positions
dtype="half",
gpu_memory_utilization=0.8,
trust_remote_code=True,
)

sampling_params = SamplingParams(
temperature=0,
max_tokens=200,
)

# Run base model inference
run_base_model_inference(llm, sampling_params)

# Example LoRA paths - replace with your actual LoRA adapters
lora_paths = [
("lora_adapter_1", "/path/to/your/lora_adapter_1"),
("lora_adapter_2", "/path/to/your/lora_adapter_2"),
]

# Run inference with each LoRA adapter (if paths exist)
for lora_id, (lora_name, lora_path) in enumerate(lora_paths, start=1):
if os.path.exists(lora_path):
run_lora_inference(llm, sampling_params, lora_path, lora_name, lora_id)
else:
print(f"\nSkipping {lora_name}: path does not exist ({lora_path})")
print("To use LoRA adapters, update lora_paths with valid paths.")

print("\n" + "=" * 60)
print("Multi-LoRA inference complete!")
print("=" * 60)


if __name__ == "__main__":
main()
168 changes: 168 additions & 0 deletions tests/lora/test_whisper_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for Whisper Multi-LoRA support.

This module tests:
1. WhisperForConditionalGeneration LoRA interface compliance
2. MergedQKVParallelLinearWithLoRA support for KV-only (2-slice) configuration
3. WorkerLoRAManager compatibility with Whisper's max_target_positions
"""

import pytest
import torch

from vllm.lora.layers import (
MergedQKVParallelLinearWithLoRA,
)
from vllm.model_executor.layers.linear import QKVParallelLinear
from vllm.model_executor.models.whisper import WhisperForConditionalGeneration
from vllm.platforms import current_platform

pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda_alike() or current_platform.is_cpu()),
reason="Backend not supported",
)


class TestWhisperLoRAInterface:
"""Test that WhisperForConditionalGeneration has proper LoRA support."""

def test_supports_lora_attribute(self):
"""Verify that WhisperForConditionalGeneration has SupportsLoRA interface."""
from vllm.model_executor.models.interfaces import SupportsLoRA

assert issubclass(WhisperForConditionalGeneration, SupportsLoRA), (
"WhisperForConditionalGeneration should inherit from SupportsLoRA"
)

def test_embedding_modules_defined(self):
"""Verify embedding_modules attribute is defined."""
assert hasattr(WhisperForConditionalGeneration, "embedding_modules")
assert isinstance(WhisperForConditionalGeneration.embedding_modules, dict)

def test_embedding_padding_modules_defined(self):
"""Verify embedding_padding_modules attribute is defined."""
assert hasattr(WhisperForConditionalGeneration, "embedding_padding_modules")
assert isinstance(
WhisperForConditionalGeneration.embedding_padding_modules, list
)

def test_packed_modules_mapping_format(self):
"""Verify packed_modules_mapping has correct format for LoRA."""
mapping = WhisperForConditionalGeneration.packed_modules_mapping

# Should have qkv_proj and kv_proj mappings
assert "qkv_proj" in mapping, "Missing qkv_proj in packed_modules_mapping"
assert "kv_proj" in mapping, "Missing kv_proj in packed_modules_mapping"

# qkv_proj should map to [q_proj, k_proj, v_proj]
assert mapping["qkv_proj"] == ["q_proj", "k_proj", "v_proj"]

# kv_proj should map to [k_proj, v_proj] (for cross-attention)
assert mapping["kv_proj"] == ["k_proj", "v_proj"]


class TestMergedQKVParallelLinearWithLoRAKVOnly:
"""Test MergedQKVParallelLinearWithLoRA with KV-only (2-slice) configuration."""

def test_can_replace_layer_accepts_2_modules(self):
"""Verify can_replace_layer accepts 2-module (KV-only) configurations."""
from vllm.config.lora import LoRAConfig

# Create a mock QKVParallelLinear layer
# This simulates a KV-only projection (like Whisper's encoder_attn.kv_proj)
linear = QKVParallelLinear(
hidden_size=512,
head_size=64,
total_num_heads=8,
total_num_kv_heads=8,
bias=False,
params_dtype=torch.float16,
)

lora_config = LoRAConfig(
max_lora_rank=32,
max_loras=4,
max_cpu_loras=4,
lora_extra_vocab_size=0,
)

# Test with 2 modules (KV-only, like encoder_attn.kv_proj)
packed_modules_2 = ["k_proj", "v_proj"]
result_2 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
source_layer=linear,
lora_config=lora_config,
packed_modules_list=packed_modules_2,
model_config=None,
)
assert result_2 is True, "Should accept 2-module (KV-only) configuration"

# Test with 3 modules (QKV, like self_attn.qkv_proj)
packed_modules_3 = ["q_proj", "k_proj", "v_proj"]
result_3 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
source_layer=linear,
lora_config=lora_config,
packed_modules_list=packed_modules_3,
model_config=None,
)
assert result_3 is True, "Should accept 3-module (QKV) configuration"

# Test with 1 module (should be rejected)
packed_modules_1 = ["q_proj"]
result_1 = MergedQKVParallelLinearWithLoRA.can_replace_layer(
source_layer=linear,
lora_config=lora_config,
packed_modules_list=packed_modules_1,
model_config=None,
)
assert result_1 is False, "Should reject 1-module configuration"


class TestWorkerLoRAManagerWhisperCompat:
"""Test WorkerLoRAManager compatibility with Whisper config."""

def test_max_position_embeddings_fallback(self):
"""Test that max_target_positions is used when missing."""

# Create a mock config similar to Whisper's
class MockWhisperConfig:
def __init__(self):
self.max_target_positions = 448
# Note: no max_position_embeddings attribute

def get_text_config(self):
return self

config = MockWhisperConfig()

# Simulate the logic from WorkerLoRAManager
max_pos = getattr(
config,
"max_position_embeddings",
getattr(config, "max_target_positions", None),
)

assert max_pos == 448, "Should fall back to max_target_positions"

def test_max_position_embeddings_priority(self):
"""Test that max_position_embeddings takes priority when present."""

class MockLLMConfig:
def __init__(self):
self.max_position_embeddings = 4096
self.max_target_positions = 448

def get_text_config(self):
return self

config = MockLLMConfig()

# Simulate the logic from WorkerLoRAManager
max_pos = getattr(
config,
"max_position_embeddings",
getattr(config, "max_target_positions", None),
)

assert max_pos == 4096, "Should use max_position_embeddings when present"
62 changes: 34 additions & 28 deletions vllm/lora/layers/column_parallel_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,6 @@ class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):

def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
# There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes)

self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
self.kv_proj_shard_size = (
Expand All @@ -366,16 +364,23 @@ def __init__(self, base_layer: QKVParallelLinear) -> None:
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas

self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.output_ids = (
self.q_shard_id,
self.kv_shard_id,
self.kv_shard_id,
)
# Build output_slices and output_ids dynamically to support both
# QKV (3 slices) and KV-only (2 slices) configurations.
# KV-only is used in cross-attention layers (e.g., Whisper encoder_attn).
slices = []
ids = []
if self.q_proj_shard_size > 0:
slices.append(self.q_proj_shard_size)
ids.append(self.q_shard_id)
if self.kv_proj_shard_size > 0:
slices.append(self.kv_proj_shard_size)
ids.append(self.kv_shard_id)
slices.append(self.kv_proj_shard_size)
ids.append(self.kv_shard_id)

self.output_slices = tuple(slices)
self.output_ids = tuple(ids)
self.n_slices = len(self.output_slices)

def create_lora_weights(
self,
Expand All @@ -398,7 +403,11 @@ def can_replace_layer(
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use MergedColumnParallelLinear rather than QKVParallelLinear in base model?

# Support both QKV (3 modules) and KV-only (2 modules) configurations
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) in (
2,
3,
)


# These following layers are based on the tensor parallelism strategy given in
Expand Down Expand Up @@ -539,21 +548,18 @@ class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
if lora_a[0] is not None
else None,
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
if lora_a[1] is not None
else None,
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
if lora_a[2] is not None
else None,
]
return lora_a
# NOTE: lora_a contains n_slices subloras, and each sublora could be None.
# n_slices is 3 for QKV and 2 for KV-only configurations.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(self.n_slices)]
start_idx = [self.tp_rank * shard_size[i] for i in range(self.n_slices)]
result: list[torch.Tensor | None] = []
for i in range(self.n_slices):
lora_a_i = lora_a[i]
if lora_a_i is not None:
result.append(lora_a_i[start_idx[i] : start_idx[i] + shard_size[i], :])
else:
result.append(None)
return result

def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
Expand Down
Loading