|
3 | 3 | from collections.abc import Generator |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | +from unittest.mock import patch |
6 | 7 |
|
7 | 8 | from django.test import override_settings |
8 | 9 |
|
|
33 | 34 |
|
34 | 35 | @pytest.fixture |
35 | 36 | def call_root_for_insight_generation(demo_org_team_user): |
36 | | - # This graph structure will first get a plan, then generate the SQL query. |
37 | | - graph = ( |
38 | | - AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) |
39 | | - .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) |
40 | | - .add_root() |
41 | | - # TRICKY: We need to set a checkpointer here because async tests create a new event loop. |
42 | | - .compile(checkpointer=DjangoCheckpointer()) |
43 | | - ) |
44 | | - |
45 | | - async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput: |
46 | | - # If query_with_extra_context is a tuple, the first element is the query, the second is the extra context |
47 | | - # in case there's an ask_user tool call. |
48 | | - query = query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context |
49 | | - # Initial state for the graph |
50 | | - initial_state = AssistantState( |
51 | | - messages=[HumanMessage(content=f"Answer this question: {query}")], |
| 37 | + with patch("ee.hogai.utils.feature_flags.has_agent_modes_feature_flag", return_value=True): |
| 38 | + # This graph structure will first get a plan, then generate the SQL query. |
| 39 | + graph = ( |
| 40 | + AssistantGraph(demo_org_team_user[1], demo_org_team_user[2]) |
| 41 | + .add_edge(AssistantNodeName.START, AssistantNodeName.ROOT) |
| 42 | + .add_root() |
| 43 | + # TRICKY: We need to set a checkpointer here because async tests create a new event loop. |
| 44 | + .compile(checkpointer=DjangoCheckpointer()) |
52 | 45 | ) |
53 | | - conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) |
54 | 46 |
|
55 | | - # Invoke the graph. The state will be updated through planner and then generator. |
56 | | - final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}}) |
| 47 | + async def callable(query_with_extra_context: str | tuple[str, str]) -> PlanAndQueryOutput: |
| 48 | + # If query_with_extra_context is a tuple, the first element is the query, the second is the extra context |
| 49 | + # in case there's an ask_user tool call. |
| 50 | + query = ( |
| 51 | + query_with_extra_context[0] if isinstance(query_with_extra_context, tuple) else query_with_extra_context |
| 52 | + ) |
| 53 | + # Initial state for the graph |
| 54 | + initial_state = AssistantState( |
| 55 | + messages=[HumanMessage(content=f"Answer this question: {query}")], |
| 56 | + ) |
| 57 | + conversation = await Conversation.objects.acreate(team=demo_org_team_user[1], user=demo_org_team_user[2]) |
57 | 58 |
|
58 | | - final_state = AssistantState.model_validate(final_state_raw) |
| 59 | + # Invoke the graph. The state will be updated through planner and then generator. |
| 60 | + final_state_raw = await graph.ainvoke(initial_state, {"configurable": {"thread_id": conversation.id}}) |
59 | 61 |
|
60 | | - # If we have extra context for the potential ask_user tool, and there's no message of type ai/failure |
61 | | - # or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case. |
62 | | - if isinstance(query_with_extra_context, tuple) and not any( |
63 | | - isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages |
64 | | - ): |
65 | | - final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])] |
66 | | - final_state.graph_status = "resumed" |
67 | | - final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}}) |
68 | 62 | final_state = AssistantState.model_validate(final_state_raw) |
69 | 63 |
|
70 | | - # The order is a viz message, tool call message, and assistant message. |
71 | | - if ( |
72 | | - not final_state.messages |
73 | | - or not len(final_state.messages) >= 3 |
74 | | - or not isinstance(final_state.messages[-3], VisualizationMessage) |
75 | | - ): |
| 64 | + # If we have extra context for the potential ask_user tool, and there's no message of type ai/failure |
| 65 | + # or ai/visualization, we should answer with that extra context. We only do this once at most in an eval case. |
| 66 | + if isinstance(query_with_extra_context, tuple) and not any( |
| 67 | + isinstance(m, VisualizationMessage | FailureMessage) for m in final_state.messages |
| 68 | + ): |
| 69 | + final_state.messages = [*final_state.messages, HumanMessage(content=query_with_extra_context[1])] |
| 70 | + final_state.graph_status = "resumed" |
| 71 | + final_state_raw = await graph.ainvoke(final_state, {"configurable": {"thread_id": conversation.id}}) |
| 72 | + final_state = AssistantState.model_validate(final_state_raw) |
| 73 | + |
| 74 | + # The order is a viz message, tool call message, and assistant message. |
| 75 | + if ( |
| 76 | + not final_state.messages |
| 77 | + or not len(final_state.messages) >= 3 |
| 78 | + or not isinstance(final_state.messages[-3], VisualizationMessage) |
| 79 | + ): |
| 80 | + return { |
| 81 | + "plan": None, |
| 82 | + "query": None, |
| 83 | + "query_generation_retry_count": final_state.query_generation_retry_count, |
| 84 | + } |
| 85 | + |
76 | 86 | return { |
77 | | - "plan": None, |
78 | | - "query": None, |
| 87 | + "plan": final_state.messages[-3].plan, |
| 88 | + "query": final_state.messages[-3].answer, |
79 | 89 | "query_generation_retry_count": final_state.query_generation_retry_count, |
80 | 90 | } |
81 | 91 |
|
82 | | - return { |
83 | | - "plan": final_state.messages[-3].plan, |
84 | | - "query": final_state.messages[-3].answer, |
85 | | - "query_generation_retry_count": final_state.query_generation_retry_count, |
86 | | - } |
87 | | - |
88 | | - return callable |
| 92 | + yield callable |
89 | 93 |
|
90 | 94 |
|
91 | 95 | @pytest.fixture(scope="package") |
|
0 commit comments