diff --git a/any_llm_client/clients/openai.py b/any_llm_client/clients/openai.py index 034a662..e5fdded 100644 --- a/any_llm_client/clients/openai.py +++ b/any_llm_client/clients/openai.py @@ -86,7 +86,7 @@ class OneStreamingChoice(pydantic.BaseModel): class ChatCompletionsStreamingEvent(pydantic.BaseModel): - choices: typing.Annotated[list[OneStreamingChoice], annotated_types.MinLen(1)] + choices: list[OneStreamingChoice] class OneNotStreamingChoiceMessage(pydantic.BaseModel): @@ -269,7 +269,8 @@ async def _iter_response_chunks(self, response: httpx.Response) -> typing.AsyncI _handle_validation_error(content=event.data.encode(), original_error=validation_error) if not ( - (validated_delta := validated_response.choices[0].delta) + (validated_choices := validated_response.choices) + and (validated_delta := validated_choices[0].delta) and (validated_delta.content or validated_delta.reasoning_content) ): continue diff --git a/tests/test_openai_client.py b/tests/test_openai_client.py index ef79c8e..0ba37b9 100644 --- a/tests/test_openai_client.py +++ b/tests/test_openai_client.py @@ -91,6 +91,7 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non + ChatCompletionsStreamingEvent(choices=[OneStreamingChoice(delta=one_message)]).model_dump_json() for one_message in generated_messages ) + + f"\n\ndata: {ChatCompletionsStreamingEvent(choices=[]).model_dump_json()}" + f"\n\ndata: [DONE]\n\ndata: {faker.pystr()}\n\n" ) response: typing.Final = httpx.Response( @@ -104,23 +105,6 @@ async def test_ok(self, faker: faker.Faker, func_request: LLMFuncRequest) -> Non assert result == expected_result - async def test_fails_without_alternatives(self) -> None: - response_content: typing.Final = ( - f"data: {ChatCompletionsStreamingEvent.model_construct(choices=[]).model_dump_json()}\n\n" - ) - response: typing.Final = httpx.Response( - 200, - headers={"Content-Type": "text/event-stream"}, - content=response_content, - ) - client: typing.Final = any_llm_client.get_client( - OpenAIConfigFactory.build(), - transport=httpx.MockTransport(lambda _: response), - ) - - with pytest.raises(LLMResponseValidationError): - await consume_llm_message_chunks(client.stream_llm_message_chunks(**LLMFuncRequestFactory.build())) - class TestOpenAILLMErrors: @pytest.mark.parametrize("stream", [True, False])