|
32 | 32 | get_scheduler_metadata, |
33 | 33 | reshape_and_cache_flash, |
34 | 34 | ) |
35 | | -from vllm.config import VllmConfig, get_layers_from_vllm_config |
| 35 | +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config |
36 | 36 | from vllm.config.cache import CacheDType |
37 | 37 | from vllm.distributed.parallel_state import get_dcp_group |
38 | 38 | from vllm.logger import init_logger |
|
56 | 56 | class FlashAttentionBackend(AttentionBackend): |
57 | 57 | accept_output_buffer: bool = True |
58 | 58 | supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] |
59 | | - # NOTE(tdoublep): while in principle, FA supports |
60 | | - # MultipleOf(16), these are the block sizes that do not |
61 | | - # suffer from the NaN propagation problem described here: |
62 | | - # https://github.com/Dao-AILab/flash-attention/issues/1974 |
63 | | - supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64] |
| 59 | + |
| 60 | + @staticmethod |
| 61 | + def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: |
| 62 | + vllm_config = get_current_vllm_config() |
| 63 | + model_config = vllm_config.model_config |
| 64 | + cache_config = vllm_config.cache_config |
| 65 | + if ( |
| 66 | + model_config |
| 67 | + and model_config.is_hybrid |
| 68 | + and ( |
| 69 | + cache_config.mamba_ssm_cache_dtype == "float32" |
| 70 | + or cache_config.mamba_cache_dtype == "float32" |
| 71 | + ) |
| 72 | + ): |
| 73 | + # NOTE(tdoublep): while in principle, FA supports |
| 74 | + # MultipleOf(16), these are the block sizes that do not |
| 75 | + # suffer from the NaN propagation problem described here: |
| 76 | + # https://github.com/Dao-AILab/flash-attention/issues/1974 |
| 77 | + return [16, 32, 64] |
| 78 | + return [MultipleOf(16)] |
64 | 79 |
|
65 | 80 | @staticmethod |
66 | 81 | def get_name() -> str: |
|
0 commit comments