diff --git a/.fernignore b/.fernignore index 2e8d6a74d..c879da38c 100644 --- a/.fernignore +++ b/.fernignore @@ -16,4 +16,6 @@ src/cohere/aws_client.py src/cohere/sagemaker_client.py src/cohere/client_v2.py mypy.ini -src/cohere/aliases.py \ No newline at end of file +src/cohere/aliases.py +src/cohere/v2/raw_client.py # remove when SSE updates are released +src/cohere/core/http_sse/* # remove when SSE updates are released \ No newline at end of file diff --git a/src/cohere/core/http_sse/__init__.py b/src/cohere/core/http_sse/__init__.py new file mode 100644 index 000000000..a87150c64 --- /dev/null +++ b/src/cohere/core/http_sse/__init__.py @@ -0,0 +1,14 @@ +from ._api import EventSource, aconnect_sse, connect_sse +from ._exceptions import SSEError +from ._models import ServerSentEvent + +__version__ = "0.4.1" + +__all__ = [ + "__version__", + "EventSource", + "connect_sse", + "aconnect_sse", + "ServerSentEvent", + "SSEError", +] diff --git a/src/cohere/core/http_sse/_api.py b/src/cohere/core/http_sse/_api.py new file mode 100644 index 000000000..32cf501ee --- /dev/null +++ b/src/cohere/core/http_sse/_api.py @@ -0,0 +1,93 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager, contextmanager +from typing import Any, AsyncIterator, Iterator, cast + +import httpx + +from ._decoders import SSEDecoder +from ._exceptions import SSEError +from ._models import ServerSentEvent + + +class EventSource: + def __init__(self, response: httpx.Response) -> None: + self._response = response + + def _check_content_type(self) -> None: + content_type = self._response.headers.get("content-type", "").partition(";")[0] + if "text/event-stream" not in content_type: + raise SSEError( + "Expected response header Content-Type to contain 'text/event-stream', " + f"got {content_type!r}" + ) + + @property + def response(self) -> httpx.Response: + return self._response + + def iter_sse(self) -> Iterator[ServerSentEvent]: + self._check_content_type() + decoder = SSEDecoder() + + buffer = "" + for chunk in self._response.iter_bytes(): + # Decode chunk and add to buffer + text_chunk = chunk.decode('utf-8', errors='replace') + buffer += text_chunk + + # Process complete lines + while '\n' in buffer: + line, buffer = buffer.split('\n', 1) + line = line.rstrip('\r') + sse = decoder.decode(line) + # when we reach a "\n\n" => line = '' + # => decoder will attempt to return an SSE Event + if sse is not None: + yield sse + + # Process any remaining data in buffer + if buffer.strip(): + line = buffer.rstrip('\r') + sse = decoder.decode(line) + if sse is not None: + yield sse + + async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]: + self._check_content_type() + decoder = SSEDecoder() + lines = cast(AsyncGenerator[str, None], self._response.aiter_lines()) + try: + async for line in lines: + line = line.rstrip("\n") + sse = decoder.decode(line) + if sse is not None: + yield sse + finally: + await lines.aclose() + + +@contextmanager +def connect_sse( + client: httpx.Client, method: str, url: str, **kwargs: Any +) -> Iterator[EventSource]: + headers = kwargs.pop("headers", {}) + headers["Accept"] = "text/event-stream" + headers["Cache-Control"] = "no-store" + + with client.stream(method, url, headers=headers, **kwargs) as response: + yield EventSource(response) + + +@asynccontextmanager +async def aconnect_sse( + client: httpx.AsyncClient, + method: str, + url: str, + **kwargs: Any, +) -> AsyncIterator[EventSource]: + headers = kwargs.pop("headers", {}) + headers["Accept"] = "text/event-stream" + headers["Cache-Control"] = "no-store" + + async with client.stream(method, url, headers=headers, **kwargs) as response: + yield EventSource(response) diff --git a/src/cohere/core/http_sse/_decoders.py b/src/cohere/core/http_sse/_decoders.py new file mode 100644 index 000000000..256c4c0bf --- /dev/null +++ b/src/cohere/core/http_sse/_decoders.py @@ -0,0 +1,64 @@ +from typing import List, Optional + +from ._models import ServerSentEvent + + +class SSEDecoder: + def __init__(self) -> None: + self._event = "" + self._data: List[str] = [] + self._last_event_id = "" + self._retry: Optional[int] = None + + def decode(self, line: str) -> Optional[ServerSentEvent]: + # See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501 + + if not line: + if ( + not self._event + and not self._data + and not self._last_event_id + and self._retry is None + ): + return None + + sse = ServerSentEvent( + event=self._event, + data="\n".join(self._data), + id=self._last_event_id, + retry=self._retry, + ) + + # NOTE: as per the SSE spec, do not reset last_event_id. + self._event = "" + self._data = [] + self._retry = None + + return sse + + if line.startswith(":"): + return None + + fieldname, _, value = line.partition(":") + + if value.startswith(" "): + value = value[1:] + + if fieldname == "event": + self._event = value + elif fieldname == "data": + self._data.append(value) + elif fieldname == "id": + if "\0" in value: + pass + else: + self._last_event_id = value + elif fieldname == "retry": + try: + self._retry = int(value) + except (TypeError, ValueError): + pass + else: + pass # Field is ignored. + + return None diff --git a/src/cohere/core/http_sse/_exceptions.py b/src/cohere/core/http_sse/_exceptions.py new file mode 100644 index 000000000..cd2c4d287 --- /dev/null +++ b/src/cohere/core/http_sse/_exceptions.py @@ -0,0 +1,5 @@ +import httpx + + +class SSEError(httpx.TransportError): + pass diff --git a/src/cohere/core/http_sse/_models.py b/src/cohere/core/http_sse/_models.py new file mode 100644 index 000000000..6b398ee68 --- /dev/null +++ b/src/cohere/core/http_sse/_models.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass +import json +from typing import Any, Optional + + +@dataclass(frozen=True) +class ServerSentEvent: + event: str = "message" + data: str = "" + id: str = "" + retry: Optional[int] = None + + def json(self) -> Any: + """Parse the data field as JSON.""" + return json.loads(self.data) diff --git a/src/cohere/v2/raw_client.py b/src/cohere/v2/raw_client.py index 630d5d3d9..aa0a4dce2 100644 --- a/src/cohere/v2/raw_client.py +++ b/src/cohere/v2/raw_client.py @@ -1,11 +1,13 @@ # This file was auto-generated by Fern from our API Definition. import contextlib +from dataclasses import asdict import json +import logging import typing from json.decoder import JSONDecodeError -import httpx_sse +from ..core.http_sse._api import EventSource from ..core.api_error import ApiError from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper from ..core.http_response import AsyncHttpResponse, HttpResponse @@ -44,6 +46,8 @@ from .types.v2embed_request_truncate import V2EmbedRequestTruncate from .types.v2rerank_response import V2RerankResponse +logger = logging.getLogger(__name__) + # this is used as the default value for optional parameters OMIT = typing.cast(typing.Any, ...) @@ -224,20 +228,29 @@ def _stream() -> HttpResponse[typing.Iterator[V2ChatStreamResponse]]: if 200 <= _response.status_code < 300: def _iter(): - _event_source = httpx_sse.EventSource(_response) + _event_source = EventSource(_response) for _sse in _event_source.iter_sse(): - if _sse.data == None: - return try: yield typing.cast( V2ChatStreamResponse, construct_type( type_=V2ChatStreamResponse, # type: ignore - object_=json.loads(_sse.data), + object_=_sse.json(), ), ) - except Exception: - pass + except json.JSONDecodeError as e: + logger.warning( + f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}" + ) + except (TypeError, ValueError, KeyError, AttributeError) as e: + logger.warning( + f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" + ) + except Exception as e: + logger.error( + f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" + ) + return return HttpResponse(response=_response, data=_iter()) @@ -1320,20 +1333,29 @@ async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamRespon if 200 <= _response.status_code < 300: async def _iter(): - _event_source = httpx_sse.EventSource(_response) + _event_source = EventSource(_response) async for _sse in _event_source.aiter_sse(): - if _sse.data == None: - return try: yield typing.cast( V2ChatStreamResponse, construct_type( type_=V2ChatStreamResponse, # type: ignore - object_=json.loads(_sse.data), + object_=_sse.json(), ), ) - except Exception: - pass + except json.JSONDecodeError as e: + logger.warning( + f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}" + ) + except (TypeError, ValueError, KeyError, AttributeError) as e: + logger.warning( + f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}" + ) + except Exception as e: + logger.error( + f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}" + ) + return return AsyncHttpResponse(response=_response, data=_iter())