Skip to content

Commit 1073ba6

Browse files
authored
[LoRA] Optimize 3D MoE logic (#29222)
Signed-off-by: Jee Jee Li <[email protected]>
1 parent c309bb5 commit 1073ba6

File tree

11 files changed

+395
-103
lines changed

11 files changed

+395
-103
lines changed

tests/lora/test_gptoss_tp.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import pytest
5+
46
import vllm
57
from vllm.lora.request import LoRARequest
68

@@ -84,14 +86,17 @@ def test_gpt_oss_lora(gptoss20b_lora_files):
8486

8587

8688
@multi_gpu_test(num_gpus=2)
87-
def test_gpt_oss_lora_tp2(gptoss20b_lora_files):
89+
@pytest.mark.parametrize("fully_sharded_loras", [False, True])
90+
def test_gpt_oss_lora_tp2(gptoss20b_lora_files, fully_sharded_loras):
8891
llm = vllm.LLM(
8992
MODEL_PATH,
9093
max_model_len=1024,
9194
enable_lora=True,
9295
max_loras=2,
9396
max_lora_rank=8,
97+
max_num_seqs=16,
9498
tensor_parallel_size=2,
99+
fully_sharded_loras=fully_sharded_loras,
95100
compilation_config=vllm.config.CompilationConfig( # Avoid OOM
96101
cudagraph_specialize_lora=False,
97102
),

vllm/lora/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
QKVParallelLinearWithLoRA,
1212
QKVParallelLinearWithShardedLoRA,
1313
)
14-
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
14+
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
1515
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
1616
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
1717
from vllm.lora.layers.row_parallel_linear import (
@@ -38,4 +38,5 @@
3838
"ReplicatedLinearWithLoRA",
3939
"LoRAMapping",
4040
"FusedMoEWithLoRA",
41+
"FusedMoE3DWithLoRA",
4142
]

vllm/lora/layers/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def reset_lora(self, index: int):
4242
def set_lora(
4343
self,
4444
index: int,
45-
lora_a: torch.Tensor,
46-
lora_b: torch.Tensor,
45+
lora_a: torch.Tensor | list[torch.Tensor],
46+
lora_b: torch.Tensor | list[torch.Tensor],
4747
):
4848
"""Overwrites lora tensors at index."""
4949
...

vllm/lora/layers/base_linear.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,15 @@ def reset_lora(self, index: int):
9494
def set_lora(
9595
self,
9696
index: int,
97-
lora_a: torch.Tensor,
98-
lora_b: torch.Tensor,
97+
lora_a: torch.Tensor | list[torch.Tensor],
98+
lora_b: torch.Tensor | list[torch.Tensor],
9999
):
100100
# Except for QKVParallelLinearWithLoRA and
101101
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
102102
# store weights in a tuple of size 1. These two layers will
103103
# override this function.
104+
assert isinstance(lora_a, torch.Tensor)
105+
assert isinstance(lora_b, torch.Tensor)
104106
assert (
105107
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
106108
)

vllm/lora/layers/column_parallel_linear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ def slice_lora_b(
246246
def set_lora(
247247
self,
248248
index: int,
249-
lora_a: torch.Tensor,
250-
lora_b: torch.Tensor,
249+
lora_a: torch.Tensor | list[torch.Tensor],
250+
lora_b: torch.Tensor | list[torch.Tensor],
251251
):
252252
self.reset_lora(index)
253253

0 commit comments

Comments
 (0)