Skip to content

Commit a39feb2

Browse files
authored
refactor(ph-ai): conversation stream manager to agent executor (#41613)
## Problem `ConversationStreamManager` is now `AgentExecutor`, a class to run an agent workflow and stream its content, taking as input the workflow type and workflow inputs. This will allow us to run deep research using a different workflow than the base chat agent one. ## Changes - Renamed `ConversationStreamManager` to `AgentExecutor` and moved it from `ee/hogai/stream/` to `ee/hogai/agent/` - Renamed `AssistantConversationRunnerWorkflow` to `ChatAgentWorkflow` for better clarity - Created a new base class `AgentBaseWorkflow` to support different types of agent workflows - Moved Redis stream implementation to the agent package - Made timeout and max length parameters configurable in the executor ## How did you test this code? - Migrated all existing tests to the new structure ## Changelog: Yes
1 parent 2a27ff5 commit a39feb2

File tree

10 files changed

+148
-119
lines changed

10 files changed

+148
-119
lines changed

ee/api/conversation.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
from posthog.exceptions_capture import capture_exception
2222
from posthog.models.user import User
2323
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
24-
from posthog.temporal.ai.conversation import AssistantConversationRunnerWorkflowInputs
24+
from posthog.temporal.ai.chat_agent import (
25+
CHAT_AGENT_STREAM_MAX_LENGTH,
26+
CHAT_AGENT_WORKFLOW_TIMEOUT,
27+
AssistantConversationRunnerWorkflow,
28+
AssistantConversationRunnerWorkflowInputs,
29+
)
2530
from posthog.utils import get_instance_region
2631

32+
from ee.hogai.agent.executor import AgentExecutor
2733
from ee.hogai.api.serializers import ConversationSerializer
28-
from ee.hogai.stream.conversation_stream import ConversationStreamManager
2934
from ee.hogai.utils.aio import async_to_sync
3035
from ee.hogai.utils.sse import AssistantSSESerializer
3136
from ee.hogai.utils.types.base import AssistantMode
@@ -131,7 +136,8 @@ def create(self, request: Request, *args, **kwargs):
131136

132137
has_message = serializer.validated_data.get("content") is not None
133138
is_deep_research = serializer.validated_data.get("deep_research_mode", False)
134-
mode = AssistantMode.DEEP_RESEARCH if is_deep_research else AssistantMode.ASSISTANT
139+
if is_deep_research:
140+
raise NotImplementedError("Deep research is not supported yet")
135141

136142
is_new_conversation = False
137143
# Safely set the lookup kwarg for potential error handling
@@ -175,15 +181,18 @@ def create(self, request: Request, *args, **kwargs):
175181
trace_id=serializer.validated_data["trace_id"],
176182
session_id=request.headers.get("X-POSTHOG-SESSION-ID"), # Relies on posthog-js __add_tracing_headers
177183
billing_context=serializer.validated_data.get("billing_context"),
178-
mode=mode,
184+
mode=AssistantMode.ASSISTANT,
179185
)
186+
workflow_class = AssistantConversationRunnerWorkflow
180187

181188
async def async_stream(
182189
workflow_inputs: AssistantConversationRunnerWorkflowInputs,
183190
) -> AsyncGenerator[bytes, None]:
184191
serializer = AssistantSSESerializer()
185-
stream_manager = ConversationStreamManager(conversation)
186-
async for chunk in stream_manager.astream(workflow_inputs):
192+
stream_manager = AgentExecutor(
193+
conversation, timeout=CHAT_AGENT_WORKFLOW_TIMEOUT, max_length=CHAT_AGENT_STREAM_MAX_LENGTH
194+
)
195+
async for chunk in stream_manager.astream(workflow_class, workflow_inputs):
187196
yield serializer.dumps(chunk).encode("utf-8")
188197

189198
return StreamingHttpResponse(
@@ -201,8 +210,8 @@ def cancel(self, request: Request, *args, **kwargs):
201210
return Response(status=status.HTTP_204_NO_CONTENT)
202211

203212
async def cancel_workflow():
204-
conversation_manager = ConversationStreamManager(conversation)
205-
await conversation_manager.cancel_conversation()
213+
agent_executor = AgentExecutor(conversation)
214+
await agent_executor.cancel_workflow()
206215

207216
try:
208217
asgi_async_to_sync(cancel_workflow)()

ee/api/test/test_conversation.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_create_conversation(self):
102102
conversation_id = str(uuid.uuid4())
103103

104104
with patch(
105-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
105+
"ee.hogai.agent.executor.AgentExecutor.astream",
106106
return_value=_async_generator(),
107107
) as mock_start_workflow_and_stream:
108108
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -120,7 +120,7 @@ def test_create_conversation(self):
120120
# Check that the method was called with workflow_inputs
121121
mock_start_workflow_and_stream.assert_called_once()
122122
call_args = mock_start_workflow_and_stream.call_args
123-
workflow_inputs = call_args[0][0]
123+
workflow_inputs = call_args[0][1]
124124
self.assertEqual(workflow_inputs.user_id, self.user.id)
125125
self.assertEqual(workflow_inputs.is_new_conversation, True)
126126
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
@@ -129,7 +129,7 @@ def test_create_conversation(self):
129129

130130
def test_add_message_to_existing_conversation(self):
131131
with patch(
132-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
132+
"ee.hogai.agent.executor.AgentExecutor.astream",
133133
return_value=_async_generator(),
134134
) as mock_start_workflow_and_stream:
135135
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -149,7 +149,7 @@ def test_add_message_to_existing_conversation(self):
149149
# Check that the method was called with workflow_inputs
150150
mock_start_workflow_and_stream.assert_called_once()
151151
call_args = mock_start_workflow_and_stream.call_args
152-
workflow_inputs = call_args[0][0]
152+
workflow_inputs = call_args[0][1]
153153
self.assertEqual(workflow_inputs.user_id, self.user.id)
154154
self.assertEqual(workflow_inputs.is_new_conversation, False)
155155
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
@@ -195,7 +195,7 @@ def test_invalid_message_format(self):
195195
def test_rate_limit_burst(self):
196196
# Create multiple requests to trigger burst rate limit
197197
with patch(
198-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
198+
"ee.hogai.agent.executor.AgentExecutor.astream",
199199
return_value=_async_generator(),
200200
):
201201
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -227,7 +227,7 @@ def test_none_content_with_existing_conversation(self):
227227
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
228228
)
229229
with patch(
230-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
230+
"ee.hogai.agent.executor.AgentExecutor.astream",
231231
return_value=_async_generator(),
232232
) as mock_stream_conversation:
233233
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -263,7 +263,7 @@ def test_missing_trace_id(self):
263263

264264
def test_nonexistent_conversation(self):
265265
with patch(
266-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
266+
"ee.hogai.agent.executor.AgentExecutor.astream",
267267
return_value=_async_generator(),
268268
):
269269
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -296,7 +296,7 @@ def test_unauthenticated_request(self):
296296
)
297297
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
298298

299-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
299+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
300300
def test_cancel_conversation(self, mock_cancel):
301301
conversation = Conversation.objects.create(
302302
user=self.user,
@@ -340,7 +340,7 @@ def test_cancel_other_teams_conversation(self):
340340
)
341341
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
342342

343-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
343+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
344344
def test_cancel_conversation_with_async_cleanup(self, mock_cancel):
345345
"""Test that cancel endpoint properly handles async cleanup."""
346346
conversation = Conversation.objects.create(
@@ -360,7 +360,7 @@ def test_cancel_conversation_with_async_cleanup(self, mock_cancel):
360360

361361
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
362362

363-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
363+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
364364
def test_cancel_conversation_async_cleanup_failure(self, mock_cancel):
365365
"""Test cancel endpoint behavior when async cleanup fails."""
366366
conversation = Conversation.objects.create(
@@ -428,7 +428,7 @@ def test_stream_from_in_progress_conversation(self):
428428
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
429429
)
430430
with patch(
431-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
431+
"ee.hogai.agent.executor.AgentExecutor.astream",
432432
return_value=_async_generator(),
433433
) as mock_stream_conversation:
434434
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -606,7 +606,7 @@ def test_billing_context_validation_valid_data(self):
606606
conversation = Conversation.objects.create(user=self.user, team=self.team)
607607

608608
with patch(
609-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
609+
"ee.hogai.agent.executor.AgentExecutor.astream",
610610
return_value=_async_generator(),
611611
) as mock_start_workflow_and_stream:
612612
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -622,15 +622,15 @@ def test_billing_context_validation_valid_data(self):
622622
)
623623
self.assertEqual(response.status_code, status.HTTP_200_OK)
624624
call_args = mock_start_workflow_and_stream.call_args
625-
workflow_inputs = call_args[0][0]
625+
workflow_inputs = call_args[0][1]
626626
self.assertEqual(workflow_inputs.billing_context, self.billing_context)
627627

628628
def test_billing_context_validation_invalid_data(self):
629629
"""Test that invalid billing context data is rejected."""
630630
conversation = Conversation.objects.create(user=self.user, team=self.team)
631631

632632
with patch(
633-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
633+
"ee.hogai.agent.executor.AgentExecutor.astream",
634634
return_value=_async_generator(),
635635
) as mock_start_workflow_and_stream:
636636
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -646,5 +646,5 @@ def test_billing_context_validation_invalid_data(self):
646646
)
647647
self.assertEqual(response.status_code, status.HTTP_200_OK)
648648
call_args = mock_start_workflow_and_stream.call_args
649-
workflow_inputs = call_args[0][0]
649+
workflow_inputs = call_args[0][1]
650650
self.assertEqual(workflow_inputs.billing_context, None)

ee/hogai/stream/conversation_stream.py renamed to ee/hogai/agent/executor.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111

1212
from posthog.schema import AssistantEventType, FailureMessage
1313

14-
from posthog.temporal.ai.conversation import (
15-
AssistantConversationRunnerWorkflow,
16-
AssistantConversationRunnerWorkflowInputs,
17-
)
14+
from posthog.temporal.ai.base import AgentBaseWorkflow
1815
from posthog.temporal.common.client import async_connect
1916

20-
from ee.hogai.stream.redis_stream import (
17+
from ee.hogai.agent.redis_stream import (
18+
CONVERSATION_STREAM_MAX_LENGTH,
19+
CONVERSATION_STREAM_TIMEOUT,
2120
ConversationEvent,
2221
ConversationRedisStream,
2322
GenerationStatusEvent,
@@ -33,38 +32,44 @@
3332
logger = structlog.get_logger(__name__)
3433

3534

36-
class ConversationStreamManager:
37-
"""Manages conversation streaming from Redis streams."""
35+
class AgentExecutor:
36+
"""Manages executing an agent workflow and streaming the output."""
3837

39-
def __init__(self, conversation: Conversation) -> None:
38+
def __init__(
39+
self,
40+
conversation: Conversation,
41+
timeout: int = CONVERSATION_STREAM_TIMEOUT,
42+
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
43+
) -> None:
4044
self._conversation = conversation
41-
self._redis_stream = ConversationRedisStream(get_conversation_stream_key(conversation.id))
45+
self._redis_stream = ConversationRedisStream(
46+
get_conversation_stream_key(conversation.id), timeout=timeout, max_length=max_length
47+
)
4248
self._workflow_id = f"conversation-{conversation.id}"
4349

44-
async def astream(
45-
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
46-
) -> AsyncGenerator[AssistantOutput, Any]:
47-
"""Stream conversation updates from Redis stream.
50+
async def astream(self, workflow: type[AgentBaseWorkflow], inputs: Any) -> AsyncGenerator[AssistantOutput, Any]:
51+
"""Stream agent workflow updates from Redis stream.
4852
4953
Args:
50-
workflow_inputs: Temporal workflow inputs
54+
workflow: Agent temporal workflow class
55+
inputs: Agent temporal workflow inputs
5156
5257
Returns:
5358
AssistantOutput generator
5459
"""
5560
# If this is a reconnection attempt, we resume streaming
5661
if self._conversation.status != Conversation.Status.IDLE:
57-
if workflow_inputs.message is not None:
62+
if inputs.message is not None:
5863
raise ValueError("Cannot resume streaming with a new message")
5964
async for chunk in self.stream_conversation():
6065
yield chunk
6166
else:
6267
# Otherwise, process the new message (new generation) or resume generation (no new message)
63-
async for chunk in self.start_workflow(workflow_inputs):
68+
async for chunk in self.start_workflow(workflow, inputs):
6469
yield chunk
6570

6671
async def start_workflow(
67-
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
72+
self, workflow: type[AgentBaseWorkflow], inputs: Any
6873
) -> AsyncGenerator[AssistantOutput, Any]:
6974
try:
7075
# Delete the stream to ensure we start fresh
@@ -74,8 +79,8 @@ async def start_workflow(
7479
client = await async_connect()
7580

7681
handle = await client.start_workflow(
77-
AssistantConversationRunnerWorkflow.run,
78-
workflow_inputs,
82+
workflow.run,
83+
inputs,
7984
id=self._workflow_id,
8085
task_queue=settings.MAX_AI_TASK_QUEUE,
8186
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
@@ -173,7 +178,7 @@ def _failure_message(self) -> AssistantOutput:
173178
)
174179
return (AssistantEventType.MESSAGE, failure_message)
175180

176-
async def cancel_conversation(self) -> None:
181+
async def cancel_workflow(self) -> None:
177182
"""Cancel the current conversation and clean up resources.
178183
179184
Raises:

ee/hogai/stream/redis_stream.py renamed to ee/hogai/agent/redis_stream.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,18 @@ class StreamError(Exception):
160160
class ConversationRedisStream:
161161
"""Manages conversation streaming from Redis streams."""
162162

163-
def __init__(self, stream_key: str):
163+
def __init__(
164+
self,
165+
stream_key: str,
166+
timeout: int = CONVERSATION_STREAM_TIMEOUT,
167+
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
168+
):
164169
self._stream_key = stream_key
165170
self._redis_client = get_async_client(settings.REDIS_URL)
166171
self._deletion_lock = asyncio.Lock()
167172
self._serializer = ConversationStreamSerializer()
173+
self._timeout = timeout
174+
self._max_length = max_length
168175

169176
async def wait_for_stream(self) -> bool:
170177
"""Wait for stream to be created using linear backoff.
@@ -220,7 +227,7 @@ async def read_stream(
220227
start_time = asyncio.get_event_loop().time()
221228

222229
while True:
223-
if asyncio.get_event_loop().time() - start_time > CONVERSATION_STREAM_TIMEOUT:
230+
if asyncio.get_event_loop().time() - start_time > self._timeout:
224231
raise StreamError("Stream timeout - conversation took too long to complete")
225232

226233
try:
@@ -283,15 +290,15 @@ async def write_to_stream(
283290
callback: Callback to trigger after each message is written to the stream
284291
"""
285292
try:
286-
await self._redis_client.expire(self._stream_key, CONVERSATION_STREAM_TIMEOUT)
293+
await self._redis_client.expire(self._stream_key, self._timeout)
287294

288295
async for chunk in generator:
289296
message = self._serializer.dumps(chunk)
290297
if message is not None:
291298
await self._redis_client.xadd(
292299
self._stream_key,
293300
message,
294-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
301+
maxlen=self._max_length,
295302
approximate=True,
296303
)
297304
if callback:
@@ -303,7 +310,7 @@ async def write_to_stream(
303310
await self._redis_client.xadd(
304311
self._stream_key,
305312
completion_message,
306-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
313+
maxlen=self._max_length,
307314
approximate=True,
308315
)
309316

@@ -314,7 +321,7 @@ async def write_to_stream(
314321
await self._redis_client.xadd(
315322
self._stream_key,
316323
message,
317-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
324+
maxlen=self._max_length,
318325
approximate=True,
319326
)
320327
raise StreamError("Failed to write to stream")

0 commit comments

Comments
 (0)