|
| 1 | +import asyncio |
| 2 | +import gc |
| 3 | +import hashlib |
| 4 | +import importlib |
| 5 | +import inspect |
| 6 | +import json |
| 7 | +import multiprocessing |
| 8 | +import multiprocessing.forkserver as forkserver |
| 9 | +import os |
| 10 | +import secrets |
| 11 | +import signal |
| 12 | +import socket |
| 13 | +import tempfile |
| 14 | +import uuid |
| 15 | +from argparse import Namespace |
| 16 | +from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable |
| 17 | +from contextlib import asynccontextmanager |
| 18 | +from http import HTTPStatus |
| 19 | +from typing import Annotated, Any, Literal |
| 20 | + |
| 21 | +import prometheus_client |
| 22 | +import pydantic |
| 23 | +import regex as re |
| 24 | +import uvloop |
| 25 | +from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request |
| 26 | +from fastapi.exceptions import RequestValidationError |
| 27 | +from fastapi.middleware.cors import CORSMiddleware |
| 28 | +from fastapi.responses import JSONResponse, Response, StreamingResponse |
| 29 | +from prometheus_client import make_asgi_app |
| 30 | +from prometheus_fastapi_instrumentator import Instrumentator |
| 31 | +from starlette.concurrency import iterate_in_threadpool |
| 32 | +from starlette.datastructures import URL, Headers, MutableHeaders, State |
| 33 | +from starlette.routing import Mount |
| 34 | +from starlette.types import ASGIApp, Message, Receive, Scope, Send |
| 35 | +from typing_extensions import assert_never |
| 36 | + |
| 37 | +import vllm.envs as envs |
| 38 | +from vllm.config import VllmConfig |
| 39 | +from vllm.engine.arg_utils import AsyncEngineArgs |
| 40 | +from vllm.engine.protocol import Device, EngineClient |
| 41 | +from vllm.entrypoints.anthropic.protocol import ( |
| 42 | + AnthropicError, |
| 43 | + AnthropicErrorResponse, |
| 44 | + AnthropicMessagesRequest, |
| 45 | + AnthropicMessagesResponse, |
| 46 | +) |
| 47 | +from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages |
| 48 | +from vllm.entrypoints.launcher import serve_http |
| 49 | +from vllm.entrypoints.logger import RequestLogger |
| 50 | +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args |
| 51 | +from vllm.entrypoints.openai.orca_metrics import metrics_header |
| 52 | +from vllm.entrypoints.openai.protocol import ( |
| 53 | + ChatCompletionRequest, |
| 54 | + ChatCompletionResponse, |
| 55 | + ClassificationRequest, |
| 56 | + ClassificationResponse, |
| 57 | + CompletionRequest, |
| 58 | + CompletionResponse, |
| 59 | + DetokenizeRequest, |
| 60 | + DetokenizeResponse, |
| 61 | + EmbeddingBytesResponse, |
| 62 | + EmbeddingRequest, |
| 63 | + EmbeddingResponse, |
| 64 | + ErrorInfo, |
| 65 | + ErrorResponse, |
| 66 | + IOProcessorResponse, |
| 67 | + LoadLoRAAdapterRequest, |
| 68 | + PoolingBytesResponse, |
| 69 | + PoolingRequest, |
| 70 | + PoolingResponse, |
| 71 | + RerankRequest, |
| 72 | + RerankResponse, |
| 73 | + ResponsesRequest, |
| 74 | + ResponsesResponse, |
| 75 | + ScoreRequest, |
| 76 | + ScoreResponse, |
| 77 | + StreamingResponsesResponse, |
| 78 | + TokenizeRequest, |
| 79 | + TokenizeResponse, |
| 80 | + TranscriptionRequest, |
| 81 | + TranscriptionResponse, |
| 82 | + TranslationRequest, |
| 83 | + TranslationResponse, |
| 84 | + UnloadLoRAAdapterRequest, |
| 85 | +) |
| 86 | +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat |
| 87 | +from vllm.entrypoints.openai.serving_classification import ServingClassification |
| 88 | +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion |
| 89 | +from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding |
| 90 | +from vllm.entrypoints.openai.serving_engine import OpenAIServing |
| 91 | +from vllm.entrypoints.openai.serving_models import ( |
| 92 | + BaseModelPath, |
| 93 | + OpenAIServingModels, |
| 94 | +) |
| 95 | +from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling |
| 96 | +from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses |
| 97 | +from vllm.entrypoints.openai.serving_score import ServingScores |
| 98 | +from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization |
| 99 | +from vllm.entrypoints.openai.serving_transcription import ( |
| 100 | + OpenAIServingTranscription, |
| 101 | + OpenAIServingTranslation, |
| 102 | +) |
| 103 | +from vllm.entrypoints.openai.tool_parsers import ToolParserManager |
| 104 | +from vllm.entrypoints.serve.utils import validate_json_request |
| 105 | +from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer |
| 106 | +from vllm.entrypoints.utils import ( |
| 107 | + cli_env_setup, |
| 108 | + load_aware_call, |
| 109 | + log_non_default_args, |
| 110 | + process_chat_template, |
| 111 | + process_lora_modules, |
| 112 | + with_cancellation, |
| 113 | +) |
| 114 | +from vllm.logger import init_logger |
| 115 | +from vllm.reasoning import ReasoningParserManager |
| 116 | +from vllm.tasks import POOLING_TASKS |
| 117 | +from vllm.usage.usage_lib import UsageContext |
| 118 | +from vllm.utils.argparse_utils import FlexibleArgumentParser |
| 119 | +from vllm.utils.network_utils import is_valid_ipv6_address |
| 120 | +from vllm.utils.system_utils import decorate_logs, set_ulimit |
| 121 | +from vllm.v1.engine.exceptions import EngineDeadError |
| 122 | +from vllm.v1.metrics.prometheus import get_prometheus_registry |
| 123 | +from vllm.version import __version__ as VLLM_VERSION |
| 124 | + |
| 125 | +router = APIRouter(tags=["Jinja APIs"]) |
| 126 | + |
| 127 | +logger = init_logger("vllm.entrypoints.jinja.api_server") |
| 128 | + |
| 129 | + |
| 130 | +def tokenization(request: Request) -> OpenAIServingTokenization: |
| 131 | + return request.app.state.openai_serving_tokenization |
| 132 | + |
| 133 | + |
| 134 | +def base(request: Request) -> OpenAIServing: |
| 135 | + # Reuse the existing instance |
| 136 | + return tokenization(request) |
| 137 | + |
| 138 | + |
| 139 | +def rerank(request: Request) -> ServingScores | None: |
| 140 | + return request.app.state.jinja_serving_scores |
| 141 | + |
| 142 | + |
| 143 | +@router.post( |
| 144 | + "/rerank", |
| 145 | + dependencies=[Depends(validate_json_request)], |
| 146 | + responses={ |
| 147 | + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, |
| 148 | + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, |
| 149 | + }, |
| 150 | +) |
| 151 | +@with_cancellation |
| 152 | +@load_aware_call |
| 153 | +async def do_rerank(request: RerankRequest, raw_request: Request): |
| 154 | + handler = rerank(raw_request) |
| 155 | + if handler is None: |
| 156 | + return base(raw_request).create_error_response( |
| 157 | + message="The model does not support Rerank (Score) API" |
| 158 | + ) |
| 159 | + try: |
| 160 | + generator = await handler.do_rerank(request, raw_request) |
| 161 | + except Exception as e: |
| 162 | + raise HTTPException( |
| 163 | + status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e) |
| 164 | + ) from e |
| 165 | + if isinstance(generator, ErrorResponse): |
| 166 | + return JSONResponse( |
| 167 | + content=generator.model_dump(), status_code=generator.error.code |
| 168 | + ) |
| 169 | + elif isinstance(generator, RerankResponse): |
| 170 | + return JSONResponse(content=generator.model_dump()) |
| 171 | + |
| 172 | + assert_never(generator) |
| 173 | + |
| 174 | + |
| 175 | +@router.post( |
| 176 | + "/v1/rerank", |
| 177 | + dependencies=[Depends(validate_json_request)], |
| 178 | + responses={ |
| 179 | + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, |
| 180 | + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, |
| 181 | + }, |
| 182 | +) |
| 183 | +@with_cancellation |
| 184 | +async def do_rerank_v1(request: RerankRequest, raw_request: Request): |
| 185 | + logger.warning_once( |
| 186 | + "To indicate that the rerank API is not part of the standard OpenAI" |
| 187 | + " API, we have located it at `/rerank`. Please update your client " |
| 188 | + "accordingly. (Note: Conforms to JinaAI rerank API)" |
| 189 | + ) |
| 190 | + |
| 191 | + return await do_rerank(request, raw_request) |
| 192 | + |
| 193 | + |
| 194 | +@router.post( |
| 195 | + "/v2/rerank", |
| 196 | + dependencies=[Depends(validate_json_request)], |
| 197 | + responses={ |
| 198 | + HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse}, |
| 199 | + HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse}, |
| 200 | + }, |
| 201 | +) |
| 202 | +@with_cancellation |
| 203 | +async def do_rerank_v2(request: RerankRequest, raw_request: Request): |
| 204 | + return await do_rerank(request, raw_request) |
0 commit comments