From b23329e6195081fd287825b4b1165897bf7d742c Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 2 Dec 2025 10:12:44 +0100 Subject: [PATCH 1/4] Remove default values from `InitVar`s so that they're not stored Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 31 +++++++++++++------------------ vllm/config/vllm.py | 4 +++- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 88f3e62fbd4e..9bc867bfd1af 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -28,6 +28,19 @@ class SchedulerConfig: """Scheduler configuration.""" + max_model_len: InitVar[int] + """Maximum length of a sequence (including prompt and generated text). + + Note: This is stored in the ModelConfig, and is used only here to + provide fallbacks and validate other attributes.""" + + is_encoder_decoder: InitVar[bool] + """True if the model is an encoder-decoder model. + + Note: This is stored in the ModelConfig, and is used only here to + disable chunked prefill and prefix caching for encoder-decoder models. + """ + DEFAULT_MAX_NUM_BATCHED_TOKENS: ClassVar[int] = 2048 DEFAULT_MAX_NUM_SEQS: ClassVar[int] = 128 @@ -73,19 +86,6 @@ class SchedulerConfig: is_multimodal_model: bool = False """True if the model is multimodal.""" - max_model_len: InitVar[int] = 8192 - """Maximum length of a sequence (including prompt and generated text). - - Note: This is stored in the ModelConfig, and is used only here to - provide fallbacks and validate other attributes.""" - - is_encoder_decoder: InitVar[bool] = False - """True if the model is an encoder-decoder model. - - Note: This is stored in the ModelConfig, and is used only here to - disable chunked prefill and prefix caching for encoder-decoder models. - """ - # TODO (ywang96): Make this configurable. max_num_encoder_input_tokens: int = Field(init=False) """Multimodal encoder compute budget, only used in V1. @@ -274,8 +274,3 @@ def verify_max_model_len(self, max_model_len: int) -> Self: ) return self - - def __getattribute__(self, name: str) -> Any: - if name == "max_model_len" or name == "is_encoder_decoder": - raise AttributeError(f"{name} is an init-only parameter. ") - return object.__getattribute__(self, name) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 4542866aa166..52762dd99803 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -154,6 +154,8 @@ def enable_fusion(cfg: "VllmConfig") -> bool: OptimizationLevel.O3: OPTIMIZATION_LEVEL_03, } +SCHEDULER_CONFIG_FACTORY = lambda: SchedulerConfig(8192, False) + @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @@ -170,7 +172,7 @@ class VllmConfig: """Cache configuration.""" parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) """Parallel configuration.""" - scheduler_config: SchedulerConfig = Field(default_factory=SchedulerConfig) + scheduler_config: SchedulerConfig = Field(default_factory=SCHEDULER_CONFIG_FACTORY) """Scheduler configuration.""" device_config: DeviceConfig = Field(default_factory=DeviceConfig) """Device configuration.""" From df11d0c37bab5c1ecef906b76015e4f3dec80512 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 2 Dec 2025 10:27:31 +0100 Subject: [PATCH 2/4] SchedulerConfig should own its default factory Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/scheduler.py | 7 +++++++ vllm/config/vllm.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 9bc867bfd1af..0ab37c334acd 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -141,6 +141,13 @@ class SchedulerConfig: while a larger value (e.g., 10) reduces host overhead and may increase throughput by batching multiple tokens before sending.""" + @staticmethod + def default_factory(): + """ + Factory method to create `SchedulerConfig` with default values for `InitVar`s. + """ + return SchedulerConfig(max_model_len=8192, is_encoder_decoder=False) + def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: if self.async_scheduling: diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 52762dd99803..8692425094da 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -154,8 +154,6 @@ def enable_fusion(cfg: "VllmConfig") -> bool: OptimizationLevel.O3: OPTIMIZATION_LEVEL_03, } -SCHEDULER_CONFIG_FACTORY = lambda: SchedulerConfig(8192, False) - @config @dataclass(config=ConfigDict(arbitrary_types_allowed=True)) @@ -172,7 +170,9 @@ class VllmConfig: """Cache configuration.""" parallel_config: ParallelConfig = Field(default_factory=ParallelConfig) """Parallel configuration.""" - scheduler_config: SchedulerConfig = Field(default_factory=SCHEDULER_CONFIG_FACTORY) + scheduler_config: SchedulerConfig = Field( + default_factory=SchedulerConfig.default_factory, + ) """Scheduler configuration.""" device_config: DeviceConfig = Field(default_factory=DeviceConfig) """Device configuration.""" From 2a36a5e2b427f152708e8e20dd9aa7bc861c0f56 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 2 Dec 2025 10:46:26 +0100 Subject: [PATCH 3/4] Update call sites Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- benchmarks/benchmark_ngram_proposer.py | 5 +++- tests/compile/test_fusion_attn.py | 15 +++++++---- tests/lora/test_worker.py | 25 +++++++++++++------ tests/v1/attention/utils.py | 2 ++ tests/v1/core/test_kv_cache_utils.py | 11 ++++++-- tests/v1/core/test_scheduler.py | 13 +++++----- tests/v1/core/utils.py | 15 +++++------ tests/v1/cudagraph/test_cudagraph_dispatch.py | 4 ++- tests/v1/engine/test_engine_core.py | 13 +++++----- tests/v1/kv_connector/unit/utils.py | 13 +++++----- tests/v1/spec_decode/test_eagle.py | 5 +++- tests/v1/spec_decode/test_mtp.py | 5 +++- tests/v1/tpu/worker/test_tpu_model_runner.py | 11 ++++---- tests/v1/worker/test_gpu_model_runner.py | 20 ++++++++------- vllm/config/scheduler.py | 8 ++++-- 15 files changed, 105 insertions(+), 60 deletions(-) diff --git a/benchmarks/benchmark_ngram_proposer.py b/benchmarks/benchmark_ngram_proposer.py index dedb564fffac..cac401456b62 100644 --- a/benchmarks/benchmark_ngram_proposer.py +++ b/benchmarks/benchmark_ngram_proposer.py @@ -108,7 +108,10 @@ def benchmark_batched_propose(args): device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) # monkey patch vllm.v1.worker.gpu_model_runner.get_pp_group diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index dbe12dc5de70..4d213e030edb 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -318,13 +318,18 @@ def test_attention_quant_pattern( torch.set_default_dtype(dtype) torch.manual_seed(42) + model_config = ModelConfig( + model=model_name, + max_model_len=2048, + dtype=dtype, + ) vllm_config = VllmConfig( - model_config=ModelConfig( - model=model_name, - max_model_len=2048, - dtype=dtype, + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=1024, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ), - scheduler_config=SchedulerConfig(max_num_seqs=1024), compilation_config=CompilationConfig( mode=CompilationMode.VLLM_COMPILE, custom_ops=custom_ops_list, diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index b163559a9414..54059ec56190 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -33,14 +33,16 @@ def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): lora_requests, lora_mapping ) + model_config = ModelConfig( + MODEL_PATH, + seed=0, + dtype="float16", + max_model_len=127, + enforce_eager=True, + ) + vllm_config = VllmConfig( - model_config=ModelConfig( - MODEL_PATH, - seed=0, - dtype="float16", - max_model_len=127, - enforce_eager=True, - ), + model_config=model_config, load_config=LoadConfig( download_dir=None, load_format="dummy", @@ -50,7 +52,14 @@ def set_active_loras(worker: Worker, lora_requests: list[LoRARequest]): tensor_parallel_size=1, data_parallel_size=1, ), - scheduler_config=SchedulerConfig("generate", 32, 32, 32), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + runner_type="generate", + max_num_batched_tokens=32, + max_num_seqs=32, + max_num_partial_prefills=32, + ), device_config=DeviceConfig("cuda"), cache_config=CacheConfig( block_size=16, diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index df3d53332c7c..6cab129c116c 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -185,6 +185,8 @@ def create_vllm_config( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=enable_chunked_prefill, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ) device_config = DeviceConfig() diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 58a7a2692bfc..fd5cf6d3e74a 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -1128,7 +1128,11 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len) dtype="float16", max_model_len=max_model_len, ) - scheduler_config = SchedulerConfig(max_num_batched_tokens=32768) + scheduler_config = SchedulerConfig( + max_num_batched_tokens=32768, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ) vllm_config = VllmConfig( model_config=model_config, @@ -1163,7 +1167,10 @@ def test_get_max_concurrency_for_kv_cache_config(): max_model_len=max_model_len, ) scheduler_config = SchedulerConfig( - max_num_batched_tokens=1024, enable_chunked_prefill=True + max_num_batched_tokens=1024, + enable_chunked_prefill=True, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, ) vllm_config = VllmConfig( diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 0051c11d18d8..c6c4a5085bff 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1508,6 +1508,12 @@ def create_scheduler_with_priority( Returns: {class}`Scheduler` instance with priority scheduling """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + ) if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -1517,14 +1523,9 @@ def create_scheduler_with_priority( long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=True, + is_encoder_decoder=model_config.is_encoder_decoder, policy="priority", # Enable priority scheduling ) - model_config = ModelConfig( - model=model, - trust_remote_code=True, - dtype="float16", - seed=42, - ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, diff --git a/tests/v1/core/utils.py b/tests/v1/core/utils.py index 7537c7a60476..f5ba613d38db 100644 --- a/tests/v1/core/utils.py +++ b/tests/v1/core/utils.py @@ -69,6 +69,13 @@ def create_scheduler( Returns: {class}`Scheduler` instance """ + model_config = ModelConfig( + model=model, + trust_remote_code=True, + dtype="float16", + seed=42, + skip_tokenizer_init=skip_tokenizer_init, + ) if max_model_len is None: max_model_len = max_num_batched_tokens scheduler_config = SchedulerConfig( @@ -79,13 +86,7 @@ def create_scheduler( disable_chunked_mm_input=disable_chunked_mm_input, enable_chunked_prefill=enable_chunked_prefill, async_scheduling=async_scheduling, - ) - model_config = ModelConfig( - model=model, - trust_remote_code=True, - dtype="float16", - seed=42, - skip_tokenizer_init=skip_tokenizer_init, + is_encoder_decoder=model_config.is_encoder_decoder, ) # Cache config, optionally force APC cache_config = CacheConfig( diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 314e7094ef97..b86534d3d438 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -40,7 +40,9 @@ def _create_vllm_config( ) -> MagicMock: mock_config = MagicMock(spec=VllmConfig) mock_config.compilation_config = compilation_config - mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs) + mock_config.scheduler_config = SchedulerConfig.default_factory( + max_num_seqs=max_num_seqs, + ) mock_config.parallel_config = ParallelConfig() mock_config.speculative_config = None # No speculative decoding if not lora_config: diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index 3ba8ab26f552..48be8c15aba9 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -484,12 +484,6 @@ def test_encoder_instance_zero_kv_cache( vision encoder, so they don't need KV cache for text generation. """ # Form vllm config - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - disable_hybrid_kv_cache_manager=True, - ) model_config = ModelConfig( model="llava-hf/llava-1.5-7b-hf", # Multimodal model enforce_eager=True, @@ -497,6 +491,13 @@ def test_encoder_instance_zero_kv_cache( dtype="float16", seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + disable_hybrid_kv_cache_manager=True, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=gpu_memory_utilization, diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index f35f91bb3adf..98f1f44923b1 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -92,18 +92,19 @@ def create_vllm_config( enable_permute_local_kv: bool = False, ) -> VllmConfig: """Initialize VllmConfig For Testing.""" - scheduler_config = SchedulerConfig( - max_num_seqs=max_num_seqs, - max_num_batched_tokens=max_num_batched_tokens, - max_model_len=max_model_len, - enable_chunked_prefill=enable_chunked_prefill, - ) model_config = ModelConfig( model=model, trust_remote_code=True, dtype="float16", seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + is_encoder_decoder=model_config.is_encoder_decoder, + ) # Cache config, optionally force APC cache_config = CacheConfig( block_size=block_size, diff --git a/tests/v1/spec_decode/test_eagle.py b/tests/v1/spec_decode/test_eagle.py index 9436ab471c21..616e57de339e 100644 --- a/tests/v1/spec_decode/test_eagle.py +++ b/tests/v1/spec_decode/test_eagle.py @@ -66,7 +66,10 @@ def _create_proposer( device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) diff --git a/tests/v1/spec_decode/test_mtp.py b/tests/v1/spec_decode/test_mtp.py index c5c0491abaf7..3b8813ceb818 100644 --- a/tests/v1/spec_decode/test_mtp.py +++ b/tests/v1/spec_decode/test_mtp.py @@ -51,7 +51,10 @@ def _create_mtp_proposer(num_speculative_tokens: int) -> EagleProposer: device_config=DeviceConfig(device=current_platform.device_type), parallel_config=ParallelConfig(), load_config=LoadConfig(), - scheduler_config=SchedulerConfig(), + scheduler_config=SchedulerConfig( + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), ) return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 7b3a07b4e12a..cfc06666e798 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -26,16 +26,17 @@ def get_vllm_config(): - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - ) model_config = ModelConfig( model="facebook/opt-125m", dtype="bfloat16", # TPUs typically use bfloat16 seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.9, diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 89669ee8b71a..0439bef1226e 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -79,16 +79,17 @@ def initialize_kv_cache(runner: GPUModelRunner): def get_vllm_config(): - scheduler_config = SchedulerConfig( - max_num_seqs=10, - max_num_batched_tokens=512, - max_model_len=512, - ) model_config = ModelConfig( model="facebook/opt-125m", dtype="float16", seed=42, ) + scheduler_config = SchedulerConfig( + max_num_seqs=10, + max_num_batched_tokens=512, + max_model_len=512, + is_encoder_decoder=model_config.is_encoder_decoder, + ) cache_config = CacheConfig( block_size=BLOCK_SIZE, gpu_memory_utilization=0.9, @@ -784,14 +785,15 @@ def test_hybrid_attention_mamba_tensor_shapes(monkeypatch): initialize_model_parallel(tensor_model_parallel_size=1) torch.set_default_dtype(torch.float16) + model_config = ModelConfig( + model="ibm-granite/granite-4.0-tiny-preview", + dtype="float16", + ) scheduler_config = SchedulerConfig( max_num_seqs=10, max_num_batched_tokens=512, max_model_len=512, - ) - model_config = ModelConfig( - model="ibm-granite/granite-4.0-tiny-preview", - dtype="float16", + is_encoder_decoder=model_config.is_encoder_decoder, ) cache_config = CacheConfig( block_size=BLOCK_SIZE, diff --git a/vllm/config/scheduler.py b/vllm/config/scheduler.py index 0ab37c334acd..0215fe44e107 100644 --- a/vllm/config/scheduler.py +++ b/vllm/config/scheduler.py @@ -142,11 +142,15 @@ class SchedulerConfig: by batching multiple tokens before sending.""" @staticmethod - def default_factory(): + def default_factory(**kwargs): """ Factory method to create `SchedulerConfig` with default values for `InitVar`s. """ - return SchedulerConfig(max_model_len=8192, is_encoder_decoder=False) + if "max_model_len" not in kwargs: + kwargs["max_model_len"] = 8192 + if "is_encoder_decoder" not in kwargs: + kwargs["is_encoder_decoder"] = False + return SchedulerConfig(**kwargs) def get_scheduler_cls(self) -> type["SchedulerInterface"]: if self.scheduler_cls is None: From b4fc6e37f4180c842e25b06ce24386d6099b0038 Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 2 Dec 2025 11:04:12 +0100 Subject: [PATCH 4/4] Add test Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- tests/test_config.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_config.py b/tests/test_config.py index 76e0d94425fa..b7ed68fea92a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -6,12 +6,14 @@ from unittest.mock import patch import pytest +from pydantic import ValidationError from vllm.compilation.backends import VllmBackend from vllm.config import ( CompilationConfig, ModelConfig, PoolerConfig, + SchedulerConfig, VllmConfig, update_config, ) @@ -1095,3 +1097,14 @@ def test_vllm_config_explicit_overrides(): # Other fields should still use defaults assert config.compilation_config.mode == CompilationMode.VLLM_COMPILE assert config.compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE + + +def test_scheduler_config_init(): + with pytest.raises(ValidationError): + # Positional InitVars missing + # (InitVars cannot have defaults otherwise they will become attributes) + SchedulerConfig() + + with pytest.raises(AttributeError): + # InitVar does not become an attribute + print(SchedulerConfig.default_factory().max_model_len)