Skip to content

Commit 7424d72

Browse files
author
kappa90
committed
refactor(ph-ai): conversation stream manager to agent executor
1 parent 78b6f81 commit 7424d72

File tree

10 files changed

+160
-131
lines changed

10 files changed

+160
-131
lines changed

ee/api/conversation.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,16 @@
2121
from posthog.exceptions_capture import capture_exception
2222
from posthog.models.user import User
2323
from posthog.rate_limit import AIBurstRateThrottle, AISustainedRateThrottle
24-
from posthog.temporal.ai.conversation import AssistantConversationRunnerWorkflowInputs
24+
from posthog.temporal.ai.chat_agent import (
25+
CHAT_AGENT_STREAM_MAX_LENGTH,
26+
CHAT_AGENT_WORKFLOW_TIMEOUT,
27+
AssistantConversationRunnerWorkflow,
28+
AssistantConversationRunnerWorkflowInputs,
29+
)
2530
from posthog.utils import get_instance_region
2631

32+
from ee.hogai.agent.executor import AgentExecutor
2733
from ee.hogai.api.serializers import ConversationSerializer
28-
from ee.hogai.stream.conversation_stream import ConversationStreamManager
2934
from ee.hogai.utils.aio import async_to_sync
3035
from ee.hogai.utils.sse import AssistantSSESerializer
3136
from ee.hogai.utils.types.base import AssistantMode
@@ -133,7 +138,8 @@ def create(self, request: Request, *args, **kwargs):
133138

134139
has_message = serializer.validated_data.get("content") is not None
135140
is_deep_research = serializer.validated_data.get("deep_research_mode", False)
136-
mode = AssistantMode.DEEP_RESEARCH if is_deep_research else AssistantMode.ASSISTANT
141+
if is_deep_research:
142+
raise NotImplementedError("Deep research is not supported yet")
137143

138144
is_new_conversation = False
139145
# Safely set the lookup kwarg for potential error handling
@@ -174,15 +180,18 @@ def create(self, request: Request, *args, **kwargs):
174180
trace_id=serializer.validated_data["trace_id"],
175181
session_id=request.headers.get("X-POSTHOG-SESSION-ID"), # Relies on posthog-js __add_tracing_headers
176182
billing_context=serializer.validated_data.get("billing_context"),
177-
mode=mode,
183+
mode=AssistantMode.ASSISTANT,
178184
)
185+
workflow_class = AssistantConversationRunnerWorkflow
179186

180187
async def async_stream(
181188
workflow_inputs: AssistantConversationRunnerWorkflowInputs,
182189
) -> AsyncGenerator[bytes, None]:
183190
serializer = AssistantSSESerializer()
184-
stream_manager = ConversationStreamManager(conversation)
185-
async for chunk in stream_manager.astream(workflow_inputs):
191+
stream_manager = AgentExecutor(
192+
conversation, timeout=CHAT_AGENT_WORKFLOW_TIMEOUT, max_length=CHAT_AGENT_STREAM_MAX_LENGTH
193+
)
194+
async for chunk in stream_manager.astream(workflow_class, workflow_inputs):
186195
yield serializer.dumps(chunk).encode("utf-8")
187196

188197
return StreamingHttpResponse(
@@ -200,8 +209,8 @@ def cancel(self, request: Request, *args, **kwargs):
200209
return Response(status=status.HTTP_204_NO_CONTENT)
201210

202211
async def cancel_workflow():
203-
conversation_manager = ConversationStreamManager(conversation)
204-
await conversation_manager.cancel_conversation()
212+
agent_executor = AgentExecutor(conversation)
213+
await agent_executor.cancel_workflow()
205214

206215
try:
207216
asgi_async_to_sync(cancel_workflow)()

ee/api/test/test_conversation.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_create_conversation(self):
102102
conversation_id = str(uuid.uuid4())
103103

104104
with patch(
105-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
105+
"ee.hogai.agent.executor.AgentExecutor.astream",
106106
return_value=_async_generator(),
107107
) as mock_start_workflow_and_stream:
108108
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -120,7 +120,7 @@ def test_create_conversation(self):
120120
# Check that the method was called with workflow_inputs
121121
mock_start_workflow_and_stream.assert_called_once()
122122
call_args = mock_start_workflow_and_stream.call_args
123-
workflow_inputs = call_args[0][0]
123+
workflow_inputs = call_args[0][1]
124124
self.assertEqual(workflow_inputs.user_id, self.user.id)
125125
self.assertEqual(workflow_inputs.is_new_conversation, True)
126126
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
@@ -129,7 +129,7 @@ def test_create_conversation(self):
129129

130130
def test_add_message_to_existing_conversation(self):
131131
with patch(
132-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
132+
"ee.hogai.agent.executor.AgentExecutor.astream",
133133
return_value=_async_generator(),
134134
) as mock_start_workflow_and_stream:
135135
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -149,7 +149,7 @@ def test_add_message_to_existing_conversation(self):
149149
# Check that the method was called with workflow_inputs
150150
mock_start_workflow_and_stream.assert_called_once()
151151
call_args = mock_start_workflow_and_stream.call_args
152-
workflow_inputs = call_args[0][0]
152+
workflow_inputs = call_args[0][1]
153153
self.assertEqual(workflow_inputs.user_id, self.user.id)
154154
self.assertEqual(workflow_inputs.is_new_conversation, False)
155155
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
@@ -183,7 +183,7 @@ def test_invalid_message_format(self):
183183
def test_rate_limit_burst(self):
184184
# Create multiple requests to trigger burst rate limit
185185
with patch(
186-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
186+
"ee.hogai.agent.executor.AgentExecutor.astream",
187187
return_value=_async_generator(),
188188
):
189189
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -215,7 +215,7 @@ def test_none_content_with_existing_conversation(self):
215215
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
216216
)
217217
with patch(
218-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
218+
"ee.hogai.agent.executor.AgentExecutor.astream",
219219
return_value=_async_generator(),
220220
) as mock_stream_conversation:
221221
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -251,7 +251,7 @@ def test_missing_trace_id(self):
251251

252252
def test_nonexistent_conversation(self):
253253
with patch(
254-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
254+
"ee.hogai.agent.executor.AgentExecutor.astream",
255255
return_value=_async_generator(),
256256
):
257257
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -284,7 +284,7 @@ def test_unauthenticated_request(self):
284284
)
285285
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
286286

287-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
287+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
288288
def test_cancel_conversation(self, mock_cancel):
289289
conversation = Conversation.objects.create(
290290
user=self.user,
@@ -327,7 +327,7 @@ def test_cancel_other_teams_conversation(self):
327327
)
328328
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
329329

330-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
330+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
331331
def test_cancel_conversation_with_async_cleanup(self, mock_cancel):
332332
"""Test that cancel endpoint properly handles async cleanup."""
333333
conversation = Conversation.objects.create(
@@ -347,7 +347,7 @@ def test_cancel_conversation_with_async_cleanup(self, mock_cancel):
347347

348348
self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT)
349349

350-
@patch("ee.hogai.stream.conversation_stream.ConversationStreamManager.cancel_conversation")
350+
@patch("ee.hogai.agent.executor.AgentExecutor.cancel_workflow")
351351
def test_cancel_conversation_async_cleanup_failure(self, mock_cancel):
352352
"""Test cancel endpoint behavior when async cleanup fails."""
353353
conversation = Conversation.objects.create(
@@ -415,7 +415,7 @@ def test_stream_from_in_progress_conversation(self):
415415
user=self.user, team=self.team, status=Conversation.Status.IN_PROGRESS
416416
)
417417
with patch(
418-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
418+
"ee.hogai.agent.executor.AgentExecutor.astream",
419419
return_value=_async_generator(),
420420
) as mock_stream_conversation:
421421
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -435,7 +435,7 @@ def test_resume_generation_from_idle_conversation(self):
435435
"""Test resuming generation from an idle conversation with no new content."""
436436
conversation = Conversation.objects.create(user=self.user, team=self.team, status=Conversation.Status.IDLE)
437437
with patch(
438-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
438+
"ee.hogai.agent.executor.AgentExecutor.astream",
439439
return_value=_async_generator(),
440440
) as mock_start_workflow_and_stream:
441441
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -453,7 +453,7 @@ def test_resume_generation_from_idle_conversation(self):
453453
# Check that the method was called with workflow_inputs
454454
mock_start_workflow_and_stream.assert_called_once()
455455
call_args = mock_start_workflow_and_stream.call_args
456-
workflow_inputs = call_args[0][0]
456+
workflow_inputs = call_args[0][1]
457457
self.assertEqual(workflow_inputs.user_id, self.user.id)
458458
self.assertEqual(workflow_inputs.is_new_conversation, False)
459459
self.assertEqual(workflow_inputs.conversation_id, conversation.id)
@@ -595,7 +595,7 @@ def test_billing_context_validation_valid_data(self):
595595
conversation = Conversation.objects.create(user=self.user, team=self.team)
596596

597597
with patch(
598-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
598+
"ee.hogai.agent.executor.AgentExecutor.astream",
599599
return_value=_async_generator(),
600600
) as mock_start_workflow_and_stream:
601601
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -611,15 +611,15 @@ def test_billing_context_validation_valid_data(self):
611611
)
612612
self.assertEqual(response.status_code, status.HTTP_200_OK)
613613
call_args = mock_start_workflow_and_stream.call_args
614-
workflow_inputs = call_args[0][0]
614+
workflow_inputs = call_args[0][1]
615615
self.assertEqual(workflow_inputs.billing_context, self.billing_context)
616616

617617
def test_billing_context_validation_invalid_data(self):
618618
"""Test that invalid billing context data is rejected."""
619619
conversation = Conversation.objects.create(user=self.user, team=self.team)
620620

621621
with patch(
622-
"ee.hogai.stream.conversation_stream.ConversationStreamManager.astream",
622+
"ee.hogai.agent.executor.AgentExecutor.astream",
623623
return_value=_async_generator(),
624624
) as mock_start_workflow_and_stream:
625625
with patch("ee.api.conversation.StreamingHttpResponse", side_effect=self._create_mock_streaming_response):
@@ -635,5 +635,5 @@ def test_billing_context_validation_invalid_data(self):
635635
)
636636
self.assertEqual(response.status_code, status.HTTP_200_OK)
637637
call_args = mock_start_workflow_and_stream.call_args
638-
workflow_inputs = call_args[0][0]
638+
workflow_inputs = call_args[0][1]
639639
self.assertEqual(workflow_inputs.billing_context, None)

ee/hogai/stream/conversation_stream.py renamed to ee/hogai/agent/executor.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111

1212
from posthog.schema import AssistantEventType, FailureMessage
1313

14-
from posthog.temporal.ai.conversation import (
15-
AssistantConversationRunnerWorkflow,
16-
AssistantConversationRunnerWorkflowInputs,
17-
)
14+
from posthog.temporal.ai.base import AgentBaseWorkflow
1815
from posthog.temporal.common.client import async_connect
1916

20-
from ee.hogai.stream.redis_stream import (
17+
from ee.hogai.agent.redis_stream import (
18+
CONVERSATION_STREAM_MAX_LENGTH,
19+
CONVERSATION_STREAM_TIMEOUT,
2120
ConversationEvent,
2221
ConversationRedisStream,
2322
GenerationStatusEvent,
@@ -33,38 +32,44 @@
3332
logger = structlog.get_logger(__name__)
3433

3534

36-
class ConversationStreamManager:
37-
"""Manages conversation streaming from Redis streams."""
35+
class AgentExecutor:
36+
"""Manages executing an agent workflow and streaming the output."""
3837

39-
def __init__(self, conversation: Conversation) -> None:
38+
def __init__(
39+
self,
40+
conversation: Conversation,
41+
timeout: int = CONVERSATION_STREAM_TIMEOUT,
42+
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
43+
) -> None:
4044
self._conversation = conversation
41-
self._redis_stream = ConversationRedisStream(get_conversation_stream_key(conversation.id))
45+
self._redis_stream = ConversationRedisStream(
46+
get_conversation_stream_key(conversation.id), timeout=timeout, max_length=max_length
47+
)
4248
self._workflow_id = f"conversation-{conversation.id}"
4349

44-
async def astream(
45-
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
46-
) -> AsyncGenerator[AssistantOutput, Any]:
47-
"""Stream conversation updates from Redis stream.
50+
async def astream(self, workflow: type[AgentBaseWorkflow], inputs: Any) -> AsyncGenerator[AssistantOutput, Any]:
51+
"""Stream agent workflow updates from Redis stream.
4852
4953
Args:
50-
workflow_inputs: Temporal workflow inputs
54+
workflow: Agent temporal workflow class
55+
inputs: Agent temporal workflow inputs
5156
5257
Returns:
5358
AssistantOutput generator
5459
"""
5560
# If this is a reconnection attempt, we resume streaming
5661
if self._conversation.status != Conversation.Status.IDLE:
57-
if workflow_inputs.message is not None:
62+
if inputs.message is not None:
5863
raise ValueError("Cannot resume streaming with a new message")
5964
async for chunk in self.stream_conversation():
6065
yield chunk
6166
else:
6267
# Otherwise, process the new message (new generation) or resume generation (no new message)
63-
async for chunk in self.start_workflow(workflow_inputs):
68+
async for chunk in self.start_workflow(workflow, inputs):
6469
yield chunk
6570

6671
async def start_workflow(
67-
self, workflow_inputs: AssistantConversationRunnerWorkflowInputs
72+
self, workflow: type[AgentBaseWorkflow], inputs: Any
6873
) -> AsyncGenerator[AssistantOutput, Any]:
6974
try:
7075
# Delete the stream to ensure we start fresh
@@ -74,8 +79,8 @@ async def start_workflow(
7479
client = await async_connect()
7580

7681
handle = await client.start_workflow(
77-
AssistantConversationRunnerWorkflow.run,
78-
workflow_inputs,
82+
workflow.run,
83+
inputs,
7984
id=self._workflow_id,
8085
task_queue=settings.MAX_AI_TASK_QUEUE,
8186
id_conflict_policy=WorkflowIDConflictPolicy.USE_EXISTING,
@@ -173,7 +178,7 @@ def _failure_message(self) -> AssistantOutput:
173178
)
174179
return (AssistantEventType.MESSAGE, failure_message)
175180

176-
async def cancel_conversation(self) -> None:
181+
async def cancel_workflow(self) -> None:
177182
"""Cancel the current conversation and clean up resources.
178183
179184
Raises:

ee/hogai/stream/redis_stream.py renamed to ee/hogai/agent/redis_stream.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,18 @@ class StreamError(Exception):
160160
class ConversationRedisStream:
161161
"""Manages conversation streaming from Redis streams."""
162162

163-
def __init__(self, stream_key: str):
163+
def __init__(
164+
self,
165+
stream_key: str,
166+
timeout: int = CONVERSATION_STREAM_TIMEOUT,
167+
max_length: int = CONVERSATION_STREAM_MAX_LENGTH,
168+
):
164169
self._stream_key = stream_key
165170
self._redis_client = get_async_client(settings.REDIS_URL)
166171
self._deletion_lock = asyncio.Lock()
167172
self._serializer = ConversationStreamSerializer()
173+
self._timeout = timeout
174+
self._max_length = max_length
168175

169176
async def wait_for_stream(self) -> bool:
170177
"""Wait for stream to be created using linear backoff.
@@ -220,7 +227,7 @@ async def read_stream(
220227
start_time = asyncio.get_event_loop().time()
221228

222229
while True:
223-
if asyncio.get_event_loop().time() - start_time > CONVERSATION_STREAM_TIMEOUT:
230+
if asyncio.get_event_loop().time() - start_time > self._timeout:
224231
raise StreamError("Stream timeout - conversation took too long to complete")
225232

226233
try:
@@ -283,15 +290,15 @@ async def write_to_stream(
283290
callback: Callback to trigger after each message is written to the stream
284291
"""
285292
try:
286-
await self._redis_client.expire(self._stream_key, CONVERSATION_STREAM_TIMEOUT)
293+
await self._redis_client.expire(self._stream_key, self._timeout)
287294

288295
async for chunk in generator:
289296
message = self._serializer.dumps(chunk)
290297
if message is not None:
291298
await self._redis_client.xadd(
292299
self._stream_key,
293300
message,
294-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
301+
maxlen=self._max_length,
295302
approximate=True,
296303
)
297304
if callback:
@@ -303,7 +310,7 @@ async def write_to_stream(
303310
await self._redis_client.xadd(
304311
self._stream_key,
305312
completion_message,
306-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
313+
maxlen=self._max_length,
307314
approximate=True,
308315
)
309316

@@ -314,7 +321,7 @@ async def write_to_stream(
314321
await self._redis_client.xadd(
315322
self._stream_key,
316323
message,
317-
maxlen=CONVERSATION_STREAM_MAX_LENGTH,
324+
maxlen=self._max_length,
318325
approximate=True,
319326
)
320327
raise StreamError("Failed to write to stream")

0 commit comments

Comments
 (0)