Skip to content

Commit 6d4a5e1

Browse files
authored
feat: add ability to return Command to interceptors and add end-to-end test (#372)
Allow returning Command to interceptors and add end-to-end test
1 parent 092da2a commit 6d4a5e1

File tree

5 files changed

+444
-15
lines changed

5 files changed

+444
-15
lines changed

langchain_mcp_adapters/interceptors.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,24 @@
1616
from mcp.types import CallToolResult
1717
from typing_extensions import NotRequired, TypedDict, Unpack
1818

19+
try:
20+
# langgraph installed
21+
import langgraph
22+
23+
LANGGRAPH_PRESENT = True
24+
except ImportError:
25+
LANGGRAPH_PRESENT = False
26+
27+
1928
if TYPE_CHECKING:
2029
from collections.abc import Awaitable, Callable
2130

22-
MCPToolCallResult = CallToolResult | ToolMessage
31+
if LANGGRAPH_PRESENT:
32+
from langgraph.types import Command
33+
34+
MCPToolCallResult = CallToolResult | ToolMessage | Command
35+
else:
36+
MCPToolCallResult = CallToolResult | ToolMessage
2337

2438

2539
class _MCPToolCallRequestOverrides(TypedDict, total=False):

langchain_mcp_adapters/tools.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,41 @@
3636
)
3737
from langchain_mcp_adapters.sessions import Connection, create_session
3838

39+
try:
40+
# langgraph installed
41+
import langgraph
42+
from langgraph.types import Command
43+
44+
LANGGRAPH_PRESENT = True
45+
except ImportError:
46+
LANGGRAPH_PRESENT = False
47+
3948
NonTextContent = ImageContent | AudioContent | ResourceLink | EmbeddedResource
49+
50+
# Conditional type based on langgraph availability
51+
if LANGGRAPH_PRESENT:
52+
ConvertedToolResult = str | list[str] | ToolMessage | Command
53+
else:
54+
ConvertedToolResult = str | list[str] | ToolMessage
55+
4056
MAX_ITERATIONS = 1000
4157

4258

4359
def _convert_call_tool_result(
4460
call_tool_result: MCPToolCallResult,
45-
) -> tuple[str | list[str] | ToolMessage, list[NonTextContent] | None]:
61+
) -> tuple[ConvertedToolResult, list[NonTextContent] | None]:
4662
"""Convert MCP MCPToolCallResult to LangChain tool result format.
4763
4864
Args:
4965
call_tool_result: The result from calling an MCP tool. Can be either
50-
a CallToolResult (MCP format) or a ToolMessage (LangChain format).
66+
a CallToolResult (MCP format), a ToolMessage (LangChain format),
67+
or a Command (LangGraph format, if langgraph is installed).
5168
5269
Returns:
53-
A tuple containing the text content (which may be a ToolMessage) and any
54-
non-text content. When a ToolMessage is returned by an interceptor, it's
55-
placed in the first position of the tuple as the content, with None as
56-
the artifact.
70+
A tuple containing the text content (which may be a ToolMessage or Command)
71+
and any non-text content. When a ToolMessage or Command is returned by an
72+
interceptor, it's placed in the first position of the tuple as the content,
73+
with None as the artifact.
5774
5875
Raises:
5976
ToolException: If the tool call resulted in an error.
@@ -63,6 +80,10 @@ def _convert_call_tool_result(
6380
if isinstance(call_tool_result, ToolMessage):
6481
return call_tool_result, None
6582

83+
# If the interceptor returned a Command (LangGraph), return it directly
84+
if LANGGRAPH_PRESENT and isinstance(call_tool_result, Command):
85+
return call_tool_result, None
86+
6687
# Otherwise, convert from CallToolResult
6788
text_contents: list[TextContent] = []
6889
non_text_contents = []
@@ -188,7 +209,7 @@ def convert_mcp_tool_to_langchain_tool(
188209
async def call_tool(
189210
runtime: Any = None, # noqa: ANN401
190211
**arguments: dict[str, Any],
191-
) -> tuple[str | list[str] | ToolMessage, list[NonTextContent] | None]:
212+
) -> tuple[ConvertedToolResult, list[NonTextContent] | None]:
192213
"""Execute tool call with interceptor chain and return formatted result.
193214
194215
Args:
@@ -197,7 +218,8 @@ async def call_tool(
197218
198219
Returns:
199220
A tuple of (text_content, non_text_content), where text_content may be
200-
a ToolMessage if an interceptor returned one directly.
221+
a ToolMessage or Command (if langgraph is installed) if an interceptor
222+
returned one directly.
201223
"""
202224
mcp_callbacks = (
203225
callbacks.to_mcp_format(

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ dependencies = [
1919
]
2020

2121
[dependency-groups]
22+
integration = [
23+
"langchain>=1.0.8",
24+
]
2225
test = [
2326
"pytest>=8.0.0",
2427
"ruff>=0.9.4",

tests/test_tools.py

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1-
from typing import Annotated
1+
import typing
2+
from collections.abc import Callable, Sequence
3+
from typing import Annotated, Any
24
from unittest.mock import AsyncMock, MagicMock
35

46
import httpx
57
import pytest
68
from langchain_core.callbacks import CallbackManagerForToolRun
7-
from langchain_core.messages import ToolMessage
9+
from langchain_core.language_models import LanguageModelInput
10+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
11+
from langchain_core.runnables import Runnable
812
from langchain_core.tools import BaseTool, InjectedToolArg, ToolException, tool
913
from mcp.server import FastMCP
1014
from mcp.types import (
@@ -19,6 +23,7 @@
1923
from pydantic import BaseModel
2024

2125
from langchain_mcp_adapters.client import MultiServerMCPClient
26+
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
2227
from langchain_mcp_adapters.tools import (
2328
_convert_call_tool_result,
2429
convert_mcp_tool_to_langchain_tool,
@@ -526,3 +531,141 @@ async def test_convert_mcp_tool_metadata_variants():
526531
"openWorldHint": None,
527532
"_meta": {"flag": True},
528533
}
534+
535+
536+
def _create_increment_server():
537+
server = FastMCP(port=8183)
538+
539+
@server.tool()
540+
def increment(value: int) -> str:
541+
"""Increment a counter"""
542+
return f"Incremented to {value + 1}"
543+
544+
return server
545+
546+
547+
try:
548+
import langchain
549+
550+
LANGCHAIN_INSTALLED = True
551+
except ImportError:
552+
LANGCHAIN_INSTALLED = False
553+
554+
from langchain_core.language_models.fake_chat_models import GenericFakeChatModel
555+
556+
557+
class FixedGenericFakeChatModel(GenericFakeChatModel):
558+
def bind_tools(
559+
self,
560+
tools: Sequence[
561+
typing.Dict[str, Any] | type | Callable | BaseTool # noqa: UP006
562+
],
563+
*,
564+
tool_choice: str | None = None,
565+
**kwargs: Any,
566+
) -> Runnable[LanguageModelInput, AIMessage]:
567+
"""Override bind-tools."""
568+
return self
569+
570+
571+
@pytest.mark.skipif(not LANGCHAIN_INSTALLED, reason="langchain not installed")
572+
async def test_mcp_tools_with_agent_and_command_interceptor(socket_enabled) -> None:
573+
"""Test Command objects from interceptors work end-to-end with create_agent.
574+
575+
This test verifies that:
576+
1. MCP tools can be used with create_agent
577+
2. Interceptors can return Command objects to short-circuit execution
578+
3. Commands can update custom agent state
579+
"""
580+
from langchain.agents import AgentState, create_agent
581+
from langchain.tools import ToolRuntime
582+
from langgraph.checkpoint.memory import MemorySaver
583+
from langgraph.types import Command
584+
585+
from langchain_mcp_adapters.interceptors import MCPToolCallResult
586+
587+
# Interceptor that returns Command to update state
588+
async def counter_interceptor(
589+
request: MCPToolCallRequest,
590+
handler: Callable[[MCPToolCallRequest], typing.Awaitable[MCPToolCallResult]],
591+
) -> Command:
592+
# Instead of calling the tool, return a Command that updates state
593+
tool_runtime: ToolRuntime = request.runtime
594+
assert tool_runtime.tool_call_id == "call_1"
595+
return Command(
596+
update={
597+
"counter": 42,
598+
"messages": [
599+
ToolMessage(
600+
content="Counter updated!",
601+
tool_call_id=tool_runtime.tool_call_id,
602+
),
603+
AIMessage(content="hello"),
604+
],
605+
},
606+
goto="__end__",
607+
)
608+
609+
with run_streamable_http(_create_increment_server, 8183):
610+
# Initialize client and connect to server
611+
client = MultiServerMCPClient(
612+
{
613+
"increment": {
614+
"url": "http://localhost:8183/mcp",
615+
"transport": "streamable_http",
616+
}
617+
},
618+
tool_interceptors=[counter_interceptor],
619+
)
620+
621+
# Get tools from the server
622+
tools = await client.get_tools(server_name="increment")
623+
assert len(tools) == 1
624+
original_tool = tools[0]
625+
assert original_tool.name == "increment"
626+
627+
# Custom state schema with counter field
628+
class CustomAgentState(AgentState):
629+
counter: typing.NotRequired[int]
630+
631+
model = FixedGenericFakeChatModel(
632+
messages=iter(
633+
[
634+
AIMessage(
635+
content="",
636+
tool_calls=[
637+
{
638+
"name": "increment",
639+
"args": {"value": 1},
640+
"id": "call_1",
641+
"type": "tool_call",
642+
}
643+
],
644+
),
645+
AIMessage(
646+
content="The counter has been incremented.",
647+
),
648+
]
649+
)
650+
)
651+
# Create agent with custom state
652+
agent = create_agent(
653+
model,
654+
tools,
655+
state_schema=CustomAgentState,
656+
checkpointer=MemorySaver(),
657+
)
658+
659+
# Run agent
660+
result = await agent.ainvoke(
661+
{"messages": [HumanMessage(content="Increment the counter")], "counter": 0},
662+
{"configurable": {"thread_id": "test_1"}},
663+
)
664+
665+
# Verify Command updated the state
666+
assert result["counter"] == 42
667+
# Verify the Command's message was added
668+
assert any(
669+
isinstance(msg, ToolMessage) and msg.content == "Counter updated!"
670+
for msg in result["messages"]
671+
)

0 commit comments

Comments
 (0)