|
12 | 12 | from functools import wraps |
13 | 13 | from pathlib import Path |
14 | 14 | from typing import cast |
15 | | -from unittest.mock import AsyncMock, MagicMock, patch |
| 15 | +from unittest.mock import AsyncMock, patch |
16 | 16 | from uuid import uuid4 |
17 | 17 |
|
18 | 18 | import ldp.agent |
|
21 | 21 | Environment, |
22 | 22 | Tool, |
23 | 23 | ToolRequestMessage, |
| 24 | + ToolResponseMessage, |
24 | 25 | ToolsAdapter, |
25 | 26 | ToolSelector, |
26 | 27 | ) |
@@ -469,26 +470,27 @@ async def test_timeout(agent_test_settings: Settings, agent_type: str | type) -> |
469 | 470 | agent_test_settings.agent.timeout = 0.05 # Give time for Environment.reset() |
470 | 471 | agent_test_settings.llm = "gpt-4o-mini" |
471 | 472 | agent_test_settings.agent.tool_names = {"gen_answer", "complete"} |
472 | | - docs = Docs() |
| 473 | + orig_exec_tool_calls = PaperQAEnvironment.exec_tool_calls |
| 474 | + tool_responses: list[list[ToolResponseMessage]] = [] |
473 | 475 |
|
474 | | - async def custom_aget_evidence(*_, **kwargs) -> PQASession: # noqa: RUF029 |
475 | | - return kwargs["query"] |
| 476 | + async def spy_exec_tool_calls(*args, **kwargs) -> list[ToolResponseMessage]: |
| 477 | + responses = await orig_exec_tool_calls(*args, **kwargs) |
| 478 | + tool_responses.append(responses) |
| 479 | + return responses |
476 | 480 |
|
477 | | - with ( |
478 | | - patch.object(docs, "docs", {"stub_key": MagicMock(spec_set=Doc)}), |
479 | | - patch.multiple( |
480 | | - Docs, clear_docs=MagicMock(), aget_evidence=custom_aget_evidence |
481 | | - ), |
482 | | - ): |
| 481 | + with patch.object(PaperQAEnvironment, "exec_tool_calls", spy_exec_tool_calls): |
483 | 482 | response = await agent_query( |
484 | 483 | query="Are COVID-19 vaccines effective?", |
485 | 484 | settings=agent_test_settings, |
486 | | - docs=docs, |
487 | 485 | agent_type=agent_type, |
488 | 486 | ) |
489 | 487 | # Ensure that GenerateAnswerTool was called in truncation's failover |
490 | 488 | assert response.status == AgentStatus.TRUNCATED, "Agent did not timeout" |
491 | 489 | assert CANNOT_ANSWER_PHRASE in response.session.answer |
| 490 | + (last_response,) = tool_responses[-1] |
| 491 | + assert ( |
| 492 | + "no papers" in last_response.content |
| 493 | + ), "Expecting agent to been shown specifics on the failure" |
492 | 494 |
|
493 | 495 |
|
494 | 496 | @pytest.mark.flaky(reruns=5, only_rerun=["AssertionError"]) |
|
0 commit comments