Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 37 additions & 33 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import inspect
from functools import cache
from typing import cast, get_args
from typing import NamedTuple, cast, get_args

import torch

from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.abstract import AttentionBackend, AttentionType
from vllm.attention.backends.registry import (
MAMBA_TYPE_TO_BACKEND_MAP,
MambaAttentionBackendEnum,
Expand All @@ -19,6 +19,30 @@
logger = init_logger(__name__)


class AttentionSelectorConfig(NamedTuple):
head_size: int
dtype: torch.dtype
kv_cache_dtype: CacheDType | None
block_size: int | None
use_mla: bool = False
has_sink: bool = False
use_sparse: bool = False
use_mm_prefix: bool = False
attn_type: str = AttentionType.DECODER

def __repr__(self):
return (
f"AttentionSelectorConfig(head_size={self.head_size}, "
f"dtype={self.dtype}, "
f"kv_cache_dtype={self.kv_cache_dtype}, "
f"block_size={self.block_size}, "
f"use_mla={self.use_mla}, "
f"has_sink={self.has_sink}, "
f"use_sparse={self.use_sparse}, "
f"attn_type={self.attn_type})"
)


def get_attn_backend(
head_size: int,
dtype: torch.dtype,
Expand All @@ -44,8 +68,7 @@ def get_attn_backend(
vllm_config = get_current_vllm_config()
backend_enum = vllm_config.attention_config.backend

return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=cast(CacheDType | None, kv_cache_dtype),
Expand All @@ -54,22 +77,19 @@ def get_attn_backend(
has_sink=has_sink,
use_sparse=use_sparse,
use_mm_prefix=use_mm_prefix,
attn_type=attn_type,
attn_type=attn_type or AttentionType.DECODER,
)

return _cached_get_attn_backend(
backend=backend_enum,
attn_selector_config=attn_selector_config,
)


@cache
def _cached_get_attn_backend(
backend,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: CacheDType | None,
block_size: int | None,
use_mla: bool = False,
has_sink: bool = False,
use_sparse: bool = False,
use_mm_prefix: bool = False,
attn_type: str | None = None,
attn_selector_config: AttentionSelectorConfig,
) -> type[AttentionBackend]:
from vllm.platforms import current_platform

Expand All @@ -82,29 +102,13 @@ def _cached_get_attn_backend(
)
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
True, # use_v1
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
use_v1=True, # use_v1
**attn_selector_config._asdict(),
)
else:
attention_cls = current_platform.get_attn_backend_cls(
backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type,
attn_selector_config=attn_selector_config,
)
if not attention_cls:
raise ValueError(
Expand Down
15 changes: 4 additions & 11 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
logger = init_logger(__name__)

if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
else:
VllmConfig = None
Expand Down Expand Up @@ -126,21 +127,13 @@ def get_device_name(cls, device_id: int = 0) -> str:
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
if attn_selector_config.use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on CPU.")
return AttentionBackendEnum.CPU_ATTN.get_path()

Expand Down
73 changes: 14 additions & 59 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

# import custom ops, trigger op registration
import vllm._C # noqa
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
Expand All @@ -23,6 +22,7 @@
from .interface import DeviceCapability, Platform, PlatformEnum

if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
Expand Down Expand Up @@ -275,38 +275,24 @@ def get_vit_attn_backend(
@classmethod
def get_valid_backends(
cls,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability: DeviceCapability,
attn_selector_config: "AttentionSelectorConfig",
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}

backend_priorities = _get_backend_priorities(use_mla, device_capability)
backend_priorities = _get_backend_priorities(
attn_selector_config.use_mla, device_capability
)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons_i = ["ImportError"]
Expand All @@ -321,19 +307,8 @@ def get_valid_backends(
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if attn_type is None:
attn_type = AttentionType.DECODER

device_capability = cls.get_device_capability()
assert device_capability is not None

Expand All @@ -342,16 +317,8 @@ def get_attn_backend_cls(
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
Expand All @@ -367,16 +334,8 @@ def get_attn_backend_cls(
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
device_capability,
attn_type,
device_capability=device_capability,
attn_selector_config=attn_selector_config,
)
reasons_str = (
"{"
Expand All @@ -386,11 +345,7 @@ def get_attn_backend_cls(
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
Expand Down
12 changes: 2 additions & 10 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup

from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -232,15 +232,7 @@ def get_vit_attn_backend(
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
"""Get the attention backend class of a device."""
return ""
Expand Down
22 changes: 9 additions & 13 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .interface import DeviceCapability, Platform, PlatformEnum

if TYPE_CHECKING:
from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig

logger = init_logger(__name__)
Expand Down Expand Up @@ -208,21 +209,16 @@ def get_vit_attn_backend(
@classmethod
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
use_mm_prefix,
attn_type: str | None = None,
selected_backend: "AttentionBackendEnum",
attn_selector_config: "AttentionSelectorConfig",
) -> str:
from vllm._aiter_ops import rocm_aiter_ops

if use_sparse:
if kv_cache_dtype.startswith("fp8"):
block_size = attn_selector_config.block_size
kv_cache_dtype = attn_selector_config.kv_cache_dtype

if attn_selector_config.use_sparse:
if kv_cache_dtype and kv_cache_dtype.startswith("fp8"):
raise ValueError(
"ROCMAiterMLASparseBackend doesn't support fp8 kv_cache_dtype."
)
Expand All @@ -232,7 +228,7 @@ def get_attn_backend_cls(
logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()

if use_mla:
if attn_selector_config.use_mla:
if selected_backend is None:
selected_backend = (
AttentionBackendEnum.ROCM_AITER_MLA
Expand Down
13 changes: 3 additions & 10 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
if TYPE_CHECKING:
from typing import TypeAlias

from vllm.attention.selector import AttentionSelectorConfig
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -57,17 +58,9 @@ def import_kernels(cls) -> None:
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
block_size: int,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
use_mm_prefix: bool,
attn_type: str | None = None,
attn_selector_config: "AttentionSelectorConfig",
) -> str:
if use_sparse:
if attn_selector_config.use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
Expand Down
Loading