Skip to content

Commit 4c82d58

Browse files
authored
feat(ph-ai): agent modes (#41284)
1 parent 51216de commit 4c82d58

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2763
-243
lines changed

ee/hogai/context/context.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
from collections.abc import Sequence
3-
from functools import lru_cache
43
from typing import Any, Optional, cast
54
from uuid import uuid4
65

@@ -9,6 +8,7 @@
98
from posthoganalytics import capture_exception
109

1110
from posthog.schema import (
11+
AgentMode,
1212
ContextMessage,
1313
FunnelsQuery,
1414
HogQLQuery,
@@ -36,10 +36,13 @@
3636

3737
from ee.hogai.graph.mixins import AssistantContextMixin
3838
from ee.hogai.graph.query_executor.query_executor import AssistantQueryExecutor, SupportedQueryTypes
39-
from ee.hogai.utils.helpers import find_start_message, insert_messages_before_start
39+
from ee.hogai.utils.feature_flags import has_agent_modes_feature_flag
40+
from ee.hogai.utils.helpers import find_start_message, find_start_message_idx, insert_messages_before_start
41+
from ee.hogai.utils.prompt import format_prompt_string
4042
from ee.hogai.utils.types.base import AnyAssistantSupportedQuery, AssistantMessageUnion, BaseStateWithMessages
4143

4244
from .prompts import (
45+
CONTEXT_MODE_PROMPT,
4346
CONTEXTUAL_TOOLS_REMINDER_PROMPT,
4447
ROOT_DASHBOARD_CONTEXT_PROMPT,
4548
ROOT_DASHBOARDS_CONTEXT_PROMPT,
@@ -133,7 +136,6 @@ def get_groups(self):
133136
"""
134137
return GroupTypeMapping.objects.filter(project_id=self._team.project_id).order_by("group_type_index")
135138

136-
@lru_cache(maxsize=1)
137139
async def get_group_names(self) -> list[str]:
138140
"""
139141
Returns the names of the team's groups.
@@ -384,23 +386,33 @@ def _render_user_context_template(
384386
).to_string()
385387

386388
async def _get_context_prompts(self, state: BaseStateWithMessages) -> list[str]:
389+
are_modes_enabled = has_agent_modes_feature_flag(self._team, self._user)
390+
387391
prompts: list[str] = []
388-
if contextual_tools := self._get_contextual_tools_prompt():
392+
if (
393+
are_modes_enabled
394+
and find_start_message_idx(state.messages, state.start_id) == 0
395+
and (mode_prompt := self._get_mode_prompt(state.agent_mode))
396+
):
397+
prompts.append(mode_prompt)
398+
if contextual_tools := await self._get_contextual_tools_prompt():
389399
prompts.append(contextual_tools)
390400
if ui_context := await self._format_ui_context(self.get_ui_context(state)):
391401
prompts.append(ui_context)
392402
return self._deduplicate_context_messages(state, prompts)
393403

394-
def _get_contextual_tools_prompt(self) -> str | None:
404+
async def _get_contextual_tools_prompt(self) -> str | None:
395405
from ee.hogai.registry import get_contextual_tool_class
396406

397-
contextual_tools_prompt = [
398-
f"<{tool_name}>\n"
399-
f"{get_contextual_tool_class(tool_name)(team=self._team, user=self._user).format_context_prompt_injection(tool_context)}\n" # type: ignore
400-
f"</{tool_name}>"
401-
for tool_name, tool_context in self.get_contextual_tools().items()
402-
if get_contextual_tool_class(tool_name) is not None
403-
]
407+
contextual_tools_prompt: list[str] = []
408+
for tool_name, tool_context in self.get_contextual_tools().items():
409+
tool_class = get_contextual_tool_class(tool_name)
410+
if tool_class is None:
411+
continue
412+
tool = await tool_class.create_tool_class(team=self._team, user=self._user, context_manager=self)
413+
tool_prompt = tool.format_context_prompt_injection(tool_context)
414+
contextual_tools_prompt.append(f"<{tool_name}>\n" f"{tool_prompt}\n" f"</{tool_name}>")
415+
404416
if contextual_tools_prompt:
405417
tools = "\n".join(contextual_tools_prompt)
406418
return CONTEXTUAL_TOOLS_REMINDER_PROMPT.format(tools=tools)
@@ -417,3 +429,6 @@ def _inject_context_messages(
417429
context_messages = [ContextMessage(content=prompt, id=str(uuid4())) for prompt in context_prompts]
418430
# Insert context messages right before the start message
419431
return insert_messages_before_start(state.messages, context_messages, start_id=state.start_id)
432+
433+
def _get_mode_prompt(self, mode: AgentMode | None) -> str:
434+
return format_prompt_string(CONTEXT_MODE_PROMPT, mode=mode.value if mode else AgentMode.PRODUCT_ANALYTICS.value)

ee/hogai/context/prompts.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,7 @@
6363
IMPORTANT: this context may or may not be relevant to your tasks. You should not respond to this context unless it is highly relevant to your task.
6464
</system_reminder>
6565
""".strip()
66+
67+
CONTEXT_MODE_PROMPT = """
68+
<system_reminder>Your initial mode is {{{mode}}}.</system_reminder>
69+
""".strip()

ee/hogai/context/test/test_context.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from parameterized import parameterized
99

1010
from posthog.schema import (
11+
AgentMode,
1112
AssistantMessage,
1213
ContextMessage,
1314
DashboardFilter,
@@ -498,33 +499,25 @@ def test_format_entity_context_empty(self):
498499
result = self.context_manager._format_entity_context([], "events", "Event")
499500
self.assertEqual(result, "")
500501

501-
@patch("ee.hogai.registry.get_contextual_tool_class")
502-
def test_get_contextual_tools_prompt(self, mock_get_contextual_tool_class):
502+
async def test_get_contextual_tools_prompt(self):
503503
"""Test generation of contextual tools prompt"""
504-
# Mock the tool class
505-
mock_tool = MagicMock()
506-
mock_tool.format_context_prompt_injection.return_value = "Tool system prompt"
507-
mock_get_contextual_tool_class.return_value = lambda team, user: mock_tool
508-
509504
config = RunnableConfig(
510505
configurable={"contextual_tools": {"search_session_recordings": {"current_filters": {}}}}
511506
)
512507
context_manager = AssistantContextManager(self.team, self.user, config)
513508

514-
result = context_manager._get_contextual_tools_prompt()
515-
516-
self.assertIsNotNone(result)
517-
assert result is not None # Type guard for mypy
509+
result = await context_manager._get_contextual_tools_prompt()
510+
assert result is not None
518511
self.assertIn("<search_session_recordings>", result)
519-
self.assertIn("Tool system prompt", result)
512+
self.assertIn("Current recordings filters are", result)
520513
self.assertIn("</search_session_recordings>", result)
521514

522-
def test_get_contextual_tools_prompt_no_tools(self):
515+
async def test_get_contextual_tools_prompt_no_tools(self):
523516
"""Test generation of contextual tools prompt returns None when no tools"""
524517
config = RunnableConfig(configurable={})
525518
context_manager = AssistantContextManager(self.team, self.user, config)
526519

527-
result = context_manager._get_contextual_tools_prompt()
520+
result = await context_manager._get_contextual_tools_prompt()
528521

529522
self.assertIsNone(result)
530523

@@ -646,3 +639,20 @@ async def test_get_billing_context(self):
646639

647640
context_manager = AssistantContextManager(self.team, self.user, RunnableConfig(configurable={}))
648641
self.assertIsNone(context_manager.get_billing_context())
642+
643+
async def test_get_context_prompts_with_agent_mode_at_start(self):
644+
"""Test that mode prompt is added when feature flag is enabled and message is at start"""
645+
with patch("ee.hogai.context.context.has_agent_modes_feature_flag", return_value=True):
646+
state = AssistantState(
647+
messages=[HumanMessage(content="Test", id="1")],
648+
start_id="1",
649+
agent_mode=AgentMode.PRODUCT_ANALYTICS,
650+
)
651+
652+
result = await self.context_manager.get_state_messages_with_context(state)
653+
654+
assert result is not None
655+
self.assertEqual(len(result), 2)
656+
assert isinstance(result[0], ContextMessage)
657+
self.assertIn("Your initial mode is", result[0].content)
658+
self.assertIsInstance(result[1], HumanMessage)

ee/hogai/eval/ci/conftest.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Generator
44

55
import pytest
6+
from unittest.mock import patch
67

78
from django.test import override_settings
89

@@ -33,59 +34,62 @@
3334

3435
@pytest.fixture
3536
def call_root_for_insight_generation(demo_org_team_user):
36-
# This graph structure will first get a plan, then generate the SQL query.
37-
graph = (
38-
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
39-
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
40-
.add_root()
41-
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
42-
.compile(checkpointer=DjangoCheckpointer())
43-
)
44-
45-
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
46-
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
47-
# in case there's an ask_user tool call.
48-
query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
49-
# Initial state for the graph
50-
initial_state = AssistantState(
51-
messages=[HumanMessage(content=f"Answer this question: {query}")],
37+
with patch("ee.hogai.utils.feature_flags.has_agent_modes_feature_flag", return_value=True):
38+
# This graph structure will first get a plan, then generate the SQL query.
39+
graph = (
40+
AssistantGraph(demo_org_team_user[1], demo_org_team_user[2])
41+
.add_edge(AssistantNodeName.START, AssistantNodeName.ROOT)
42+
.add_root()
43+
# TRICKY: We need to set a checkpointer here because async tests create a new event loop.
44+
.compile(checkpointer=DjangoCheckpointer())
5245
)
53-
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
5446

55-
# Invoke the graph. The state will be updated through planner and then generator.
56-
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
47+
async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput:
48+
# If query_with_extra_context is a tuple, the first element is the query, the second is the extra context
49+
# in case there's an ask_user tool call.
50+
query = (
51+
query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context
52+
)
53+
# Initial state for the graph
54+
initial_state = AssistantState(
55+
messages=[HumanMessage(content=f"Answer this question: {query}")],
56+
)
57+
conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2])
5758

58-
final_state = AssistantState.model_validate(final_state_raw)
59+
# Invoke the graph. The state will be updated through planner and then generator.
60+
final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}})
5961

60-
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
61-
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
62-
if isinstance(query_with_extra_context, tuple) and not any(
63-
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
64-
):
65-
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
66-
final_state.graph_status = "resumed"
67-
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
6862
final_state = AssistantState.model_validate(final_state_raw)
6963

70-
# The order is a viz message, tool call message, and assistant message.
71-
if (
72-
not final_state.messages
73-
or not len(final_state.messages) >= 3
74-
or not isinstance(final_state.messages[-3], VisualizationMessage)
75-
):
64+
# If we have extra context for the potential ask_user tool, and there's no message of type ai/failure
65+
# or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case.
66+
if isinstance(query_with_extra_context, tuple) and not any(
67+
isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages
68+
):
69+
final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])]
70+
final_state.graph_status = "resumed"
71+
final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}})
72+
final_state = AssistantState.model_validate(final_state_raw)
73+
74+
# The order is a viz message, tool call message, and assistant message.
75+
if (
76+
not final_state.messages
77+
or not len(final_state.messages) >= 3
78+
or not isinstance(final_state.messages[-3], VisualizationMessage)
79+
):
80+
return {
81+
"plan": None,
82+
"query": None,
83+
"query_generation_retry_count": final_state.query_generation_retry_count,
84+
}
85+
7686
return {
77-
"plan": None,
78-
"query": None,
87+
"plan": final_state.messages[-3].plan,
88+
"query": final_state.messages[-3].answer,
7989
"query_generation_retry_count": final_state.query_generation_retry_count,
8090
}
8191

82-
return {
83-
"plan": final_state.messages[-3].plan,
84-
"query": final_state.messages[-3].answer,
85-
"query_generation_retry_count": final_state.query_generation_retry_count,
86-
}
87-
88-
return callable
92+
yield callable
8993

9094

9195
@pytest.fixture(scope="package")
Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,29 @@
11
from langchain_core.runnables import RunnableConfig
22

3-
from posthog.schema import AgentMode
4-
53
from ee.hogai.graph.agent_modes.mode_manager import AgentModeManager
64
from ee.hogai.graph.base import AssistantNode
75
from ee.hogai.utils.types import AssistantState, PartialAssistantState
86

97

108
class AgentGraphNode(AssistantNode):
119
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
12-
manager = AgentModeManager(
13-
team=self._team, user=self._user, node_path=self.node_path, mode=AgentMode.PRODUCT_ANALYTICS
14-
)
10+
manager = AgentModeManager(team=self._team, user=self._user, node_path=self.node_path, mode=state.agent_mode)
1511
new_state = await manager.node(state, config)
1612
return new_state
1713

1814
def router(self, state: AssistantState):
19-
manager = AgentModeManager(
20-
team=self._team, user=self._user, node_path=self.node_path, mode=AgentMode.PRODUCT_ANALYTICS
21-
)
15+
manager = AgentModeManager(team=self._team, user=self._user, node_path=self.node_path, mode=state.agent_mode)
2216
next_node = manager.node.router(state)
2317
return next_node
2418

2519

2620
class AgentGraphToolsNode(AssistantNode):
2721
async def arun(self, state: AssistantState, config: RunnableConfig) -> PartialAssistantState | None:
28-
manager = AgentModeManager(
29-
team=self._team, user=self._user, node_path=self.node_path, mode=AgentMode.PRODUCT_ANALYTICS
30-
)
22+
manager = AgentModeManager(team=self._team, user=self._user, node_path=self.node_path, mode=state.agent_mode)
3123
new_state = await manager.tools_node(state, config)
3224
return new_state
3325

3426
def router(self, state: AssistantState):
35-
manager = AgentModeManager(
36-
team=self._team, user=self._user, node_path=self.node_path, mode=AgentMode.PRODUCT_ANALYTICS
37-
)
27+
manager = AgentModeManager(team=self._team, user=self._user, node_path=self.node_path, mode=state.agent_mode)
3828
next_node = manager.tools_node.router(state)
3929
return next_node

ee/hogai/graph/agent_modes/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .compaction_manager import AnthropicConversationCompactionManager, ConversationCompactionManager
22
from .const import SLASH_COMMAND_INIT, SLASH_COMMAND_REMEMBER
3-
from .factory import AgentExample, AgentModeDefinition
3+
from .factory import AgentModeDefinition
44
from .mode_manager import AgentModeManager
55
from .nodes import AgentExecutable, AgentToolkit, AgentToolsExecutable
66

@@ -10,7 +10,6 @@
1010
"AgentToolkit",
1111
"AgentModeManager",
1212
"AgentModeDefinition",
13-
"AgentExample",
1413
"SLASH_COMMAND_INIT",
1514
"SLASH_COMMAND_REMEMBER",
1615
"AnthropicConversationCompactionManager",
Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,10 @@
1-
from dataclasses import dataclass, field
2-
3-
from pydantic import BaseModel
1+
from dataclasses import dataclass
42

53
from posthog.schema import AgentMode
64

75
from .nodes import AgentExecutable, AgentToolkit, AgentToolsExecutable
86

97

10-
class AgentExample(BaseModel):
11-
"""
12-
Custom agent example to correct the agent's behavior through few-shot prompting.
13-
The example will be formatted as follows:
14-
```
15-
<example>
16-
{example}
17-
18-
<reasoning>
19-
{reasoning}
20-
</reasoning>
21-
</example>
22-
```
23-
"""
24-
25-
example: str
26-
reasoning: str | None = None
27-
28-
298
@dataclass
309
class AgentModeDefinition:
3110
mode: AgentMode
@@ -38,7 +17,3 @@ class AgentModeDefinition:
3817
"""A custom node class to use for the agent."""
3918
tools_node_class: type[AgentToolsExecutable] = AgentToolsExecutable
4019
"""A custom tools node class to use for the agent."""
41-
positive_todo_examples: list[AgentExample] = field(default_factory=list)
42-
"""Positive examples that will be injected into the `todo_write` tool. Use this field to explain the agent how it should orchestrate complex tasks using provided tools."""
43-
negative_todo_examples: list[AgentExample] = field(default_factory=list)
44-
"""Negative examples that will be injected into the `todo_write` tool. Use this field to explain the agent how it should **NOT** orchestrate tasks using provided tools."""

ee/hogai/graph/agent_modes/mode_manager.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from posthog.models import Team, User
88

9+
from ee.hogai.utils.feature_flags import has_agent_modes_feature_flag
910
from ee.hogai.utils.types.base import NodePath
1011

1112
if TYPE_CHECKING:
@@ -31,7 +32,10 @@ def __init__(self, *, team: Team, user: User, node_path: tuple[NodePath, ...], m
3132
self._team = team
3233
self._user = user
3334
self._node_path = node_path
34-
self._mode = mode or AgentMode.PRODUCT_ANALYTICS
35+
if has_agent_modes_feature_flag(team, user):
36+
self._mode = mode or AgentMode.PRODUCT_ANALYTICS
37+
else:
38+
self._mode = AgentMode.PRODUCT_ANALYTICS
3539

3640
@property
3741
def node(self) -> "AgentExecutable":

0 commit comments

Comments
 (0)