diff --git a/.changeset/stupid-plums-wear.md b/.changeset/stupid-plums-wear.md new file mode 100644 index 00000000..f4dac530 --- /dev/null +++ b/.changeset/stupid-plums-wear.md @@ -0,0 +1,7 @@ +--- +"agents": patch +--- + +fix: add session ID and header support to SSE transport + +The SSE transport now properly forwards session IDs and request headers to MCP message handlers, achieving closer header parity with StreamableHTTP transport. This allows MCP servers using SSE to access request headers for session management. diff --git a/packages/agents/src/mcp/index.ts b/packages/agents/src/mcp/index.ts index 03cf4735..082b9314 100644 --- a/packages/agents/src/mcp/index.ts +++ b/packages/agents/src/mcp/index.ts @@ -1,7 +1,10 @@ import type { Server } from "@modelcontextprotocol/sdk/server/index.js"; import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js"; -import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import type { + JSONRPCMessage, + MessageExtraInfo +} from "@modelcontextprotocol/sdk/types.js"; import { JSONRPCMessageSchema, isJSONRPCError, @@ -77,7 +80,7 @@ export abstract class McpAgent< } /** Get the unique WebSocket. SSE transport only. */ - private getWebSocket() { + getWebSocket() { const websockets = Array.from(this.getConnections()); if (websockets.length === 0) { return null; @@ -89,7 +92,7 @@ export abstract class McpAgent< private initTransport() { switch (this.getTransportType()) { case "sse": { - return new McpSSETransport(() => this.getWebSocket()); + return new McpSSETransport(); } case "streamable-http": { return new StreamableHTTPServerTransport({}); @@ -188,7 +191,8 @@ export abstract class McpAgent< /** Handles MCP Messages for the legacy SSE transport. */ async onSSEMcpMessage( _sessionId: string, - messageBody: unknown + messageBody: unknown, + extraInfo?: MessageExtraInfo ): Promise { // Since we address the DO via both the protocol and the session id, // this should never happen, but let's enforce it just in case @@ -210,7 +214,7 @@ export abstract class McpAgent< return null; // Message was handled by elicitation system } - this._transport?.onmessage?.(parsedMessage); + this._transport?.onmessage?.(parsedMessage, extraInfo); return null; } catch (error) { console.error("Error forwarding message to SSE:", error); diff --git a/packages/agents/src/mcp/transport.ts b/packages/agents/src/mcp/transport.ts index 9dade9f4..243044d3 100644 --- a/packages/agents/src/mcp/transport.ts +++ b/packages/agents/src/mcp/transport.ts @@ -16,16 +16,21 @@ import { MessageType } from "../ai-types"; import { MCP_HTTP_METHOD_HEADER, MCP_MESSAGE_HEADER } from "./utils"; export class McpSSETransport implements Transport { - sessionId?: string; + sessionId: string; // Set by the server in `server.connect(transport)` onclose?: () => void; onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void; private _getWebSocket: () => WebSocket | null; private _started = false; - constructor(getWebSocket: () => WebSocket | null) { - this._getWebSocket = getWebSocket; + constructor() { + const { agent } = getCurrentAgent(); + if (!agent) + throw new Error("McpAgent was not found in Transport constructor"); + + this.sessionId = agent.getSessionId(); + this._getWebSocket = () => agent.getWebSocket(); } async start() { diff --git a/packages/agents/src/mcp/utils.ts b/packages/agents/src/mcp/utils.ts index 0996b909..66154dbc 100644 --- a/packages/agents/src/mcp/utils.ts +++ b/packages/agents/src/mcp/utils.ts @@ -1,6 +1,7 @@ import { JSONRPCMessageSchema, type JSONRPCMessage, + type MessageExtraInfo, InitializeRequestSchema, isJSONRPCResponse, isJSONRPCNotification @@ -682,7 +683,19 @@ export const createLegacySseHandler = ( }); const messageBody = await request.json(); - const error = await agent.onSSEMcpMessage(sessionId, messageBody); + + // Build MessageExtraInfo with filtered headers + const headers = Object.fromEntries(request.headers.entries()); + + const extraInfo: MessageExtraInfo = { + requestInfo: { headers } + }; + + const error = await agent.onSSEMcpMessage( + sessionId, + messageBody, + extraInfo + ); if (error) { return new Response(error.message, { diff --git a/packages/agents/src/tests/mcp/transports/sse.test.ts b/packages/agents/src/tests/mcp/transports/sse.test.ts index a66cf29e..bca5d24e 100644 --- a/packages/agents/src/tests/mcp/transports/sse.test.ts +++ b/packages/agents/src/tests/mcp/transports/sse.test.ts @@ -1,5 +1,9 @@ import { createExecutionContext, env } from "cloudflare:test"; import { describe, expect, it } from "vitest"; +import type { + CallToolResult, + JSONRPCResponse +} from "@modelcontextprotocol/sdk/types.js"; import worker, { type Env } from "../../worker"; import { establishSSEConnection } from "../../shared/test-utils"; @@ -117,4 +121,77 @@ describe("SSE Transport", () => { expect(response.headers.get("Content-Type")).toBe("text/event-stream"); }); }); + + describe("Header and Auth Handling", () => { + it("should pass headers and session ID to transport via requestInfo", async () => { + const ctx = createExecutionContext(); + const { sessionId, reader } = await establishSSEConnection(ctx); + + // Send request with custom headers using the echoRequestInfo tool + const request = new Request(`${baseUrl}/message?sessionId=${sessionId}`, { + method: "POST", + body: JSON.stringify({ + id: "echo-headers-1", + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "echoRequestInfo", + arguments: {} + } + }), + headers: { + "Content-Type": "application/json", + "x-user-id": "test-user-123", + "x-request-id": "req-456", + "x-custom-header": "custom-value" + } + }); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(202); // SSE returns 202 Accepted + + // Read the response from the SSE stream + const { value } = await reader.read(); + const event = new TextDecoder().decode(value); + const lines = event.split("\n"); + expect(lines[0]).toEqual("event: message"); + + // Parse the JSON response from the data line + const dataLine = lines.find((line) => line.startsWith("data:")); + const parsed = JSON.parse( + dataLine!.replace("data: ", "") + ) as JSONRPCResponse; + expect(parsed.id).toBe("echo-headers-1"); + + // Extract the echoed request info + const result = parsed.result as CallToolResult; + const textContent = result.content?.[0]; + if (!textContent || textContent.type !== "text") { + throw new Error("Expected text content in tool result"); + } + const echoedData = JSON.parse(textContent.text); + + // Verify custom headers were passed through + expect(echoedData.hasRequestInfo).toBe(true); + expect(echoedData.headers["x-user-id"]).toBe("test-user-123"); + expect(echoedData.headers["x-request-id"]).toBe("req-456"); + expect(echoedData.headers["x-custom-header"]).toBe("custom-value"); + + // Verify that certain internal headers that the transport adds are NOT exposed + // The transport filters cf-mcp-method, cf-mcp-message, and upgrade headers + expect(echoedData.headers["cf-mcp-method"]).toBeUndefined(); + expect(echoedData.headers["cf-mcp-message"]).toBeUndefined(); + expect(echoedData.headers.upgrade).toBeUndefined(); + + // Verify standard headers are also present + expect(echoedData.headers["content-type"]).toBe("application/json"); + + // Check what properties are available in extra + expect(echoedData.availableExtraKeys).toBeDefined(); + + // Verify sessionId is passed through extra data + expect(echoedData.sessionId).toBeDefined(); + expect(echoedData.sessionId).toBe(sessionId); + }); + }); }); diff --git a/packages/agents/src/tests/mcp/transports/streamable-http.test.ts b/packages/agents/src/tests/mcp/transports/streamable-http.test.ts index ece62905..1ce6a278 100644 --- a/packages/agents/src/tests/mcp/transports/streamable-http.test.ts +++ b/packages/agents/src/tests/mcp/transports/streamable-http.test.ts @@ -596,4 +596,72 @@ describe("Streamable HTTP Transport", () => { expect(tools.some((t) => t.name === "temp-echo")).toBe(false); }); }); + + describe("Header and Auth Handling", () => { + it("should pass custom headers to transport via requestInfo", async () => { + const ctx = createExecutionContext(); + const sessionId = await initializeStreamableHTTPServer(ctx); + + // Send request with custom headers using the echoRequestInfo tool + const echoMessage: JSONRPCMessage = { + id: "echo-headers-1", + jsonrpc: "2.0", + method: "tools/call", + params: { + name: "echoRequestInfo", + arguments: {} + } + }; + + const request = new Request(baseUrl, { + body: JSON.stringify(echoMessage), + headers: { + Accept: "application/json, text/event-stream", + "Content-Type": "application/json", + "mcp-session-id": sessionId, + "x-user-id": "test-user-123", + "x-request-id": "req-456", + "x-custom-header": "custom-value" + }, + method: "POST" + }); + + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + + // Parse the SSE response + const sseText = await readSSEEvent(response); + const parsed = parseSSEData(sseText) as JSONRPCResponse; + expect(parsed.id).toBe("echo-headers-1"); + + // Extract the echoed request info + const result = parsed.result as CallToolResult; + const firstContent = result.content?.[0]; + const contentText = + firstContent?.type === "text" ? firstContent.text : undefined; + const echoedData = JSON.parse( + typeof contentText === "string" ? contentText : "{}" + ); + + // Verify custom headers were passed through + expect(echoedData.hasRequestInfo).toBe(true); + expect(echoedData.headers["x-user-id"]).toBe("test-user-123"); + expect(echoedData.headers["x-request-id"]).toBe("req-456"); + expect(echoedData.headers["x-custom-header"]).toBe("custom-value"); + + // Verify that certain internal headers that the transport adds are NOT exposed + // The transport adds cf-mcp-method and cf-mcp-message internally but should filter them + expect(echoedData.headers["cf-mcp-method"]).toBeUndefined(); + expect(echoedData.headers["cf-mcp-message"]).toBeUndefined(); + expect(echoedData.headers.upgrade).toBeUndefined(); + + // Verify standard headers are also present + expect(echoedData.headers.accept).toContain("text/event-stream"); + expect(echoedData.headers["content-type"]).toBe("application/json"); + + // Verify sessionId is passed through extra data + expect(echoedData.sessionId).toBeDefined(); + expect(echoedData.sessionId).toBe(sessionId); + }); + }); }); diff --git a/packages/agents/src/tests/worker.ts b/packages/agents/src/tests/worker.ts index 521a25e3..bb681f6b 100644 --- a/packages/agents/src/tests/worker.ts +++ b/packages/agents/src/tests/worker.ts @@ -1,4 +1,11 @@ import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js"; +import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js"; +import type { + CallToolResult, + IsomorphicHeaders, + ServerNotification, + ServerRequest +} from "@modelcontextprotocol/sdk/types.js"; import { z } from "zod"; import { McpAgent } from "../mcp/index.ts"; import { @@ -38,6 +45,19 @@ type Props = { testValue: string; }; +type ToolExtraInfo = RequestHandlerExtra; + +type EchoResponseData = { + headers: IsomorphicHeaders; + authInfo: ToolExtraInfo["authInfo"] | null; + hasRequestInfo: boolean; + hasAuthInfo: boolean; + requestId: ToolExtraInfo["requestId"]; + sessionId: string | null; + availableExtraKeys: string[]; + [key: string]: unknown; +}; + export class TestMcpAgent extends McpAgent { observability = undefined; private tempToolHandle?: { remove: () => void }; @@ -181,6 +201,55 @@ export class TestMcpAgent extends McpAgent { }; } ); + + // Echo request info for testing header and auth passthrough + this.server.tool( + "echoRequestInfo", + "Echo back request headers and auth info", + {}, + async (_args, extra: ToolExtraInfo): Promise => { + // Extract headers from requestInfo, auth from authInfo + const headers: IsomorphicHeaders = extra.requestInfo?.headers ?? {}; + const authInfo = extra.authInfo ?? null; + + // Track non-function properties available in extra + const extraRecord = extra as Record; + const extraKeys = Object.keys(extraRecord).filter( + (key) => typeof extraRecord[key] !== "function" + ); + + // Build response object with all available data + const responseData: EchoResponseData = { + headers, + authInfo, + hasRequestInfo: !!extra.requestInfo, + hasAuthInfo: !!extra.authInfo, + requestId: extra.requestId, + // Include any sessionId if it exists + sessionId: extra.sessionId ?? null, + // List all available properties in extra + availableExtraKeys: extraKeys + }; + + // Add any other properties from extra that aren't already included + extraKeys.forEach((key) => { + if ( + !["requestInfo", "authInfo", "requestId", "sessionId"].includes(key) + ) { + responseData[`extra_${key}`] = extraRecord[key]; + } + }); + + return { + content: [ + { + type: "text", + text: JSON.stringify(responseData, null, 2) + } + ] + }; + } + ); } }