diff --git a/components/src/dynamo/sglang/main.py b/components/src/dynamo/sglang/main.py index 024226fdbb..18807e4d3a 100644 --- a/components/src/dynamo/sglang/main.py +++ b/components/src/dynamo/sglang/main.py @@ -382,16 +382,24 @@ async def init_multimodal_encode_worker(runtime: DistributedRuntime, config: Con await pd_worker_client.wait_for_instances() - tasks = [ - generate_endpoint.serve_endpoint( - handler.generate, - graceful_shutdown=True, - metrics_labels=[("model", server_args.served_model_name)], - ) - ] + ready_event = asyncio.Event() try: - await asyncio.gather(*tasks) + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + graceful_shutdown=True, + metrics_labels=[("model", server_args.served_model_name)], + ), + register_llm_with_readiness_gate( + None, # encode worker doesn't have engine + generate_endpoint, + server_args, + dynamo_args, + input_type=ModelInput.Text, + readiness_gate=ready_event, + ), + ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") raise @@ -425,11 +433,24 @@ async def init_multimodal_worker(runtime: DistributedRuntime, config: Config): await handler.async_init() + health_check_payload = SglangHealthCheckPayload(engine).to_dict() + ready_event = asyncio.Event() + try: - await generate_endpoint.serve_endpoint( - handler.generate, - metrics_labels=[("model", server_args.served_model_name)], - graceful_shutdown=True, + await asyncio.gather( + generate_endpoint.serve_endpoint( + handler.generate, + metrics_labels=[("model", server_args.served_model_name)], + graceful_shutdown=True, + health_check_payload=health_check_payload, + ), + register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + readiness_gate=ready_event, + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}") @@ -454,6 +475,7 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co await handler.async_init() health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict() + ready_event = asyncio.Event() try: await asyncio.gather( @@ -462,7 +484,14 @@ async def init_multimodal_prefill_worker(runtime: DistributedRuntime, config: Co graceful_shutdown=True, metrics_labels=[("model", server_args.served_model_name)], health_check_payload=health_check_payload, - ) + ), + register_llm_with_readiness_gate( + engine, + generate_endpoint, + server_args, + dynamo_args, + readiness_gate=ready_event, + ), ) except Exception as e: logging.error(f"Failed to serve endpoints: {e}")