diff --git a/tests/entrypoints/anthropic/__init__.py b/tests/entrypoints/anthropic/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/tests/entrypoints/anthropic/test_messages.py b/tests/entrypoints/openai/test_messages.py similarity index 68% rename from tests/entrypoints/anthropic/test_messages.py rename to tests/entrypoints/openai/test_messages.py index 4e35554b4e33..3e390ad49642 100644 --- a/tests/entrypoints/anthropic/test_messages.py +++ b/tests/entrypoints/openai/test_messages.py @@ -5,7 +5,7 @@ import pytest import pytest_asyncio -from ...utils import RemoteAnthropicServer +from ...utils import RemoteOpenAIServer MODEL_NAME = "Qwen/Qwen3-0.6B" @@ -23,13 +23,13 @@ def server(): # noqa: F811 "claude-3-7-sonnet-latest", ] - with RemoteAnthropicServer(MODEL_NAME, args) as remote_server: + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server @pytest_asyncio.fixture async def client(server): - async with server.get_async_client() as async_client: + async with server.get_async_client_anthropic() as async_client: yield async_client @@ -105,37 +105,37 @@ async def test_anthropic_tool_call(client: anthropic.AsyncAnthropic): print(f"Anthropic response: {resp.model_dump_json()}") - @pytest.mark.asyncio - async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): - resp = await client.messages.create( - model="claude-3-7-sonnet-latest", - max_tokens=1024, - messages=[ - { - "role": "user", - "content": "What's the weather like in New York today?", - } - ], - tools=[ - { - "name": "get_current_weather", - "description": "Useful for querying the weather " - "in a specified city.", - "input_schema": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "City or region, for example: " - "New York, London, Tokyo, etc.", - } - }, - "required": ["location"], + +@pytest.mark.asyncio +async def test_anthropic_tool_call_streaming(client: anthropic.AsyncAnthropic): + resp = await client.messages.create( + model="claude-3-7-sonnet-latest", + max_tokens=1024, + messages=[ + { + "role": "user", + "content": "What's the weather like in New York today?", + } + ], + tools=[ + { + "name": "get_current_weather", + "description": "Useful for querying the weather in a specified city.", + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City or region, for example: " + "New York, London, Tokyo, etc.", + } }, - } - ], - stream=True, - ) + "required": ["location"], + }, + } + ], + stream=True, + ) - async for chunk in resp: - print(chunk.model_dump_json()) + async for chunk in resp: + print(chunk.model_dump_json()) diff --git a/tests/utils.py b/tests/utils.py index af4ce6ebaeda..c8f18384c511 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -247,6 +247,23 @@ def get_async_client(self, **kwargs): **kwargs, ) + def get_client_anthropic(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.Anthropic( + base_url=self.url_for(), + api_key=self.DUMMY_API_KEY, + max_retries=0, + **kwargs, + ) + + def get_async_client_anthropic(self, **kwargs): + if "timeout" not in kwargs: + kwargs["timeout"] = 600 + return anthropic.AsyncAnthropic( + base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs + ) + class RemoteOpenAIServerCustom(RemoteOpenAIServer): """Launch test server with custom child process""" @@ -293,131 +310,6 @@ def __exit__(self, exc_type, exc_value, traceback): self.proc.kill() -class RemoteAnthropicServer: - DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key - - def __init__( - self, - model: str, - vllm_serve_args: list[str], - *, - env_dict: dict[str, str] | None = None, - seed: int | None = 0, - auto_port: bool = True, - max_wait_seconds: float | None = None, - ) -> None: - if auto_port: - if "-p" in vllm_serve_args or "--port" in vllm_serve_args: - raise ValueError( - "You have manually specified the port when `auto_port=True`." - ) - - # Don't mutate the input args - vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())] - if seed is not None: - if "--seed" in vllm_serve_args: - raise ValueError( - f"You have manually specified the seed when `seed={seed}`." - ) - - vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] - - parser = FlexibleArgumentParser(description="vLLM's remote Anthropic server.") - subparsers = parser.add_subparsers(required=False, dest="subparser") - parser = ServeSubcommand().subparser_init(subparsers) - args = parser.parse_args(["--model", model, *vllm_serve_args]) - self.host = str(args.host or "localhost") - self.port = int(args.port) - - self.show_hidden_metrics = args.show_hidden_metrics_for_version is not None - - # download the model before starting the server to avoid timeout - is_local = os.path.isdir(model) - if not is_local: - engine_args = AsyncEngineArgs.from_cli_args(args) - model_config = engine_args.create_model_config() - load_config = engine_args.create_load_config() - - model_loader = get_model_loader(load_config) - model_loader.download_model(model_config) - - env = os.environ.copy() - # the current process might initialize cuda, - # to be safe, we should use spawn method - env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" - if env_dict is not None: - env.update(env_dict) - self.proc = subprocess.Popen( - [ - sys.executable, - "-m", - "vllm.entrypoints.anthropic.api_server", - model, - *vllm_serve_args, - ], - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - max_wait_seconds = max_wait_seconds or 240 - self._wait_for_server(url=self.url_for("health"), timeout=max_wait_seconds) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.proc.terminate() - try: - self.proc.wait(8) - except subprocess.TimeoutExpired: - # force kill if needed - self.proc.kill() - - def _wait_for_server(self, *, url: str, timeout: float): - # run health check - start = time.time() - while True: - try: - if requests.get(url).status_code == 200: - break - except Exception: - # this exception can only be raised by requests.get, - # which means the server is not ready yet. - # the stack trace is not useful, so we suppress it - # by using `raise from None`. - result = self.proc.poll() - if result is not None and result != 0: - raise RuntimeError("Server exited unexpectedly.") from None - - time.sleep(0.5) - if time.time() - start > timeout: - raise RuntimeError("Server failed to start in time.") from None - - @property - def url_root(self) -> str: - return f"http://{self.host}:{self.port}" - - def url_for(self, *parts: str) -> str: - return self.url_root + "/" + "/".join(parts) - - def get_client(self, **kwargs): - if "timeout" not in kwargs: - kwargs["timeout"] = 600 - return anthropic.Anthropic( - base_url=self.url_for(), - api_key=self.DUMMY_API_KEY, - max_retries=0, - **kwargs, - ) - - def get_async_client(self, **kwargs): - if "timeout" not in kwargs: - kwargs["timeout"] = 600 - return anthropic.AsyncAnthropic( - base_url=self.url_for(), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs - ) - - def _test_completion( client: openai.OpenAI, model: str, diff --git a/vllm/entrypoints/anthropic/api_server.py b/vllm/entrypoints/anthropic/api_server.py deleted file mode 100644 index df877f99b084..000000000000 --- a/vllm/entrypoints/anthropic/api_server.py +++ /dev/null @@ -1,301 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Adapted from: -# https://github.com/vllm/vllm/entrypoints/openai/api_server.py - -import asyncio -import signal -import tempfile -from argparse import Namespace -from http import HTTPStatus - -import uvloop -from fastapi import APIRouter, Depends, FastAPI, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response, StreamingResponse -from starlette.datastructures import State - -import vllm.envs as envs -from vllm.engine.protocol import EngineClient -from vllm.entrypoints.anthropic.protocol import ( - AnthropicErrorResponse, - AnthropicMessagesRequest, - AnthropicMessagesResponse, -) -from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages -from vllm.entrypoints.launcher import serve_http -from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.api_server import ( - build_async_engine_client, - create_server_socket, - lifespan, - load_log_config, - validate_api_server_args, - validate_json_request, -) -from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args -from vllm.entrypoints.openai.protocol import ErrorResponse -from vllm.entrypoints.openai.serving_models import ( - BaseModelPath, - OpenAIServingModels, -) - -# -# yapf: enable -from vllm.entrypoints.openai.tool_parsers import ToolParserManager -from vllm.entrypoints.utils import ( - cli_env_setup, - load_aware_call, - process_chat_template, - process_lora_modules, - with_cancellation, -) -from vllm.logger import init_logger -from vllm.utils.argparse_utils import FlexibleArgumentParser -from vllm.utils.network_utils import is_valid_ipv6_address -from vllm.utils.system_utils import set_ulimit -from vllm.version import __version__ as VLLM_VERSION - -prometheus_multiproc_dir: tempfile.TemporaryDirectory - -# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765) -logger = init_logger("vllm.entrypoints.anthropic.api_server") - -_running_tasks: set[asyncio.Task] = set() - -router = APIRouter() - - -def messages(request: Request) -> AnthropicServingMessages: - return request.app.state.anthropic_serving_messages - - -def engine_client(request: Request) -> EngineClient: - return request.app.state.engine_client - - -@router.get("/health", response_class=Response) -async def health(raw_request: Request) -> Response: - """Health check.""" - await engine_client(raw_request).check_health() - return Response(status_code=200) - - -@router.get("/ping", response_class=Response) -@router.post("/ping", response_class=Response) -async def ping(raw_request: Request) -> Response: - """Ping check. Endpoint required for SageMaker""" - return await health(raw_request) - - -@router.post( - "/v1/messages", - dependencies=[Depends(validate_json_request)], - responses={ - HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, - HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, - HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, - HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, - }, -) -@with_cancellation -@load_aware_call -async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): - handler = messages(raw_request) - if handler is None: - return messages(raw_request).create_error_response( - message="The model does not support Messages API" - ) - - generator = await handler.create_messages(request, raw_request) - - if isinstance(generator, ErrorResponse): - return JSONResponse(content=generator.model_dump()) - - elif isinstance(generator, AnthropicMessagesResponse): - logger.debug( - "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) - ) - return JSONResponse(content=generator.model_dump(exclude_none=True)) - - return StreamingResponse(content=generator, media_type="text/event-stream") - - -async def init_app_state( - engine_client: EngineClient, - state: State, - args: Namespace, -) -> None: - vllm_config = engine_client.vllm_config - - if args.served_model_name is not None: - served_model_names = args.served_model_name - else: - served_model_names = [args.model] - - if args.disable_log_requests: - request_logger = None - else: - request_logger = RequestLogger(max_log_len=args.max_log_len) - - base_model_paths = [ - BaseModelPath(name=name, model_path=args.model) for name in served_model_names - ] - - state.engine_client = engine_client - state.log_stats = not args.disable_log_stats - state.vllm_config = vllm_config - model_config = vllm_config.model_config - - default_mm_loras = ( - vllm_config.lora_config.default_mm_loras - if vllm_config.lora_config is not None - else {} - ) - lora_modules = process_lora_modules(args.lora_modules, default_mm_loras) - - resolved_chat_template = await process_chat_template( - args.chat_template, engine_client, model_config - ) - - state.openai_serving_models = OpenAIServingModels( - engine_client=engine_client, - base_model_paths=base_model_paths, - lora_modules=lora_modules, - ) - await state.openai_serving_models.init_static_loras() - state.anthropic_serving_messages = AnthropicServingMessages( - engine_client, - state.openai_serving_models, - args.response_role, - request_logger=request_logger, - chat_template=resolved_chat_template, - chat_template_content_format=args.chat_template_content_format, - return_tokens_as_token_ids=args.return_tokens_as_token_ids, - enable_auto_tools=args.enable_auto_tool_choice, - tool_parser=args.tool_call_parser, - reasoning_parser=args.reasoning_parser, - enable_prompt_tokens_details=args.enable_prompt_tokens_details, - enable_force_include_usage=args.enable_force_include_usage, - ) - - -def setup_server(args): - """Validate API server args, set up signal handler, create socket - ready to serve.""" - - logger.info("vLLM API server version %s", VLLM_VERSION) - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - - validate_api_server_args(args) - - # workaround to make sure that we bind the port before the engine is set up. - # This avoids race conditions with ray. - # see https://github.com/vllm-project/vllm/issues/8204 - sock_addr = (args.host or "", args.port) - sock = create_server_socket(sock_addr) - - # workaround to avoid footguns where uvicorn drops requests with too - # many concurrent requests active - set_ulimit() - - def signal_handler(*_) -> None: - # Interrupt server on sigterm while initializing - raise KeyboardInterrupt("terminated") - - signal.signal(signal.SIGTERM, signal_handler) - - addr, port = sock_addr - is_ssl = args.ssl_keyfile and args.ssl_certfile - host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0" - listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}" - - return listen_address, sock - - -async def run_server(args, **uvicorn_kwargs) -> None: - """Run a single-worker API server.""" - listen_address, sock = setup_server(args) - await run_server_worker(listen_address, sock, args, **uvicorn_kwargs) - - -def build_app(args: Namespace) -> FastAPI: - app = FastAPI(lifespan=lifespan) - app.include_router(router) - app.root_path = args.root_path - - app.add_middleware( - CORSMiddleware, - allow_origins=args.allowed_origins, - allow_credentials=args.allow_credentials, - allow_methods=args.allowed_methods, - allow_headers=args.allowed_headers, - ) - - return app - - -async def run_server_worker( - listen_address, sock, args, client_config=None, **uvicorn_kwargs -) -> None: - """Run a single API server worker.""" - - if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: - ToolParserManager.import_tool_parser(args.tool_parser_plugin) - - server_index = client_config.get("client_index", 0) if client_config else 0 - - # Load logging config for uvicorn if specified - log_config = load_log_config(args.log_config_file) - if log_config is not None: - uvicorn_kwargs["log_config"] = log_config - - async with build_async_engine_client( - args, - client_config=client_config, - ) as engine_client: - app = build_app(args) - - await init_app_state(engine_client, app.state, args) - - logger.info("Starting vLLM API server %d on %s", server_index, listen_address) - shutdown_task = await serve_http( - app, - sock=sock, - enable_ssl_refresh=args.enable_ssl_refresh, - host=args.host, - port=args.port, - log_level=args.uvicorn_log_level, - # NOTE: When the 'disable_uvicorn_access_log' value is True, - # no access log will be output. - access_log=not args.disable_uvicorn_access_log, - timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE, - ssl_keyfile=args.ssl_keyfile, - ssl_certfile=args.ssl_certfile, - ssl_ca_certs=args.ssl_ca_certs, - ssl_cert_reqs=args.ssl_cert_reqs, - **uvicorn_kwargs, - ) - - # NB: Await server shutdown only after the backend context is exited - try: - await shutdown_task - finally: - sock.close() - - -if __name__ == "__main__": - # NOTE(simon): - # This section should be in sync with vllm/entrypoints/cli/main.py for CLI - # entrypoints. - cli_env_setup() - parser = FlexibleArgumentParser( - description="vLLM Anthropic-Compatible RESTful API server." - ) - parser = make_arg_parser(parser) - args = parser.parse_args() - validate_parsed_serve_args(args) - - uvloop.run(run_server(args)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8fa71855f8f6..22b5584749ae 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -41,6 +41,13 @@ from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import Device, EngineClient +from vllm.entrypoints.anthropic.protocol import ( + AnthropicError, + AnthropicErrorResponse, + AnthropicMessagesRequest, + AnthropicMessagesResponse, +) +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages from vllm.entrypoints.launcher import serve_http from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args @@ -308,6 +315,10 @@ def responses(request: Request) -> OpenAIServingResponses | None: return request.app.state.openai_serving_responses +def messages(request: Request) -> AnthropicServingMessages: + return request.app.state.anthropic_serving_messages + + def chat(request: Request) -> OpenAIServingChat | None: return request.app.state.openai_serving_chat @@ -591,6 +602,63 @@ async def cancel_responses(response_id: str, raw_request: Request): return JSONResponse(content=response.model_dump()) +@router.post( + "/v1/messages", + dependencies=[Depends(validate_json_request)], + responses={ + HTTPStatus.OK.value: {"content": {"text/event-stream": {}}}, + HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse}, + HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse}, + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse}, + }, +) +@with_cancellation +@load_aware_call +async def create_messages(request: AnthropicMessagesRequest, raw_request: Request): + def translate_error_response(response: ErrorResponse) -> JSONResponse: + anthropic_error = AnthropicErrorResponse( + error=AnthropicError( + type=response.error.type, + message=response.error.message, + ) + ) + return JSONResponse( + status_code=response.error.code, content=anthropic_error.model_dump() + ) + + handler = messages(raw_request) + if handler is None: + error = base(raw_request).create_error_response( + message="The model does not support Messages API" + ) + return translate_error_response(error) + + try: + generator = await handler.create_messages(request, raw_request) + except Exception as e: + logger.exception("Error in create_messages: %s", e) + return JSONResponse( + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, + content=AnthropicErrorResponse( + error=AnthropicError( + type="internal_error", + message=str(e), + ) + ).model_dump(), + ) + + if isinstance(generator, ErrorResponse): + return translate_error_response(generator) + + elif isinstance(generator, AnthropicMessagesResponse): + logger.debug( + "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True) + ) + return JSONResponse(content=generator.model_dump(exclude_none=True)) + + return StreamingResponse(content=generator, media_type="text/event-stream") + + @router.post( "/v1/chat/completions", dependencies=[Depends(validate_json_request)], @@ -1817,6 +1885,24 @@ async def init_app_state( if "transcription" in supported_tasks else None ) + state.anthropic_serving_messages = ( + AnthropicServingMessages( + engine_client, + state.openai_serving_models, + args.response_role, + request_logger=request_logger, + chat_template=resolved_chat_template, + chat_template_content_format=args.chat_template_content_format, + return_tokens_as_token_ids=args.return_tokens_as_token_ids, + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser, + reasoning_parser=args.structured_outputs_config.reasoning_parser, + enable_prompt_tokens_details=args.enable_prompt_tokens_details, + enable_force_include_usage=args.enable_force_include_usage, + ) + if "generate" in supported_tasks + else None + ) state.enable_server_load_tracking = args.enable_server_load_tracking state.server_load_metrics = 0