Skip to content
4 changes: 3 additions & 1 deletion .fernignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ src/cohere/aws_client.py
src/cohere/sagemaker_client.py
src/cohere/client_v2.py
mypy.ini
src/cohere/aliases.py
src/cohere/aliases.py
src/cohere/v2/raw_client.py # remove when SSE updates are released
src/cohere/core/http_sse/* # remove when SSE updates are released
14 changes: 14 additions & 0 deletions src/cohere/core/http_sse/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from ._api import EventSource, aconnect_sse, connect_sse
from ._exceptions import SSEError
from ._models import ServerSentEvent

__version__ = "0.4.1"

__all__ = [
"__version__",
"EventSource",
"connect_sse",
"aconnect_sse",
"ServerSentEvent",
"SSEError",
]
93 changes: 93 additions & 0 deletions src/cohere/core/http_sse/_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncIterator, Iterator, cast

import httpx

from ._decoders import SSEDecoder
from ._exceptions import SSEError
from ._models import ServerSentEvent


class EventSource:
def __init__(self, response: httpx.Response) -> None:
self._response = response

def _check_content_type(self) -> None:
content_type = self._response.headers.get("content-type", "").partition(";")[0]
if "text/event-stream" not in content_type:
raise SSEError(
"Expected response header Content-Type to contain 'text/event-stream', "
f"got {content_type!r}"
)

@property
def response(self) -> httpx.Response:
return self._response

def iter_sse(self) -> Iterator[ServerSentEvent]:
self._check_content_type()
decoder = SSEDecoder()

buffer = ""
for chunk in self._response.iter_bytes():
# Decode chunk and add to buffer
text_chunk = chunk.decode('utf-8', errors='replace')
buffer += text_chunk

# Process complete lines
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.rstrip('\r')
sse = decoder.decode(line)
# when we reach a "\n\n" => line = ''
# => decoder will attempt to return an SSE Event
if sse is not None:
yield sse

# Process any remaining data in buffer
if buffer.strip():
line = buffer.rstrip('\r')
sse = decoder.decode(line)
if sse is not None:
yield sse

async def aiter_sse(self) -> AsyncGenerator[ServerSentEvent, None]:
self._check_content_type()
decoder = SSEDecoder()
lines = cast(AsyncGenerator[str, None], self._response.aiter_lines())
try:
async for line in lines:
line = line.rstrip("\n")
sse = decoder.decode(line)
if sse is not None:
yield sse
finally:
await lines.aclose()


@contextmanager
def connect_sse(
client: httpx.Client, method: str, url: str, **kwargs: Any
) -> Iterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)


@asynccontextmanager
async def aconnect_sse(
client: httpx.AsyncClient,
method: str,
url: str,
**kwargs: Any,
) -> AsyncIterator[EventSource]:
headers = kwargs.pop("headers", {})
headers["Accept"] = "text/event-stream"
headers["Cache-Control"] = "no-store"

async with client.stream(method, url, headers=headers, **kwargs) as response:
yield EventSource(response)
64 changes: 64 additions & 0 deletions src/cohere/core/http_sse/_decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import List, Optional

from ._models import ServerSentEvent


class SSEDecoder:
def __init__(self) -> None:
self._event = ""
self._data: List[str] = []
self._last_event_id = ""
self._retry: Optional[int] = None

def decode(self, line: str) -> Optional[ServerSentEvent]:
# See: https://html.spec.whatwg.org/multipage/server-sent-events.html#event-stream-interpretation # noqa: E501

if not line:
if (
not self._event
and not self._data
and not self._last_event_id
and self._retry is None
):
return None

sse = ServerSentEvent(
event=self._event,
data="\n".join(self._data),
id=self._last_event_id,
retry=self._retry,
)

# NOTE: as per the SSE spec, do not reset last_event_id.
self._event = ""
self._data = []
self._retry = None

return sse

if line.startswith(":"):
return None

fieldname, _, value = line.partition(":")

if value.startswith(" "):
value = value[1:]

if fieldname == "event":
self._event = value
elif fieldname == "data":
self._data.append(value)
elif fieldname == "id":
if "\0" in value:
pass
else:
self._last_event_id = value
elif fieldname == "retry":
try:
self._retry = int(value)
except (TypeError, ValueError):
pass
else:
pass # Field is ignored.

return None
5 changes: 5 additions & 0 deletions src/cohere/core/http_sse/_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import httpx


class SSEError(httpx.TransportError):
pass
15 changes: 15 additions & 0 deletions src/cohere/core/http_sse/_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from dataclasses import dataclass
import json
from typing import Any, Optional


@dataclass(frozen=True)
class ServerSentEvent:
event: str = "message"
data: str = ""
id: str = ""
retry: Optional[int] = None

def json(self) -> Any:
"""Parse the data field as JSON."""
return json.loads(self.data)
48 changes: 35 additions & 13 deletions src/cohere/v2/raw_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# This file was auto-generated by Fern from our API Definition.

import contextlib
from dataclasses import asdict
import json
import logging
import typing
from json.decoder import JSONDecodeError

import httpx_sse
from ..core.http_sse._api import EventSource
from ..core.api_error import ApiError
from ..core.client_wrapper import AsyncClientWrapper, SyncClientWrapper
from ..core.http_response import AsyncHttpResponse, HttpResponse
Expand Down Expand Up @@ -44,6 +46,8 @@
from .types.v2embed_request_truncate import V2EmbedRequestTruncate
from .types.v2rerank_response import V2RerankResponse

logger = logging.getLogger(__name__)

# this is used as the default value for optional parameters
OMIT = typing.cast(typing.Any, ...)

Expand Down Expand Up @@ -224,20 +228,29 @@ def _stream() -> HttpResponse[typing.Iterator[V2ChatStreamResponse]]:
if 200 <= _response.status_code < 300:

def _iter():
_event_source = httpx_sse.EventSource(_response)
_event_source = EventSource(_response)
for _sse in _event_source.iter_sse():
if _sse.data == None:
return
try:
yield typing.cast(
V2ChatStreamResponse,
construct_type(
type_=V2ChatStreamResponse, # type: ignore
object_=json.loads(_sse.data),
object_=_sse.json(),
),
)
except Exception:
pass
except json.JSONDecodeError as e:
logger.warning(
f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}"
)
except (TypeError, ValueError, KeyError, AttributeError) as e:
logger.warning(
f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}"
)
except Exception as e:
logger.error(
f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}"
)

return

return HttpResponse(response=_response, data=_iter())
Expand Down Expand Up @@ -1320,20 +1333,29 @@ async def _stream() -> AsyncHttpResponse[typing.AsyncIterator[V2ChatStreamRespon
if 200 <= _response.status_code < 300:

async def _iter():
_event_source = httpx_sse.EventSource(_response)
_event_source = EventSource(_response)
async for _sse in _event_source.aiter_sse():
if _sse.data == None:
return
try:
yield typing.cast(
V2ChatStreamResponse,
construct_type(
type_=V2ChatStreamResponse, # type: ignore
object_=json.loads(_sse.data),
object_=_sse.json(),
),
)
except Exception:
pass
except json.JSONDecodeError as e:
logger.warning(
f"Skipping SSE event with invalid JSON: {e}, sse: {_sse!r}"
)
except (TypeError, ValueError, KeyError, AttributeError) as e:
logger.warning(
f"Skipping SSE event due to model construction error: {type(e).__name__}: {e}, sse: {_sse!r}"
)
except Exception as e:
logger.error(
f"Unexpected error processing SSE event: {type(e).__name__}: {e}, sse: {_sse!r}"
)

return

return AsyncHttpResponse(response=_response, data=_iter())
Expand Down
Loading