Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions tests/ut/core/test_schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def setUp(self):
max_model_len=8192,
is_multimodal_model=False,
send_delta_data=False,
is_encoder_decoder=False,
)

def test_initialize_from_config_with_default(self):
# No additional config given, check the default value here.
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.basic_scheduler_config, {}, 8192, False)
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
self.assertEqual(ascend_config.scheduler_cls,
Expand All @@ -53,6 +54,8 @@ def test_initialize_from_config_with_override(self):
max_long_partial_prefills=1,
long_prefill_token_threshold=512,
),
8192,
False,
)
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
Expand All @@ -72,6 +75,8 @@ def test_not_implemented_policy(self):
max_num_batched_tokens=8192,
max_model_len=2048,
),
8192,
False,
)
self.assertIn(
"currently AscendScheduler only supports fcfs policy",
Expand All @@ -80,14 +85,21 @@ def test_not_implemented_policy(self):

def test_no_override(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.basic_scheduler_config, {}, 8192, False)
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192)

def test_valid_config_with_multimodal(self):
config = AscendSchedulerConfig.initialize_from_config(
SchedulerConfig(is_multimodal_model=True,
max_num_batched_tokens=8192), {})
SchedulerConfig(
is_multimodal_model=True,
max_num_batched_tokens=8192,
is_encoder_decoder=False,
),
{},
8192,
False,
)
self.assertTrue(config.is_multimodal_model)

def test_valid_config_with_chunked_prefill(self):
Expand All @@ -98,6 +110,8 @@ def test_valid_config_with_chunked_prefill(self):
max_num_batched_tokens=8192,
max_model_len=8192,
),
8192,
False,
)
self.assertEqual(ascend_config.max_num_batched_tokens, 8192)
self.assertEqual(ascend_config.max_model_len, 8192)
Expand All @@ -112,6 +126,8 @@ def test_invalid_config_without_chunked_prefill(self):
max_num_batched_tokens=2048,
max_model_len=8192,
),
8192,
False,
)
self.assertIn(
"Ascend scheduler is enabled without chunked prefill feature",
Expand All @@ -129,6 +145,8 @@ def test_initialize_from_config_with_pd_transfer(self):
max_num_batched_tokens=8192,
max_model_len=4096,
),
8192,
False,
)
self.assertEqual(ascend_config.enable_pd_transfer, True)
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
30 changes: 20 additions & 10 deletions vllm_ascend/core/schedule_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,28 @@ def initialize_from_config(
cls,
vllm_scheduler_config: SchedulerConfig,
ascend_scheduler_config,
max_model_len: int,
is_encoder_decoder: bool = False,
):
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init
}
scheduler_config = {}
for field in fields(vllm_scheduler_config):
if not field.init:
continue
try:
scheduler_config[field.name] = getattr(vllm_scheduler_config,
field.name)
except AttributeError:
pass

scheduler_config["max_model_len"] = max_model_len
scheduler_config["is_encoder_decoder"] = is_encoder_decoder
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config[
"scheduler_cls"] = "vllm_ascend.core.scheduler.AscendScheduler"
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
Expand Down Expand Up @@ -78,13 +88,13 @@ def __post_init__(self, *args) -> None:
self.max_long_partial_prefills = 1
self.long_prefill_token_threshold = MAX_INT

if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if (self.long_prefill_token_threshold is None
or self.long_prefill_token_threshold <= 0):
if self.max_model_len is None:
self.long_prefill_token_threshold = MAX_INT
else:
self.long_prefill_token_threshold = \
max(1, int(self.max_model_len * 0.04))
self.long_prefill_token_threshold = max(
1, int(self.max_model_len * 0.04))

if self.max_long_partial_prefills < 0:
raise ValueError(
Expand Down
Loading
Loading