Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ VOLCENGINE_TTS_ACCESS_TOKEN=xxx
# LANGSMITH_API_KEY="xxx"
# LANGSMITH_PROJECT="xxx"

# JWT Secret Key for authentication (required for production)
# Generate a secure random key at least 32 characters long
# Example: JWT_SECRET_KEY="your-very-secure-random-secret-key-with-mixed-characters-123!@#"
JWT_SECRET_KEY=your-very-secure-random-secret-key-with-mixed-characters-123!@#

# [!NOTE]
# For model settings and other configurations, please refer to `docs/configuration_guide.md`

Expand Down
74 changes: 70 additions & 4 deletions src/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import base64
import json
import logging

import os
from typing import Annotated, Any, List, cast
from typing import Annotated, Any, List, cast, Optional

from uuid import uuid4

from fastapi import FastAPI, HTTPException, Query
from fastapi import FastAPI, HTTPException, Query, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response, StreamingResponse
from fastapi.security import HTTPBearer
from pydantic import BaseModel
from langchain_core.messages import AIMessageChunk, BaseMessage, ToolMessage
from langgraph.checkpoint.mongodb import AsyncMongoDBSaver
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
Expand Down Expand Up @@ -51,6 +55,7 @@
)
from src.tools import VolcengineTTS
from src.utils.json_utils import sanitize_args
from src.server.middleware.auth import authenticate_user, create_access_token, get_current_user, require_admin_user, generate_csrf_token

logger = logging.getLogger(__name__)

Expand All @@ -59,6 +64,7 @@
if os.name == "nt":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())


INTERNAL_SERVER_ERROR_DETAIL = "Internal Server Error"

app = FastAPI(
Expand All @@ -67,6 +73,66 @@
version="0.1.0",
)

class LoginResponse(BaseModel):
access_token: str
token_type: str
user: dict
csrf_token: Optional[str] = None

@app.post("/api/auth/login", response_model=LoginResponse)
async def login(form_data: dict, response: Response):
"""Authenticate user and return JWT token"""
email = form_data.get("email")
password = form_data.get("password")

if not email or not password:
raise HTTPException(
status_code=400,
detail="Email and password are required"
)

user = authenticate_user(email, password)
if not user:
raise HTTPException(
status_code=401,
detail="Invalid credentials"
)

access_token = create_access_token(
data={"sub": user["id"], "email": user["email"], "role": user["role"]}
)

# Generate CSRF token
csrf_token = generate_csrf_token()

# Set CSRF token in cookie
response.set_cookie(
key="csrf_token",
value=csrf_token,
httponly=False,
secure=True,
samesite="strict",
max_age=86400 # 24 hours
)

login_response = LoginResponse(
access_token=access_token,
token_type="bearer",
user=user
)

# Add CSRF token if available (for new auth system)
if csrf_token:
login_response.csrf_token = csrf_token

return login_response

@app.post("/api/auth/logout")
async def logout():
"""Logout user"""
# In a real implementation, you might want to add the token to a blacklist
return {"message": "Successfully logged out"}

# Add CORS middleware
# It's recommended to load the allowed origins from an environment variable
# for better security and flexibility across different environments.
Expand All @@ -91,7 +157,7 @@


@app.post("/api/chat/stream")
async def chat_stream(request: ChatRequest):
async def chat_stream(request: ChatRequest, current_user: dict = Depends(get_current_user)):
# Check if MCP server configuration is enabled
mcp_enabled = get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False)

Expand Down Expand Up @@ -561,7 +627,7 @@ async def enhance_prompt(request: EnhancePromptRequest):


@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
async def mcp_server_metadata(request: MCPServerMetadataRequest):
async def mcp_server_metadata(request: MCPServerMetadataRequest, current_user: dict = Depends(require_admin_user)):
"""Get information about an MCP server."""
# Check if MCP server configuration is enabled
if not get_bool_env("ENABLE_MCP_SERVER_CONFIGURATION", False):
Expand Down
157 changes: 157 additions & 0 deletions src/server/middleware/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT

import jwt
import logging
import os
import secrets
from datetime import datetime, timedelta, timezone
from typing import Callable
from fastapi import HTTPException, Depends, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials

logger = logging.getLogger(__name__)

# JWT configuration
SECRET_KEY = os.getenv("JWT_SECRET_KEY")
ALGORITHM = "HS256"

# Check if we're in a test environment
def is_test_environment() -> bool:
"""Check if we're running in a test environment"""
import sys

# Check if pytest is in the command line arguments
is_pytest_running = any('pytest' in arg for arg in sys.argv)

return (
is_pytest_running or
os.getenv("PYTEST_CURRENT_TEST") is not None or # pytest is running
os.getenv("TESTING") == "true" or # explicit test flag
os.getenv("APP_ENV") == "test" # test environment
)

# Validate secret key complexity
def validate_secret_key(key: str) -> bool:
"""Validate that the secret key meets security requirements"""
if len(key) < 32:
return False
# Check for basic complexity - should contain mix of character types
has_upper = any(c.isupper() for c in key)
has_lower = any(c.islower() for c in key)
has_digit = any(c.isdigit() for c in key)
has_special = any(not c.isalnum() for c in key)
return has_upper and has_lower and has_digit and has_special

# Set up JWT secret key with fallback for test environments
if not SECRET_KEY:
if is_test_environment():
# Use a test-only secret key in test environments
SECRET_KEY = "test-secret-key-for-development-only-do-not-use-in-production-123!@#ABC"
logger.warning("Using test JWT secret key. This should only be used in test environments.")
else:
raise ValueError("JWT_SECRET_KEY environment variable is required. Set a secure random secret key.")
elif not validate_secret_key(SECRET_KEY):
if is_test_environment():
logger.warning("JWT secret key does not meet complexity requirements. Using in test environment only.")
else:
raise ValueError("JWT_SECRET_KEY must be at least 32 characters and contain uppercase, lowercase, digits, and special characters.")

security = HTTPBearer()

def create_access_token(data: dict, expires_delta: timedelta = None) -> str:
"""Create JWT token with configurable expiration"""
to_encode = data.copy()

# Use provided expiration or default to 24 hours
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(hours=24)

to_encode.update({
"exp": expire,
"iat": now # Issued at time
})

encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt

def generate_csrf_token() -> str:
"""Generate a secure CSRF token"""
return secrets.token_urlsafe(32)

def validate_csrf_token(request: Request, csrf_token: str) -> bool:
"""Validate CSRF token against session or header"""
# For state-changing operations, validate CSRF token
# In production, you might store this in session or validate against user session
expected_token = request.headers.get("X-CSRF-Token") or request.cookies.get("csrf_token")
return secrets.compare_digest(csrf_token, expected_token) if expected_token else False

def csrf_protected(request: Request = None):
"""Decorator for CSRF protection on state-changing operations"""
def decorator(func: Callable) -> Callable:
async def wrapper(*args, **kwargs):
if request and request.method in ["POST", "PUT", "DELETE", "PATCH"]:
csrf_token = request.headers.get("X-CSRF-Token")
if not csrf_token or not validate_csrf_token(request, csrf_token):
raise HTTPException(status_code=403, detail="CSRF token validation failed")
return await func(*args, **kwargs)
return wrapper
return decorator

def verify_token(token: str) -> dict:
"""Verify JWT token with enhanced error handling"""
try:
import jwt
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])

# Additional validation
if not all(key in payload for key in ["sub", "email", "role"]):
return {}

return payload
except jwt.ExpiredSignatureError:
return {}
except jwt.InvalidTokenError:
return {}
except Exception as e:
logger.warning(f"Token verification error: {e}")
return {}

def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""Get current user from token"""
token = credentials.credentials
payload = verify_token(token)
if not payload:
raise HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
return payload

def require_admin_user(current_user: dict = Depends(get_current_user)):
"""Require admin role"""
if current_user.get("role") != "admin":
raise HTTPException(
status_code=403,
detail="Admin access required",
)
return current_user

def authenticate_user(email: str, password: str) -> dict:
"""Authenticate user - in production, check against database"""
# This is a simple mock implementation
# In production, you would check against a database
if email and password:
# Simple role assignment based on email for demo
role = "admin" if "admin" in email else "user"
return {
"id": f"user_{email}",
"email": email,
"name": email.split("@")[0],
"role": role
}
return {}
Loading
Loading