Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions litellm/integrations/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from litellm.types.utils import StandardLoggingPayload

from .custom_batch_logger import CustomBatchLogger
from litellm.types.integrations.base_health_check import IntegrationHealthCheckStatus

_BASE64_INLINE_PATTERN = re.compile(
r"data:(?:application|image|audio|video)/[a-zA-Z0-9.+-]+;base64,[A-Za-z0-9+/=\s]+",
Expand Down Expand Up @@ -354,3 +355,19 @@ async def async_send_message(self, payload: StandardLoggingPayload) -> None:
response.raise_for_status()
except Exception as e:
verbose_logger.exception(f"Error sending to SQS: {str(e)}")

async def async_health_check(self) -> IntegrationHealthCheckStatus:
"""
Health check for SQS by sending a small test message to the configured queue.
"""
try:
from litellm.litellm_core_utils.litellm_logging import (
create_dummy_standard_logging_payload,
)
# Create a minimal standard logging payload
standard_logging_object: StandardLoggingPayload = create_dummy_standard_logging_payload()
# Attempt to send a single message
await self.async_send_message(standard_logging_object)
return IntegrationHealthCheckStatus(status="healthy", error_message=None)
except Exception as e:
return IntegrationHealthCheckStatus(status="unhealthy", error_message=str(e))
12 changes: 12 additions & 0 deletions litellm/proxy/health_endpoints/_health_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Literal[
"slack_budget_alerts",
"langfuse",
"langfuse_otel",
"slack",
"openmeter",
"webhook",
Expand All @@ -46,6 +47,7 @@
"datadog",
"generic_api",
"arize",
"sqs"
],
str,
]
Expand Down Expand Up @@ -106,6 +108,7 @@ async def health_services_endpoint( # noqa: PLR0915
"slack_budget_alerts",
"email",
"langfuse",
"langfuse_otel",
"slack",
"openmeter",
"webhook",
Expand All @@ -116,6 +119,7 @@ async def health_services_endpoint( # noqa: PLR0915
"datadog",
"generic_api",
"arize",
"sqs"
]:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -196,6 +200,14 @@ async def health_services_endpoint( # noqa: PLR0915
type="user_budget",
user_info=user_info,
)
elif service == "sqs":
from litellm.integrations.sqs import SQSLogger
sqs_logger = SQSLogger()
response = await sqs_logger.async_health_check()
return {
"status": response["status"],
"message": response["error_message"],
}

if service == "slack" or service == "slack_budget_alerts":
if "slack" in general_settings.get("alerting", []):
Expand Down
24 changes: 24 additions & 0 deletions tests/logging_callback_tests/test_sqs_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,27 @@ async def test_strip_base64_recursive_redaction():
s = json.dumps(c).lower()
# allow "[base64_redacted]" but nothing else
assert "base64," not in s, f"Found real base64 blob in: {s}"


@pytest.mark.asyncio
async def test_async_health_check_healthy(monkeypatch):
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
monkeypatch.setattr(asyncio, "create_task", MagicMock())
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
logger.async_send_message = AsyncMock(return_value=None)

result = await logger.async_health_check()
assert result["status"] == "healthy"
assert result.get("error_message") is None


@pytest.mark.asyncio
async def test_async_health_check_unhealthy(monkeypatch):
monkeypatch.setattr("litellm.aws_sqs_callback_params", {})
monkeypatch.setattr(asyncio, "create_task", MagicMock())
logger = SQSLogger(sqs_queue_url="https://example.com", sqs_region_name="us-west-2")
logger.async_send_message = AsyncMock(side_effect=Exception("boom"))

result = await logger.async_health_check()
assert result["status"] == "unhealthy"
assert "boom" in (result.get("error_message") or "")
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import sys
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, patch, AsyncMock

sys.path.insert(
0, os.path.abspath("../../..")
Expand All @@ -16,6 +16,7 @@
from litellm.proxy.health_endpoints._health_endpoints import (
_db_health_readiness_check,
db_health_cache,
health_services_endpoint,
)


Expand Down Expand Up @@ -97,3 +98,31 @@ async def test_db_health_readiness_check_with_error_and_flag_off(prisma_error):

# Verify that the raised exception is the same
assert excinfo.value == prisma_error


@pytest.mark.asyncio
@pytest.mark.parametrize(
"status,error_message",
[
("healthy", ""),
("unhealthy", "queue not reachable"),
],
)
async def test_health_services_endpoint_sqs(status, error_message):
"""
Verify the /health/services SQS branch returns expected status and message
based on SQSLogger.async_health_check().
"""
with patch("litellm.integrations.sqs.SQSLogger") as MockSQSLogger:
mock_instance = MagicMock()
mock_instance.async_health_check = AsyncMock(
return_value={"status": status, "error_message": error_message}
)
MockSQSLogger.return_value = mock_instance

result = await health_services_endpoint(service="sqs")

assert result["status"] == status
assert result["message"] == error_message
mock_instance.async_health_check.assert_awaited_once()

Loading