diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index 3d581a300b6a..1ff30de31bbe 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -232,7 +232,7 @@ def make_long_completion_request(): @pytest.mark.asyncio async def test_health_check_engine_dead_error(): # Import the health function directly to test it in isolation - from vllm.entrypoints.openai.api_server import health + from vllm.entrypoints.serve.instrumentator.health import health # Create a mock request that simulates what FastAPI would provide mock_request = Mock(spec=Request) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 154cdeb42a3e..b59f7120551e 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -118,6 +118,7 @@ async def init_app( ) ) app.state.engine_client = engine + app.state.args = args return app diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index cdc316b65ba7..2fa6afa2bacb 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -20,21 +20,15 @@ from typing import Annotated, Any, Literal import model_hosting_container_standards.sagemaker as sagemaker_standards -import prometheus_client import pydantic -import regex as re import uvloop from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse -from prometheus_client import make_asgi_app -from prometheus_fastapi_instrumentator import Instrumentator from starlette.concurrency import iterate_in_threadpool from starlette.datastructures import URL, Headers, MutableHeaders, State -from starlette.routing import Mount from starlette.types import ASGIApp, Message, Receive, Scope, Send -from typing_extensions import assert_never import vllm.envs as envs from vllm.config import VllmConfig @@ -56,17 +50,11 @@ ChatCompletionResponse, CompletionRequest, CompletionResponse, - DetokenizeRequest, - DetokenizeResponse, ErrorInfo, ErrorResponse, - GenerateRequest, - GenerateResponse, ResponsesRequest, ResponsesResponse, StreamingResponsesResponse, - TokenizeRequest, - TokenizeResponse, TranscriptionRequest, TranscriptionResponseVariant, TranslationRequest, @@ -80,8 +68,6 @@ OpenAIServingModels, ) from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses -from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization -from vllm.entrypoints.openai.serving_tokens import ServingTokens from vllm.entrypoints.openai.serving_transcription import ( OpenAIServingTranscription, OpenAIServingTranslation, @@ -92,6 +78,11 @@ from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling from vllm.entrypoints.pooling.score.serving import ServingScores +from vllm.entrypoints.serve.disagg.serving import ServingTokens +from vllm.entrypoints.serve.elastic_ep.middleware import ( + ScalingMiddleware, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer from vllm.entrypoints.utils import ( cli_env_setup, @@ -109,8 +100,6 @@ from vllm.utils.gc_utils import freeze_gc_heap from vllm.utils.network_utils import is_valid_ipv6_address from vllm.utils.system_utils import decorate_logs, set_ulimit -from vllm.v1.engine.exceptions import EngineDeadError -from vllm.v1.metrics.prometheus import get_prometheus_registry from vllm.version import __version__ as VLLM_VERSION prometheus_multiproc_dir: tempfile.TemporaryDirectory @@ -245,39 +234,6 @@ async def build_async_engine_client_from_engine_args( router = APIRouter() -class PrometheusResponse(Response): - media_type = prometheus_client.CONTENT_TYPE_LATEST - - -def mount_metrics(app: FastAPI): - """Mount prometheus metrics to a FastAPI app.""" - - registry = get_prometheus_registry() - - # `response_class=PrometheusResponse` is needed to return an HTTP response - # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" - # instead of the default "application/json" which is incorrect. - # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 - Instrumentator( - excluded_handlers=[ - "/metrics", - "/health", - "/load", - "/ping", - "/version", - "/server_info", - ], - registry=registry, - ).add().instrument(app).expose(app, response_class=PrometheusResponse) - - # Add prometheus asgi middleware to route /metrics requests - metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) - - # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile("^/metrics(?P.*)$") - app.routes.append(metrics_route) - - def base(request: Request) -> OpenAIServing: # Reuse the existing instance return tokenization(request) @@ -323,16 +279,6 @@ def generate_tokens(request: Request) -> ServingTokens | None: return request.app.state.serving_tokens -@router.get("/health", response_class=Response) -async def health(raw_request: Request) -> Response: - """Health check.""" - try: - await engine_client(raw_request).check_health() - return Response(status_code=200) - except EngineDeadError: - return Response(status_code=503) - - @router.get("/load") async def get_server_load_metrics(request: Request): # This endpoint returns the current server load metrics. @@ -352,167 +298,6 @@ async def get_server_load_metrics(request: Request): return JSONResponse(content={"server_load": request.app.state.server_load_metrics}) -@router.post("/pause") -async def pause_generation( - raw_request: Request, - wait_for_inflight_requests: bool = Query(False), - clear_cache: bool = Query(True), -) -> JSONResponse: - """Pause generation requests to allow weight updates. - - Args: - wait_for_inflight_requests: When ``True`` waits for in-flight - requests to finish before pausing. When ``False`` (default), - aborts any in-flight requests immediately. - clear_cache: Whether to clear KV/prefix caches after draining. - """ - - engine = engine_client(raw_request) - - try: - await engine.pause_generation( - wait_for_inflight_requests=wait_for_inflight_requests, - clear_cache=clear_cache, - ) - return JSONResponse( - content={"status": "paused"}, - status_code=HTTPStatus.OK.value, - ) - - except ValueError as err: - return JSONResponse( - content={"error": str(err)}, - status_code=HTTPStatus.BAD_REQUEST.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to pause generation") - return JSONResponse( - content={"error": f"Failed to pause generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.post("/resume") -async def resume_generation(raw_request: Request) -> JSONResponse: - """Resume generation after a pause.""" - - engine = engine_client(raw_request) - - try: - await engine.resume_generation() - return JSONResponse( - content={"status": "resumed"}, - status_code=HTTPStatus.OK.value, - ) - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to resume generation") - return JSONResponse( - content={"error": f"Failed to resume generation: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - -@router.get("/is_paused") -async def is_paused(raw_request: Request) -> JSONResponse: - """Return the current pause status.""" - - engine = engine_client(raw_request) - - try: - paused = await engine.is_paused() - except Exception as err: # pragma: no cover - defensive - logger.exception("Failed to fetch pause status") - return JSONResponse( - content={"error": f"Failed to fetch pause status: {err}"}, - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, - ) - - return JSONResponse(content={"is_paused": paused}) - - -@router.post( - "/tokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def tokenize(request: TokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_tokenize(request, raw_request) - except NotImplementedError as e: - raise HTTPException( - status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) - ) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, TokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -@router.post( - "/detokenize", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -async def detokenize(request: DetokenizeRequest, raw_request: Request): - handler = tokenization(raw_request) - - try: - generator = await handler.create_detokenize(request, raw_request) - except OverflowError as e: - raise RequestValidationError(errors=[str(e)]) from e - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - elif isinstance(generator, DetokenizeResponse): - return JSONResponse(content=generator.model_dump()) - - assert_never(generator) - - -def maybe_register_tokenizer_info_endpoint(args): - """Conditionally register the tokenizer info endpoint if enabled.""" - if getattr(args, "enable_tokenizer_info_endpoint", False): - - @router.get("/tokenizer_info") - async def get_tokenizer_info(raw_request: Request): - """Get comprehensive tokenizer information.""" - result = await tokenization(raw_request).get_tokenizer_info() - return JSONResponse( - content=result.model_dump(), - status_code=result.error.code - if isinstance(result, ErrorResponse) - else 200, - ) - - @router.get("/v1/models") async def show_available_models(raw_request: Request): handler = models(raw_request) @@ -898,33 +683,6 @@ async def reset_mm_cache(raw_request: Request): await engine_client(raw_request).reset_mm_cache() return Response(status_code=200) - @router.post("/sleep") - async def sleep(raw_request: Request): - # get POST params - level = raw_request.query_params.get("level", "1") - await engine_client(raw_request).sleep(int(level)) - # FIXME: in v0 with frontend multiprocessing, the sleep command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.post("/wake_up") - async def wake_up(raw_request: Request): - tags = raw_request.query_params.getlist("tags") - if tags == []: - # set to None to wake up all tags if no tags are provided - tags = None - logger.info("wake up the engine with tags: %s", tags) - await engine_client(raw_request).wake_up(tags) - # FIXME: in v0 with frontend multiprocessing, the wake-up command - # is sent but does not finish yet when we return a response. - return Response(status_code=200) - - @router.get("/is_sleeping") - async def is_sleeping(raw_request: Request): - logger.info("check whether the engine is sleeping") - is_sleeping = await engine_client(raw_request).is_sleeping() - return JSONResponse(content={"is_sleeping": is_sleeping}) - @router.post("/collective_rpc") async def collective_rpc(raw_request: Request): try: @@ -952,138 +710,13 @@ async def collective_rpc(raw_request: Request): return Response(status_code=200) response: list[Any] = [] for result in results: - if result is None or isinstance(result, (dict, list)): + if result is None or isinstance(result, dict | list): response.append(result) else: response.append(str(result)) return JSONResponse(content={"results": response}) -@router.post( - "/scale_elastic_ep", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"model": dict}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -async def scale_elastic_ep(raw_request: Request): - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 - - new_data_parallel_size = body.get("new_data_parallel_size") - drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes - - if new_data_parallel_size is None: - raise HTTPException( - status_code=400, detail="new_data_parallel_size is required" - ) - - if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: - raise HTTPException( - status_code=400, detail="new_data_parallel_size must be a positive integer" - ) - - if not isinstance(drain_timeout, int) or drain_timeout <= 0: - raise HTTPException( - status_code=400, detail="drain_timeout must be a positive integer" - ) - - # Set scaling flag to prevent new requests - global _scaling_elastic_ep - _scaling_elastic_ep = True - client = engine_client(raw_request) - try: - await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) - return JSONResponse( - { - "message": f"Scaled to {new_data_parallel_size} data parallel engines", - } - ) - except TimeoutError as e: - raise HTTPException( - status_code=408, - detail="Scale failed due to request drain timeout " - f"after {drain_timeout} seconds", - ) from e - except Exception as e: - logger.error("Scale failed: %s", e) - raise HTTPException(status_code=500, detail="Scale failed") from e - finally: - _scaling_elastic_ep = False - - -@router.post("/is_scaling_elastic_ep") -async def is_scaling_elastic_ep(raw_request: Request): - return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep}) - - -@router.post( - "/inference/v1/generate", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def generate(request: GenerateRequest, raw_request: Request): - handler = generate_tokens(raw_request) - if handler is None: - return base(raw_request).create_error_response( - message="The model does not support generate tokens API" - ) - try: - generator = await handler.serve_tokens(request, raw_request) - except Exception as e: - raise HTTPException( - status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) - ) from e - if isinstance(generator, ErrorResponse): - return JSONResponse( - content=generator.model_dump(), status_code=generator.error.code - ) - - elif isinstance(generator, GenerateResponse): - return JSONResponse(content=generator.model_dump()) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -if envs.VLLM_TORCH_PROFILER_DIR: - logger.warning_once( - "Torch Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -elif envs.VLLM_TORCH_CUDA_PROFILE: - logger.warning_once( - "CUDA Profiler is enabled in the API server. This should ONLY be " - "used for local development!" - ) -if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: - - @router.post("/start_profile") - async def start_profile(raw_request: Request): - logger.info("Starting profiler...") - await engine_client(raw_request).start_profile() - logger.info("Profiler started.") - return Response(status_code=200) - - @router.post("/stop_profile") - async def stop_profile(raw_request: Request): - logger.info("Stopping profiler...") - await engine_client(raw_request).stop_profile() - logger.info("Profiler stopped.") - return Response(status_code=200) - - def load_log_config(log_config_file: str | None) -> dict | None: if not log_config_file: return None @@ -1176,41 +809,6 @@ async def send_with_request_id(message: Message) -> None: return self.app(scope, receive, send_with_request_id) -# Global variable to track scaling state -_scaling_elastic_ep = False - - -class ScalingMiddleware: - """ - Middleware that checks if the model is currently scaling and - returns a 503 Service Unavailable response if it is. - - This middleware applies to all HTTP requests and prevents - processing when the model is in a scaling state. - """ - - def __init__(self, app: ASGIApp) -> None: - self.app = app - - def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: - if scope["type"] != "http": - return self.app(scope, receive, send) - - # Check global scaling state - global _scaling_elastic_ep - if _scaling_elastic_ep: - # Return 503 Service Unavailable response - response = JSONResponse( - content={ - "error": "The model is currently scaling. Please try again later." - }, - status_code=503, - ) - return response(scope, receive, send) - - return self.app(scope, receive, send) - - def _extract_content_from_chunk(chunk_data: dict) -> str: """Extract content from a streaming response chunk.""" try: @@ -1353,15 +951,10 @@ def build_app(args: Namespace) -> FastAPI: ) else: app = FastAPI(lifespan=lifespan) + app.state.args = args + from vllm.entrypoints.serve import register_vllm_serve_api_routers - if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: - logger.warning( - "LoRA dynamic loading & unloading is enabled in the API server. " - "This should ONLY be used for local development!" - ) - from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes - - register_dynamic_lora_routes(router) + register_vllm_serve_api_routers(app) from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes @@ -1370,8 +963,6 @@ def build_app(args: Namespace) -> FastAPI: app.root_path = args.root_path - mount_metrics(app) - from vllm.entrypoints.pooling import register_pooling_api_routers register_pooling_api_routers(app) @@ -1462,31 +1053,6 @@ async def log_response(request: Request, call_next): ) app = sagemaker_standards.bootstrap(app) - # Optional endpoints - if args.tokens_only: - - @app.post("/abort_requests") - async def abort_requests(raw_request: Request): - """ - Abort one or more requests. To be used in a - Disaggregated Everything setup. - """ - try: - body = await raw_request.json() - except json.JSONDecodeError as e: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail=f"JSON decode error: {e}", - ) from e - request_ids = body.get("request_ids") - if request_ids is None: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST.value, - detail="Missing 'request_ids' in request body", - ) - # Abort requests in background - asyncio.create_task(engine_client(raw_request).abort(request_ids)) - return Response(status_code=200) return app @@ -1515,7 +1081,7 @@ async def init_app_state( state.engine_client = engine_client state.log_stats = not args.disable_log_stats state.vllm_config = vllm_config - + state.args = args supported_tasks = await engine_client.get_supported_tasks() logger.info("Supported tasks: %s", supported_tasks) @@ -1839,7 +1405,6 @@ async def run_server_worker( args, client_config=client_config, ) as engine_client: - maybe_register_tokenizer_info_endpoint(args) app = build_app(args) await init_app_state(engine_client, app.state, args) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1d89aa011af2..67291f45a925 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -74,8 +74,6 @@ ErrorResponse, FunctionCall, FunctionDefinition, - GenerateRequest, - GenerateResponse, ResponsesRequest, TokenizeChatRequest, TokenizeCompletionRequest, @@ -87,6 +85,7 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig +from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse from vllm.entrypoints.utils import _validate_truncation_size from vllm.inputs.data import PromptType from vllm.inputs.data import TokensPrompt as EngineTokensPrompt diff --git a/vllm/entrypoints/sagemaker/routes.py b/vllm/entrypoints/sagemaker/routes.py index 108fdd773e32..ea88c0fc4b97 100644 --- a/vllm/entrypoints/sagemaker/routes.py +++ b/vllm/entrypoints/sagemaker/routes.py @@ -16,7 +16,6 @@ completion, create_chat_completion, create_completion, - health, validate_json_request, ) from vllm.entrypoints.openai.protocol import ( @@ -38,6 +37,7 @@ score, ) from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest +from vllm.entrypoints.serve.instrumentator.health import health # TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers # (requires typing_extensions >= 4.13) diff --git a/vllm/entrypoints/serve/__init__.py b/vllm/entrypoints/serve/__init__.py new file mode 100644 index 000000000000..c4fcc92db931 --- /dev/null +++ b/vllm/entrypoints/serve/__init__.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import FastAPI + + +def register_vllm_serve_api_routers(app: FastAPI): + from vllm.entrypoints.serve.lora.api_router import ( + attach_router as attach_lora_router, + ) + + attach_lora_router(app) + from vllm.entrypoints.serve.elastic_ep.api_router import ( + attach_router as attach_elastic_ep_router, + ) + + attach_elastic_ep_router(app) + + from vllm.entrypoints.serve.profile.api_router import ( + attach_router as attach_profile_router, + ) + + attach_profile_router(app) + + from vllm.entrypoints.serve.sleep.api_router import ( + attach_router as attach_sleep_router, + ) + + attach_sleep_router(app) + + from vllm.entrypoints.serve.tokenize.api_router import ( + attach_router as attach_tokenize_router, + ) + + attach_tokenize_router(app) + + from vllm.entrypoints.serve.disagg.api_router import ( + attach_router as attach_disagg_router, + ) + + attach_disagg_router(app) + + from vllm.entrypoints.serve.rlhf.api_router import ( + attach_router as attach_rlhf_router, + ) + + attach_rlhf_router(app) + + from vllm.entrypoints.serve.instrumentator.metrics import ( + attach_router as attach_metrics_router, + ) + + attach_metrics_router(app) + + from vllm.entrypoints.serve.instrumentator.health import ( + attach_router as attach_health_router, + ) + + attach_health_router(app) diff --git a/vllm/entrypoints/serve/disagg/__init__.py b/vllm/entrypoints/serve/disagg/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/disagg/api_router.py b/vllm/entrypoints/serve/disagg/api_router.py new file mode 100644 index 000000000000..c38ede30dad1 --- /dev/null +++ b/vllm/entrypoints/serve/disagg/api_router.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import asyncio +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, +) +from vllm.entrypoints.serve.disagg.serving import ( + ServingTokens, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + load_aware_call, + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +def generate_tokens(request: Request) -> ServingTokens | None: + return request.app.state.serving_tokens + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/inference/v1/generate", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def generate(request: GenerateRequest, raw_request: Request): + handler = generate_tokens(raw_request) + if handler is None: + return tokenization(raw_request).create_error_response( + message="The model does not support generate tokens API" + ) + try: + generator = await handler.serve_tokens(request, raw_request) + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + + elif isinstance(generator, GenerateResponse): + return JSONResponse(content=generator.model_dump()) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "tokens_only", False): + + @router.post("/abort_requests") + async def abort_requests(raw_request: Request): + """ + Abort one or more requests. To be used in a + Disaggregated Everything setup. + """ + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail=f"JSON decode error: {e}", + ) from e + request_ids = body.get("request_ids") + if request_ids is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST.value, + detail="Missing 'request_ids' in request body", + ) + # Abort requests in background + asyncio.create_task(engine_client(raw_request).abort(request_ids)) + return Response(status_code=200) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/disagg/protocol.py b/vllm/entrypoints/serve/disagg/protocol.py new file mode 100644 index 000000000000..251fcf12ed7d --- /dev/null +++ b/vllm/entrypoints/serve/disagg/protocol.py @@ -0,0 +1,90 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +from pydantic import BaseModel, Field + +from vllm.entrypoints.openai.protocol import ( + ChatCompletionLogProbs, + Logprob, + SamplingParams, + StreamOptions, +) +from vllm.utils import random_uuid + + +####### Tokens IN <> Tokens OUT ####### +class GenerateRequest(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + token_ids: list[int] + """The token ids to generate text from.""" + + # features: MultiModalFeatureSpec + # TODO (NickLucche): implement once Renderer work is completed + features: str | None = None + """The processed MM inputs for the model.""" + + sampling_params: SamplingParams + """The sampling parameters for the model.""" + + model: str | None = None + + stream: bool | None = False + stream_options: StreamOptions | None = None + cache_salt: str | None = Field( + default=None, + description=( + "If specified, the prefix cache will be salted with the provided " + "string to prevent an attacker to guess prompts in multi-user " + "environments. The salt should be random, protected from " + "access by 3rd parties, and long enough to be " + "unpredictable (e.g., 43 characters base64-encoded, corresponding " + "to 256 bit)." + ), + ) + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling." + ), + ) + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) + + +class GenerateResponseChoice(BaseModel): + index: int + logprobs: ChatCompletionLogProbs | None = None + # per OpenAI spec this is the default + finish_reason: str | None = "stop" + token_ids: list[int] | None = None + + +class GenerateResponse(BaseModel): + request_id: str = Field( + default_factory=lambda: f"{random_uuid()}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response." + ), + ) + choices: list[GenerateResponseChoice] + + prompt_logprobs: list[dict[int, Logprob] | None] | None = None + + kv_transfer_params: dict[str, Any] | None = Field( + default=None, + description="KVTransfer parameters used for disaggregated serving.", + ) diff --git a/vllm/entrypoints/openai/serving_tokens.py b/vllm/entrypoints/serve/disagg/serving.py similarity index 99% rename from vllm/entrypoints/openai/serving_tokens.py rename to vllm/entrypoints/serve/disagg/serving.py index daa739e41fa0..5c1d17156a90 100644 --- a/vllm/entrypoints/openai/serving_tokens.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import asyncio import time from collections.abc import AsyncGenerator @@ -14,15 +16,17 @@ ChatCompletionLogProbs, ChatCompletionLogProbsContent, ErrorResponse, - GenerateRequest, - GenerateResponse, - GenerateResponseChoice, PromptTokenUsageInfo, RequestResponseMetadata, UsageInfo, ) from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs from vllm.entrypoints.openai.serving_models import OpenAIServingModels +from vllm.entrypoints.serve.disagg.protocol import ( + GenerateRequest, + GenerateResponse, + GenerateResponseChoice, +) from vllm.inputs.data import TokensPrompt as EngineTokensPrompt from vllm.logger import init_logger from vllm.logprobs import Logprob diff --git a/vllm/entrypoints/serve/elastic_ep/__init__.py b/vllm/entrypoints/serve/elastic_ep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/elastic_ep/api_router.py b/vllm/entrypoints/serve/elastic_ep/api_router.py new file mode 100644 index 000000000000..21d5d2e60778 --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/api_router.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import json +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + ErrorResponse, +) +from vllm.entrypoints.serve.elastic_ep.middleware import ( + get_scaling_elastic_ep, + set_scaling_elastic_ep, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post( + "/scale_elastic_ep", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"model": dict}, + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +async def scale_elastic_ep(raw_request: Request): + try: + body = await raw_request.json() + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 + + new_data_parallel_size = body.get("new_data_parallel_size") + drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes + + if new_data_parallel_size is None: + raise HTTPException( + status_code=400, detail="new_data_parallel_size is required" + ) + + if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0: + raise HTTPException( + status_code=400, + detail="new_data_parallel_size must be a positive integer", + ) + + if not isinstance(drain_timeout, int) or drain_timeout <= 0: + raise HTTPException( + status_code=400, detail="drain_timeout must be a positive integer" + ) + + # Set scaling flag to prevent new requests + set_scaling_elastic_ep(True) + client = engine_client(raw_request) + try: + await client.scale_elastic_ep(new_data_parallel_size, drain_timeout) + return JSONResponse( + { + "message": f"Scaled to {new_data_parallel_size} data parallel engines", + } + ) + except TimeoutError as e: + raise HTTPException( + status_code=408, + detail="Scale failed due to request drain timeout " + f"after {drain_timeout} seconds", + ) from e + except Exception as e: + logger.error("Scale failed: %s", e) + raise HTTPException(status_code=500, detail="Scale failed") from e + finally: + set_scaling_elastic_ep(False) + + +@router.post("/is_scaling_elastic_ep") +async def is_scaling_elastic_ep(raw_request: Request): + return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/elastic_ep/middleware.py b/vllm/entrypoints/serve/elastic_ep/middleware.py new file mode 100644 index 000000000000..23f45eafeaa0 --- /dev/null +++ b/vllm/entrypoints/serve/elastic_ep/middleware.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Awaitable + +from fastapi.responses import JSONResponse +from starlette.types import ASGIApp, Receive, Scope, Send + +# Global variable to track scaling state +_scaling_elastic_ep = False + + +def get_scaling_elastic_ep(): + return _scaling_elastic_ep + + +def set_scaling_elastic_ep(value): + global _scaling_elastic_ep + _scaling_elastic_ep = value + + +class ScalingMiddleware: + """ + Middleware that checks if the model is currently scaling and + returns a 503 Service Unavailable response if it is. + + This middleware applies to all HTTP requests and prevents + processing when the model is in a scaling state. + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]: + if scope["type"] != "http": + return self.app(scope, receive, send) + + # Check global scaling state + if get_scaling_elastic_ep(): + # Return 503 Service Unavailable response + response = JSONResponse( + content={ + "error": "The model is currently scaling. Please try again later." + }, + status_code=503, + ) + return response(scope, receive, send) + + return self.app(scope, receive, send) diff --git a/vllm/entrypoints/serve/instrumentator/__init__.py b/vllm/entrypoints/serve/instrumentator/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/instrumentator/health.py b/vllm/entrypoints/serve/instrumentator/health.py new file mode 100644 index 000000000000..029ef677aaa2 --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/health.py @@ -0,0 +1,33 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, Request +from fastapi.responses import Response + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger +from vllm.v1.engine.exceptions import EngineDeadError + +logger = init_logger(__name__) + + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.get("/health", response_class=Response) +async def health(raw_request: Request) -> Response: + """Health check.""" + try: + await engine_client(raw_request).check_health() + return Response(status_code=200) + except EngineDeadError: + return Response(status_code=503) + + +def attach_router(app): + app.include_router(router) diff --git a/vllm/entrypoints/serve/instrumentator/metrics.py b/vllm/entrypoints/serve/instrumentator/metrics.py new file mode 100644 index 000000000000..efe0c63a9071 --- /dev/null +++ b/vllm/entrypoints/serve/instrumentator/metrics.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import re + +import prometheus_client +from fastapi import FastAPI, Response +from prometheus_client import make_asgi_app +from prometheus_fastapi_instrumentator import Instrumentator +from starlette.routing import Mount + +from vllm.v1.metrics.prometheus import get_prometheus_registry + + +class PrometheusResponse(Response): + media_type = prometheus_client.CONTENT_TYPE_LATEST + + +def attach_router(app: FastAPI): + """Mount prometheus metrics to a FastAPI app.""" + + registry = get_prometheus_registry() + + # `response_class=PrometheusResponse` is needed to return an HTTP response + # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8" + # instead of the default "application/json" which is incorrect. + # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364 + Instrumentator( + excluded_handlers=[ + "/metrics", + "/health", + "/load", + "/ping", + "/version", + "/server_info", + ], + registry=registry, + ).add().instrument(app).expose(app, response_class=PrometheusResponse) + + # Add prometheus asgi middleware to route /metrics requests + metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) + + # Workaround for 307 Redirect for /metrics + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") + app.routes.append(metrics_route) diff --git a/vllm/entrypoints/serve/lora/__init__.py b/vllm/entrypoints/serve/lora/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/dynamic_lora.py b/vllm/entrypoints/serve/lora/api_router.py similarity index 80% rename from vllm/entrypoints/dynamic_lora.py rename to vllm/entrypoints/serve/lora/api_router.py index cc0f437e5c77..6a57e73f334f 100644 --- a/vllm/entrypoints/dynamic_lora.py +++ b/vllm/entrypoints/serve/lora/api_router.py @@ -1,9 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + import model_hosting_container_standards.sagemaker as sagemaker_standards -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Depends, FastAPI, Request from fastapi.responses import JSONResponse, Response +from vllm import envs from vllm.entrypoints.openai.api_server import models, validate_json_request from vllm.entrypoints.openai.protocol import ( ErrorResponse, @@ -14,9 +17,18 @@ from vllm.logger import init_logger logger = init_logger(__name__) +router = APIRouter() -def register_dynamic_lora_routes(router: APIRouter): +def attach_router(app: FastAPI): + if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + """If LoRA dynamic loading & unloading is not enabled, do nothing.""" + return + logger.warning( + "LoRA dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!" + ) + @sagemaker_standards.register_load_adapter_handler( request_shape={ "lora_name": "body.name", @@ -54,4 +66,5 @@ async def unload_lora_adapter( return Response(status_code=200, content=response) - return router + # register the router + app.include_router(router) diff --git a/vllm/entrypoints/serve/profile/__init__.py b/vllm/entrypoints/serve/profile/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/profile/api_router.py b/vllm/entrypoints/serve/profile/api_router.py new file mode 100644 index 000000000000..166f13764eb3 --- /dev/null +++ b/vllm/entrypoints/serve/profile/api_router.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import Response + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + +router = APIRouter() + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +@router.post("/start_profile") +async def start_profile(raw_request: Request): + logger.info("Starting profiler...") + await engine_client(raw_request).start_profile() + logger.info("Profiler started.") + return Response(status_code=200) + + +@router.post("/stop_profile") +async def stop_profile(raw_request: Request): + logger.info("Stopping profiler...") + await engine_client(raw_request).stop_profile() + logger.info("Profiler stopped.") + return Response(status_code=200) + + +def attach_router(app: FastAPI): + if envs.VLLM_TORCH_PROFILER_DIR: + logger.warning_once( + "Torch Profiler is enabled in the API server. This should ONLY be " + "used for local development!" + ) + elif envs.VLLM_TORCH_CUDA_PROFILE: + logger.warning_once( + "CUDA Profiler is enabled in the API server. This should ONLY be " + "used for local development!" + ) + if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE: + app.include_router(router) diff --git a/vllm/entrypoints/serve/rlhf/__init__.py b/vllm/entrypoints/serve/rlhf/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/rlhf/api_router.py b/vllm/entrypoints/serve/rlhf/api_router.py new file mode 100644 index 000000000000..3b37840ae089 --- /dev/null +++ b/vllm/entrypoints/serve/rlhf/api_router.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, FastAPI, Query, Request +from fastapi.responses import JSONResponse + +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/pause") +async def pause_generation( + raw_request: Request, + wait_for_inflight_requests: bool = Query(False), + clear_cache: bool = Query(True), +) -> JSONResponse: + """Pause generation requests to allow weight updates. + + Args: + wait_for_inflight_requests: When ``True`` waits for in-flight + requests to finish before pausing. When ``False`` (default), + aborts any in-flight requests immediately. + clear_cache: Whether to clear KV/prefix caches after draining. + """ + + engine = engine_client(raw_request) + + try: + await engine.pause_generation( + wait_for_inflight_requests=wait_for_inflight_requests, + clear_cache=clear_cache, + ) + return JSONResponse( + content={"status": "paused"}, + status_code=HTTPStatus.OK.value, + ) + + except ValueError as err: + return JSONResponse( + content={"error": str(err)}, + status_code=HTTPStatus.BAD_REQUEST.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to pause generation") + return JSONResponse( + content={"error": f"Failed to pause generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.post("/resume") +async def resume_generation(raw_request: Request) -> JSONResponse: + """Resume generation after a pause.""" + + engine = engine_client(raw_request) + + try: + await engine.resume_generation() + return JSONResponse( + content={"status": "resumed"}, + status_code=HTTPStatus.OK.value, + ) + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to resume generation") + return JSONResponse( + content={"error": f"Failed to resume generation: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + +@router.get("/is_paused") +async def is_paused(raw_request: Request) -> JSONResponse: + """Return the current pause status.""" + + engine = engine_client(raw_request) + + try: + paused = await engine.is_paused() + except Exception as err: # pragma: no cover - defensive + logger.exception("Failed to fetch pause status") + return JSONResponse( + content={"error": f"Failed to fetch pause status: {err}"}, + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + ) + + return JSONResponse(content={"is_paused": paused}) + + +def attach_router(app: FastAPI): + app.include_router(router) diff --git a/vllm/entrypoints/serve/sleep/__init__.py b/vllm/entrypoints/serve/sleep/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/sleep/api_router.py b/vllm/entrypoints/serve/sleep/api_router.py new file mode 100644 index 000000000000..bc01e185315c --- /dev/null +++ b/vllm/entrypoints/serve/sleep/api_router.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from fastapi import APIRouter, FastAPI, Request +from fastapi.responses import JSONResponse, Response + +import vllm.envs as envs +from vllm.engine.protocol import EngineClient +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def engine_client(request: Request) -> EngineClient: + return request.app.state.engine_client + + +router = APIRouter() + + +@router.post("/sleep") +async def sleep(raw_request: Request): + # get POST params + level = raw_request.query_params.get("level", "1") + await engine_client(raw_request).sleep(int(level)) + # FIXME: in v0 with frontend multiprocessing, the sleep command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.post("/wake_up") +async def wake_up(raw_request: Request): + tags = raw_request.query_params.getlist("tags") + if tags == []: + # set to None to wake up all tags if no tags are provided + tags = None + logger.info("wake up the engine with tags: %s", tags) + await engine_client(raw_request).wake_up(tags) + # FIXME: in v0 with frontend multiprocessing, the wake-up command + # is sent but does not finish yet when we return a response. + return Response(status_code=200) + + +@router.get("/is_sleeping") +async def is_sleeping(raw_request: Request): + logger.info("check whether the engine is sleeping") + is_sleeping = await engine_client(raw_request).is_sleeping() + return JSONResponse(content={"is_sleeping": is_sleeping}) + + +def attach_router(app: FastAPI): + if not envs.VLLM_SERVER_DEV_MODE: + return + logger.warning( + "SECURITY WARNING: Development endpoints are enabled! " + "This should NOT be used in production!" + ) + + app.include_router(router) diff --git a/vllm/entrypoints/serve/tokenize/__init__.py b/vllm/entrypoints/serve/tokenize/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/entrypoints/serve/tokenize/api_router.py b/vllm/entrypoints/serve/tokenize/api_router.py new file mode 100644 index 000000000000..a10e78c8d28e --- /dev/null +++ b/vllm/entrypoints/serve/tokenize/api_router.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from http import HTTPStatus + +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from typing_extensions import assert_never + +from vllm.entrypoints.openai.api_server import validate_json_request +from vllm.entrypoints.openai.protocol import ( + DetokenizeRequest, + DetokenizeResponse, + ErrorResponse, + TokenizeRequest, + TokenizeResponse, +) +from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization +from vllm.entrypoints.utils import ( + with_cancellation, +) +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def tokenization(request: Request) -> OpenAIServingTokenization: + return request.app.state.openai_serving_tokenization + + +router = APIRouter() + + +@router.post( + "/tokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def tokenize(request: TokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_tokenize(request, raw_request) + except NotImplementedError as e: + raise HTTPException( + status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e) + ) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, TokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +@router.post( + "/detokenize", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, + }, +) +@with_cancellation +async def detokenize(request: DetokenizeRequest, raw_request: Request): + handler = tokenization(raw_request) + + try: + generator = await handler.create_detokenize(request, raw_request) + except OverflowError as e: + raise RequestValidationError(errors=[str(e)]) from e + except Exception as e: + raise HTTPException( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) + ) from e + + if isinstance(generator, ErrorResponse): + return JSONResponse( + content=generator.model_dump(), status_code=generator.error.code + ) + elif isinstance(generator, DetokenizeResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + +def attach_router(app: FastAPI): + if getattr(app.state.args, "enable_tokenizer_info_endpoint", False): + """Conditionally register the tokenizer info endpoint if enabled.""" + + @router.get("/tokenizer_info") + async def get_tokenizer_info(raw_request: Request): + """Get comprehensive tokenizer information.""" + result = await tokenization(raw_request).get_tokenizer_info() + return JSONResponse( + content=result.model_dump(), + status_code=result.error.code + if isinstance(result, ErrorResponse) + else 200, + ) + + app.include_router(router) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/serve/tokenize/serving.py similarity index 100% rename from vllm/entrypoints/openai/serving_tokenization.py rename to vllm/entrypoints/serve/tokenize/serving.py