Skip to content

Commit 8596f89

Browse files
fix(integrations): enhance input handling for embeddings in LiteLLM integration (#5127)
#### Issues Closes https://linear.app/getsentry/issue/TET-1461/fix-embedding-support-for-litellm
1 parent 0e6e808 commit 8596f89

File tree

2 files changed

+207
-34
lines changed

2 files changed

+207
-34
lines changed

sentry_sdk/integrations/litellm.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,40 @@ def _input_callback(kwargs):
7777
set_data_normalized(span, SPANDATA.GEN_AI_SYSTEM, provider)
7878
set_data_normalized(span, SPANDATA.GEN_AI_OPERATION_NAME, operation)
7979

80-
# Record messages if allowed
81-
messages = kwargs.get("messages", [])
82-
if messages and should_send_default_pii() and integration.include_prompts:
83-
scope = sentry_sdk.get_current_scope()
84-
messages_data = truncate_and_annotate_messages(messages, span, scope)
85-
if messages_data is not None:
86-
set_data_normalized(
87-
span, SPANDATA.GEN_AI_REQUEST_MESSAGES, messages_data, unpack=False
88-
)
80+
# Record input/messages if allowed
81+
if should_send_default_pii() and integration.include_prompts:
82+
if operation == "embeddings":
83+
# For embeddings, look for the 'input' parameter
84+
embedding_input = kwargs.get("input")
85+
if embedding_input:
86+
scope = sentry_sdk.get_current_scope()
87+
# Normalize to list format
88+
input_list = (
89+
embedding_input
90+
if isinstance(embedding_input, list)
91+
else [embedding_input]
92+
)
93+
messages_data = truncate_and_annotate_messages(input_list, span, scope)
94+
if messages_data is not None:
95+
set_data_normalized(
96+
span,
97+
SPANDATA.GEN_AI_EMBEDDINGS_INPUT,
98+
messages_data,
99+
unpack=False,
100+
)
101+
else:
102+
# For chat, look for the 'messages' parameter
103+
messages = kwargs.get("messages", [])
104+
if messages:
105+
scope = sentry_sdk.get_current_scope()
106+
messages_data = truncate_and_annotate_messages(messages, span, scope)
107+
if messages_data is not None:
108+
set_data_normalized(
109+
span,
110+
SPANDATA.GEN_AI_REQUEST_MESSAGES,
111+
messages_data,
112+
unpack=False,
113+
)
89114

90115
# Record other parameters
91116
params = {

tests/integrations/litellm/test_litellm.py

Lines changed: 173 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import pytest
3+
import time
34
from unittest import mock
45
from datetime import datetime
56

@@ -17,6 +18,7 @@ async def __call__(self, *args, **kwargs):
1718
except ImportError:
1819
pytest.skip("litellm not installed", allow_module_level=True)
1920

21+
import sentry_sdk
2022
from sentry_sdk import start_transaction
2123
from sentry_sdk.consts import OP, SPANDATA
2224
from sentry_sdk.integrations.litellm import (
@@ -31,6 +33,36 @@ async def __call__(self, *args, **kwargs):
3133
LITELLM_VERSION = package_version("litellm")
3234

3335

36+
@pytest.fixture
37+
def clear_litellm_cache():
38+
"""
39+
Clear litellm's client cache and reset integration state to ensure test isolation.
40+
41+
The LiteLLM integration uses setup_once() which only runs once per Python process.
42+
This fixture ensures the integration is properly re-initialized for each test.
43+
"""
44+
45+
# Stop all existing mocks
46+
mock.patch.stopall()
47+
48+
# Clear client cache
49+
if (
50+
hasattr(litellm, "in_memory_llm_clients_cache")
51+
and litellm.in_memory_llm_clients_cache
52+
):
53+
litellm.in_memory_llm_clients_cache.flush_cache()
54+
55+
yield
56+
57+
# Clean up after test as well
58+
mock.patch.stopall()
59+
if (
60+
hasattr(litellm, "in_memory_llm_clients_cache")
61+
and litellm.in_memory_llm_clients_cache
62+
):
63+
litellm.in_memory_llm_clients_cache.flush_cache()
64+
65+
3466
# Mock response objects
3567
class MockMessage:
3668
def __init__(self, role="assistant", content="Test response"):
@@ -87,6 +119,21 @@ def __init__(self, model="text-embedding-ada-002", data=None, usage=None):
87119
)
88120
self.object = "list"
89121

122+
def model_dump(self):
123+
return {
124+
"model": self.model,
125+
"data": [
126+
{"embedding": d.embedding, "index": d.index, "object": d.object}
127+
for d in self.data
128+
],
129+
"usage": {
130+
"prompt_tokens": self.usage.prompt_tokens,
131+
"completion_tokens": self.usage.completion_tokens,
132+
"total_tokens": self.usage.total_tokens,
133+
},
134+
"object": self.object,
135+
}
136+
90137

91138
@pytest.mark.parametrize(
92139
"send_default_pii, include_prompts",
@@ -201,44 +248,145 @@ def test_streaming_chat_completion(
201248
assert span["data"][SPANDATA.GEN_AI_RESPONSE_STREAMING] is True
202249

203250

204-
def test_embeddings_create(sentry_init, capture_events):
251+
def test_embeddings_create(sentry_init, capture_events, clear_litellm_cache):
252+
"""
253+
Test that litellm.embedding() calls are properly instrumented.
254+
255+
This test calls the actual litellm.embedding() function (not just callbacks)
256+
to ensure proper integration testing.
257+
"""
205258
sentry_init(
206259
integrations=[LiteLLMIntegration(include_prompts=True)],
207260
traces_sample_rate=1.0,
208261
send_default_pii=True,
209262
)
210263
events = capture_events()
211264

212-
messages = [{"role": "user", "content": "Some text to test embeddings"}]
213265
mock_response = MockEmbeddingResponse()
214266

215-
with start_transaction(name="litellm test"):
216-
kwargs = {
217-
"model": "text-embedding-ada-002",
218-
"input": "Hello!",
219-
"messages": messages,
220-
"call_type": "embedding",
221-
}
267+
# Mock within the test to ensure proper ordering with cache clearing
268+
with mock.patch(
269+
"litellm.openai_chat_completions.make_sync_openai_embedding_request"
270+
) as mock_http:
271+
# The function returns (headers, response)
272+
mock_http.return_value = ({}, mock_response)
273+
274+
with start_transaction(name="litellm test"):
275+
response = litellm.embedding(
276+
model="text-embedding-ada-002",
277+
input="Hello, world!",
278+
api_key="test-key", # Provide a fake API key to avoid authentication errors
279+
)
280+
# Allow time for callbacks to complete (they may run in separate threads)
281+
time.sleep(0.1)
282+
283+
# Response is processed by litellm, so just check it exists
284+
assert response is not None
285+
assert len(events) == 1
286+
(event,) = events
287+
288+
assert event["type"] == "transaction"
289+
assert len(event["spans"]) == 1
290+
(span,) = event["spans"]
291+
292+
assert span["op"] == OP.GEN_AI_EMBEDDINGS
293+
assert span["description"] == "embeddings text-embedding-ada-002"
294+
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
295+
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
296+
assert span["data"][SPANDATA.GEN_AI_REQUEST_MODEL] == "text-embedding-ada-002"
297+
# Check that embeddings input is captured (it's JSON serialized)
298+
embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
299+
assert json.loads(embeddings_input) == ["Hello, world!"]
300+
301+
302+
def test_embeddings_create_with_list_input(
303+
sentry_init, capture_events, clear_litellm_cache
304+
):
305+
"""Test embedding with list input."""
306+
sentry_init(
307+
integrations=[LiteLLMIntegration(include_prompts=True)],
308+
traces_sample_rate=1.0,
309+
send_default_pii=True,
310+
)
311+
events = capture_events()
222312

223-
_input_callback(kwargs)
224-
_success_callback(
225-
kwargs,
226-
mock_response,
227-
datetime.now(),
228-
datetime.now(),
229-
)
313+
mock_response = MockEmbeddingResponse()
230314

231-
assert len(events) == 1
232-
(event,) = events
315+
# Mock within the test to ensure proper ordering with cache clearing
316+
with mock.patch(
317+
"litellm.openai_chat_completions.make_sync_openai_embedding_request"
318+
) as mock_http:
319+
# The function returns (headers, response)
320+
mock_http.return_value = ({}, mock_response)
321+
322+
with start_transaction(name="litellm test"):
323+
response = litellm.embedding(
324+
model="text-embedding-ada-002",
325+
input=["First text", "Second text", "Third text"],
326+
api_key="test-key", # Provide a fake API key to avoid authentication errors
327+
)
328+
# Allow time for callbacks to complete (they may run in separate threads)
329+
time.sleep(0.1)
330+
331+
# Response is processed by litellm, so just check it exists
332+
assert response is not None
333+
assert len(events) == 1
334+
(event,) = events
335+
336+
assert event["type"] == "transaction"
337+
assert len(event["spans"]) == 1
338+
(span,) = event["spans"]
339+
340+
assert span["op"] == OP.GEN_AI_EMBEDDINGS
341+
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
342+
# Check that list of embeddings input is captured (it's JSON serialized)
343+
embeddings_input = span["data"][SPANDATA.GEN_AI_EMBEDDINGS_INPUT]
344+
assert json.loads(embeddings_input) == [
345+
"First text",
346+
"Second text",
347+
"Third text",
348+
]
349+
350+
351+
def test_embeddings_no_pii(sentry_init, capture_events, clear_litellm_cache):
352+
"""Test that PII is not captured when disabled."""
353+
sentry_init(
354+
integrations=[LiteLLMIntegration(include_prompts=True)],
355+
traces_sample_rate=1.0,
356+
send_default_pii=False, # PII disabled
357+
)
358+
events = capture_events()
233359

234-
assert event["type"] == "transaction"
235-
assert len(event["spans"]) == 1
236-
(span,) = event["spans"]
360+
mock_response = MockEmbeddingResponse()
361+
362+
# Mock within the test to ensure proper ordering with cache clearing
363+
with mock.patch(
364+
"litellm.openai_chat_completions.make_sync_openai_embedding_request"
365+
) as mock_http:
366+
# The function returns (headers, response)
367+
mock_http.return_value = ({}, mock_response)
368+
369+
with start_transaction(name="litellm test"):
370+
response = litellm.embedding(
371+
model="text-embedding-ada-002",
372+
input="Hello, world!",
373+
api_key="test-key", # Provide a fake API key to avoid authentication errors
374+
)
375+
# Allow time for callbacks to complete (they may run in separate threads)
376+
time.sleep(0.1)
377+
378+
# Response is processed by litellm, so just check it exists
379+
assert response is not None
380+
assert len(events) == 1
381+
(event,) = events
382+
383+
assert event["type"] == "transaction"
384+
assert len(event["spans"]) == 1
385+
(span,) = event["spans"]
237386

238-
assert span["op"] == OP.GEN_AI_EMBEDDINGS
239-
assert span["description"] == "embeddings text-embedding-ada-002"
240-
assert span["data"][SPANDATA.GEN_AI_OPERATION_NAME] == "embeddings"
241-
assert span["data"][SPANDATA.GEN_AI_USAGE_INPUT_TOKENS] == 5
387+
assert span["op"] == OP.GEN_AI_EMBEDDINGS
388+
# Check that embeddings input is NOT captured when PII is disabled
389+
assert SPANDATA.GEN_AI_EMBEDDINGS_INPUT not in span["data"]
242390

243391

244392
def test_exception_handling(sentry_init, capture_events):

0 commit comments

Comments
 (0)