Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions ee/api/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@
from posthog.exceptions_capture import capture_exception
from posthog.models.user import User
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
from posthog.temporal.ai.conversation import AssistantConversationRunnerWorkflowInputs
from posthog.temporal.ai.chat_agent import (
CHAT_AGENT_STREAM_MAX_LENGTH,
CHAT_AGENT_WORKFLOW_TIMEOUT,
AssistantConversationRunnerWorkflow,
AssistantConversationRunnerWorkflowInputs,
)
from posthog.utils import get_instance_region

from ee.hogai.agent.executor import AgentExecutor
from ee.hogai.api.serializers import ConversationSerializer
from ee.hogai.stream.conversation_stream import ConversationStreamManager
from ee.hogai.utils.aio import async_to_sync
from ee.hogai.utils.sse import AssistantSSESerializer
from ee.hogai.utils.types.base import AssistantMode
Expand Down Expand Up @@ -131,7 +136,8 @@ def create(self, request: Request, *args, **kwargs):

has_message = serializer.validated_data.get("content") is not None
is_deep_research = serializer.validated_data.get("deep_research_mode", False)
mode = AssistantMode.DEEP_RESEARCH if is_deep_research else AssistantMode.ASSISTANT
if is_deep_research:
raise NotImplementedError("Deep research is not supported yet")
Comment on lines +139 to +140
Copy link

Copilot AI Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for deep research mode was replaced with a NotImplementedError. While this prevents the use of an incomplete feature, the variable is_deep_research is still being extracted from serializer.validated_data on line 140 but is no longer used. Consider removing the unused variable assignment.

Copilot uses AI. Check for mistakes.

is_new_conversation = False
# Safely set the lookup kwarg for potential error handling
Expand Down Expand Up @@ -175,15 +181,18 @@ def create(self, request: Request, *args, **kwargs):
trace_id=serializer.validated_data["trace_id"],
session_id=request.headers.get("X-POSTHOG-SESSION-ID"), # Relies on posthog-js __add_tracing_headers
billing_context=serializer.validated_data.get("billing_context"),
mode=mode,
mode=AssistantMode.ASSISTANT,
)
workflow_class = AssistantConversationRunnerWorkflow

async def async_stream(
workflow_inputs: AssistantConversationRunnerWorkflowInputs,
) -> AsyncGenerator[bytes, None]:
serializer = AssistantSSESerializer()
stream_manager = ConversationStreamManager(conversation)
async for chunk in stream_manager.astream(workflow_inputs):
stream_manager = AgentExecutor(
conversation, timeout=CHAT_AGENT_WORKFLOW_TIMEOUT, max_length=CHAT_AGENT_STREAM_MAX_LENGTH
)
async for chunk in stream_manager.astream(workflow_class, workflow_inputs):
yield serializer.dumps(chunk).encode("utf-8")

return StreamingHttpResponse(
Expand All @@ -201,8 +210,8 @@ def cancel(self, request: Request, *args, **kwargs):
return Response(status=status.HTTP_204_NO_CONTENT)

async def cancel_workflow():
conversation_manager = ConversationStreamManager(conversation)
await conversation_manager.cancel_conversation()
agent_executor = AgentExecutor(conversation)
await agent_executor.cancel_workflow()

try:
asgi_async_to_sync(cancel_workflow)()
Expand Down
30 changes: 15 additions & 15 deletions ee/api/test/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_create_conversation(self):
conversation_id = str(uuid.uuid4())

with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_start_workflow_and_stream:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand All @@ -120,7 +120,7 @@ def test_create_conversation(self):
# Check that the method was called with workflow_inputs
mock_start_workflow_and_stream.assert_called_once()
call_args = mock_start_workflow_and_stream.call_args
workflow_inputs = call_args[0][0]
workflow_inputs = call_args[0][1]
self.assertEqual(workflow_inputs.user_id, self.user.id)
self.assertEqual(workflow_inputs.is_new_conversation, True)
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
Expand All @@ -129,7 +129,7 @@ def test_create_conversation(self):

def test_add_message_to_existing_conversation(self):
with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_start_workflow_and_stream:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand All @@ -149,7 +149,7 @@ def test_add_message_to_existing_conversation(self):
# Check that the method was called with workflow_inputs
mock_start_workflow_and_stream.assert_called_once()
call_args = mock_start_workflow_and_stream.call_args
workflow_inputs = call_args[0][0]
workflow_inputs = call_args[0][1]
self.assertEqual(workflow_inputs.user_id, self.user.id)
self.assertEqual(workflow_inputs.is_new_conversation, False)
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
Expand Down Expand Up @@ -195,7 +195,7 @@ def test_invalid_message_format(self):
def test_rate_limit_burst(self):
# Create multiple requests to trigger burst rate limit
with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
):
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_none_content_with_existing_conversation(self):
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
)
with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_stream_conversation:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_missing_trace_id(self):

def test_nonexistent_conversation(self):
with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
):
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand Down Expand Up @@ -296,7 +296,7 @@ def test_unauthenticated_request(self):
)
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)

@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
def test_cancel_conversation(self, mock_cancel):
conversation = Conversation.objects.create(
user=self.user,
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_cancel_other_teams_conversation(self):
)
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)

@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
def test_cancel_conversation_with_async_cleanup(self, mock_cancel):
"""Test that cancel endpoint properly handles async cleanup."""
conversation = Conversation.objects.create(
Expand All @@ -360,7 +360,7 @@ def test_cancel_conversation_with_async_cleanup(self, mock_cancel):

self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)

@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
def test_cancel_conversation_async_cleanup_failure(self, mock_cancel):
"""Test cancel endpoint behavior when async cleanup fails."""
conversation = Conversation.objects.create(
Expand Down Expand Up @@ -428,7 +428,7 @@ def test_stream_from_in_progress_conversation(self):
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
)
with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_stream_conversation:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand Down Expand Up @@ -606,7 +606,7 @@ def test_billing_context_validation_valid_data(self):
conversation = Conversation.objects.create(user=self.user, team=self.team)

with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_start_workflow_and_stream:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand All @@ -622,15 +622,15 @@ def test_billing_context_validation_valid_data(self):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
call_args = mock_start_workflow_and_stream.call_args
workflow_inputs = call_args[0][0]
workflow_inputs = call_args[0][1]
self.assertEqual(workflow_inputs.billing_context, self.billing_context)

def test_billing_context_validation_invalid_data(self):
"""Test that invalid billing context data is rejected."""
conversation = Conversation.objects.create(user=self.user, team=self.team)

with patch(
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
"ee.hogai.agent.executor.AgentExecutor.astream",
return_value=_async_generator(),
) as mock_start_workflow_and_stream:
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
Expand All @@ -646,5 +646,5 @@ def test_billing_context_validation_invalid_data(self):
)
self.assertEqual(response.status_code, status.HTTP_200_OK)
call_args = mock_start_workflow_and_stream.call_args
workflow_inputs = call_args[0][0]
workflow_inputs = call_args[0][1]
self.assertEqual(workflow_inputs.billing_context, None)
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@

from posthog.schema import AssistantEventType, FailureMessage

from posthog.temporal.ai.conversation import (
AssistantConversationRunnerWorkflow,
AssistantConversationRunnerWorkflowInputs,
)
from posthog.temporal.ai.base import AgentBaseWorkflow
from posthog.temporal.common.client import async_connect

from ee.hogai.stream.redis_stream import (
from ee.hogai.agent.redis_stream import (
CONVERSATION_STREAM_MAX_LENGTH,
CONVERSATION_STREAM_TIMEOUT,
ConversationEvent,
ConversationRedisStream,
GenerationStatusEvent,
Expand All @@ -33,38 +32,44 @@
logger = structlog.get_logger(__name__)


class ConversationStreamManager:
"""Manages conversation streaming from Redis streams."""
class AgentExecutor:
"""Manages executing an agent workflow and streaming the output."""

def __init__(self, conversation: Conversation) -> None:
def __init__(
self,
conversation: Conversation,
timeout: int = CONVERSATION_STREAM_TIMEOUT,
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
) -> None:
self._conversation = conversation
self._redis_stream = ConversationRedisStream(get_conversation_stream_key(conversation.id))
self._redis_stream = ConversationRedisStream(
get_conversation_stream_key(conversation.id), timeout=timeout, max_length=max_length
)
self._workflow_id = f"conversation-{conversation.id}"

async def astream(
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
) -> AsyncGenerator[AssistantOutput, Any]:
"""Stream conversation updates from Redis stream.
async def astream(self, workflow: type[AgentBaseWorkflow], inputs: Any) -> AsyncGenerator[AssistantOutput, Any]:
"""Stream agent workflow updates from Redis stream.

Args:
workflow_inputs: Temporal workflow inputs
workflow: Agent temporal workflow class
inputs: Agent temporal workflow inputs

Returns:
AssistantOutput generator
"""
# If this is a reconnection attempt, we resume streaming
if self._conversation.status != Conversation.Status.IDLE:
if workflow_inputs.message is not None:
if inputs.message is not None:
raise ValueError("Cannot resume streaming with a new message")
async for chunk in self.stream_conversation():
yield chunk
else:
# Otherwise, process the new message (new generation) or resume generation (no new message)
async for chunk in self.start_workflow(workflow_inputs):
async for chunk in self.start_workflow(workflow, inputs):
yield chunk

async def start_workflow(
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
self, workflow: type[AgentBaseWorkflow], inputs: Any
) -> AsyncGenerator[AssistantOutput, Any]:
try:
# Delete the stream to ensure we start fresh
Expand All @@ -74,8 +79,8 @@ async def start_workflow(
client = await async_connect()

handle = await client.start_workflow(
AssistantConversationRunnerWorkflow.run,
workflow_inputs,
workflow.run,
inputs,
id=self._workflow_id,
task_queue=settings.MAX_AI_TASK_QUEUE,
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
Expand Down Expand Up @@ -173,7 +178,7 @@ def _failure_message(self) -> AssistantOutput:
)
return (AssistantEventType.MESSAGE, failure_message)

async def cancel_conversation(self) -> None:
async def cancel_workflow(self) -> None:
"""Cancel the current conversation and clean up resources.

Raises:
Expand Down
19 changes: 13 additions & 6 deletions ee/hogai/stream/redis_stream.py → ee/hogai/agent/redis_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,18 @@ class StreamError(Exception):
class ConversationRedisStream:
"""Manages conversation streaming from Redis streams."""

def __init__(self, stream_key: str):
def __init__(
self,
stream_key: str,
timeout: int = CONVERSATION_STREAM_TIMEOUT,
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
):
self._stream_key = stream_key
self._redis_client = get_async_client(settings.REDIS_URL)
self._deletion_lock = asyncio.Lock()
self._serializer = ConversationStreamSerializer()
self._timeout = timeout
self._max_length = max_length

async def wait_for_stream(self) -> bool:
"""Wait for stream to be created using linear backoff.
Expand Down Expand Up @@ -220,7 +227,7 @@ async def read_stream(
start_time = asyncio.get_event_loop().time()

while True:
if asyncio.get_event_loop().time() - start_time > CONVERSATION_STREAM_TIMEOUT:
if asyncio.get_event_loop().time() - start_time > self._timeout:
raise StreamError("Stream timeout - conversation took too long to complete")

try:
Expand Down Expand Up @@ -283,15 +290,15 @@ async def write_to_stream(
callback: Callback to trigger after each message is written to the stream
"""
try:
await self._redis_client.expire(self._stream_key, CONVERSATION_STREAM_TIMEOUT)
await self._redis_client.expire(self._stream_key, self._timeout)

async for chunk in generator:
message = self._serializer.dumps(chunk)
if message is not None:
await self._redis_client.xadd(
self._stream_key,
message,
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
maxlen=self._max_length,
approximate=True,
)
if callback:
Expand All @@ -303,7 +310,7 @@ async def write_to_stream(
await self._redis_client.xadd(
self._stream_key,
completion_message,
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
maxlen=self._max_length,
approximate=True,
)

Expand All @@ -314,7 +321,7 @@ async def write_to_stream(
await self._redis_client.xadd(
self._stream_key,
message,
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
maxlen=self._max_length,
approximate=True,
)
raise StreamError("Failed to write to stream")
Loading
Loading