Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/stupid-plums-wear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"agents": patch
---

fix: add session ID and header support to SSE transport

The SSE transport now properly forwards session IDs and request headers to MCP message handlers, achieving closer header parity with StreamableHTTP transport. This allows MCP servers using SSE to access request headers for session management.
14 changes: 9 additions & 5 deletions packages/agents/src/mcp/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import type { Server } from "@modelcontextprotocol/sdk/server/index.js";
import type { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { Transport } from "@modelcontextprotocol/sdk/shared/transport.js";
import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js";
import type {
JSONRPCMessage,
MessageExtraInfo
} from "@modelcontextprotocol/sdk/types.js";
import {
JSONRPCMessageSchema,
isJSONRPCError,
Expand Down Expand Up @@ -77,7 +80,7 @@ export abstract class McpAgent<
}

/** Get the unique WebSocket. SSE transport only. */
private getWebSocket() {
getWebSocket() {
const websockets = Array.from(this.getConnections());
if (websockets.length === 0) {
return null;
Expand All @@ -89,7 +92,7 @@ export abstract class McpAgent<
private initTransport() {
switch (this.getTransportType()) {
case "sse": {
return new McpSSETransport(() => this.getWebSocket());
return new McpSSETransport();
}
case "streamable-http": {
return new StreamableHTTPServerTransport({});
Expand Down Expand Up @@ -188,7 +191,8 @@ export abstract class McpAgent<
/** Handles MCP Messages for the legacy SSE transport. */
async onSSEMcpMessage(
_sessionId: string,
messageBody: unknown
messageBody: unknown,
extraInfo?: MessageExtraInfo
): Promise<Error | null> {
// Since we address the DO via both the protocol and the session id,
// this should never happen, but let's enforce it just in case
Expand All @@ -210,7 +214,7 @@ export abstract class McpAgent<
return null; // Message was handled by elicitation system
}

this._transport?.onmessage?.(parsedMessage);
this._transport?.onmessage?.(parsedMessage, extraInfo);
return null;
} catch (error) {
console.error("Error forwarding message to SSE:", error);
Expand Down
13 changes: 9 additions & 4 deletions packages/agents/src/mcp/transport.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@ import { MessageType } from "../ai-types";
import { MCP_HTTP_METHOD_HEADER, MCP_MESSAGE_HEADER } from "./utils";

export class McpSSETransport implements Transport {
sessionId?: string;
sessionId: string;
// Set by the server in `server.connect(transport)`
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
onmessage?: (message: JSONRPCMessage, extra?: MessageExtraInfo) => void;

private _getWebSocket: () => WebSocket | null;
private _started = false;
constructor(getWebSocket: () => WebSocket | null) {
this._getWebSocket = getWebSocket;
constructor() {
const { agent } = getCurrentAgent<McpAgent>();
if (!agent)
throw new Error("McpAgent was not found in Transport constructor");

this.sessionId = agent.getSessionId();
this._getWebSocket = () => agent.getWebSocket();
}

async start() {
Expand Down
20 changes: 19 additions & 1 deletion packages/agents/src/mcp/utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import {
JSONRPCMessageSchema,
type JSONRPCMessage,
type MessageExtraInfo,
InitializeRequestSchema,
isJSONRPCResponse,
isJSONRPCNotification
Expand Down Expand Up @@ -682,7 +683,24 @@ export const createLegacySseHandler = (
});

const messageBody = await request.json();
const error = await agent.onSSEMcpMessage(sessionId, messageBody);

// Build MessageExtraInfo with filtered headers
const headers = Object.fromEntries(request.headers.entries());

// Remove internal headers that are not part of the original request
delete headers[MCP_HTTP_METHOD_HEADER];
delete headers[MCP_MESSAGE_HEADER];
delete headers.upgrade;

const extraInfo: MessageExtraInfo = {
requestInfo: { headers }
};

const error = await agent.onSSEMcpMessage(
sessionId,
messageBody,
extraInfo
);

if (error) {
return new Response(error.message, {
Expand Down
77 changes: 77 additions & 0 deletions packages/agents/src/tests/mcp/transports/sse.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { createExecutionContext, env } from "cloudflare:test";
import { describe, expect, it } from "vitest";
import type {
CallToolResult,
JSONRPCResponse
} from "@modelcontextprotocol/sdk/types.js";
import worker, { type Env } from "../../worker";
import { establishSSEConnection } from "../../shared/test-utils";

Expand Down Expand Up @@ -117,4 +121,77 @@ describe("SSE Transport", () => {
expect(response.headers.get("Content-Type")).toBe("text/event-stream");
});
});

describe("Header and Auth Handling", () => {
it("should pass headers and session ID to transport via requestInfo", async () => {
const ctx = createExecutionContext();
const { sessionId, reader } = await establishSSEConnection(ctx);

// Send request with custom headers using the echoRequestInfo tool
const request = new Request(`${baseUrl}/message?sessionId=${sessionId}`, {
method: "POST",
body: JSON.stringify({
id: "echo-headers-1",
jsonrpc: "2.0",
method: "tools/call",
params: {
name: "echoRequestInfo",
arguments: {}
}
}),
headers: {
"Content-Type": "application/json",
"x-user-id": "test-user-123",
"x-request-id": "req-456",
"x-custom-header": "custom-value"
}
});

const response = await worker.fetch(request, env, ctx);
expect(response.status).toBe(202); // SSE returns 202 Accepted

// Read the response from the SSE stream
const { value } = await reader.read();
const event = new TextDecoder().decode(value);
const lines = event.split("\n");
expect(lines[0]).toEqual("event: message");

// Parse the JSON response from the data line
const dataLine = lines.find((line) => line.startsWith("data:"));
const parsed = JSON.parse(
dataLine!.replace("data: ", "")
) as JSONRPCResponse;
expect(parsed.id).toBe("echo-headers-1");

// Extract the echoed request info
const result = parsed.result as CallToolResult;
const textContent = result.content?.[0];
if (!textContent || textContent.type !== "text") {
throw new Error("Expected text content in tool result");
}
const echoedData = JSON.parse(textContent.text);

// Verify custom headers were passed through
expect(echoedData.hasRequestInfo).toBe(true);
expect(echoedData.headers["x-user-id"]).toBe("test-user-123");
expect(echoedData.headers["x-request-id"]).toBe("req-456");
expect(echoedData.headers["x-custom-header"]).toBe("custom-value");

// Verify that certain internal headers that the transport adds are NOT exposed
// The transport filters cf-mcp-method, cf-mcp-message, and upgrade headers
expect(echoedData.headers["cf-mcp-method"]).toBeUndefined();
expect(echoedData.headers["cf-mcp-message"]).toBeUndefined();
expect(echoedData.headers.upgrade).toBeUndefined();

// Verify standard headers are also present
expect(echoedData.headers["content-type"]).toBe("application/json");

// Check what properties are available in extra
expect(echoedData.availableExtraKeys).toBeDefined();

// Verify sessionId is passed through extra data
expect(echoedData.sessionId).toBeDefined();
expect(echoedData.sessionId).toBe(sessionId);
});
});
});
66 changes: 66 additions & 0 deletions packages/agents/src/tests/mcp/transports/streamable-http.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -583,4 +583,70 @@ describe("Streamable HTTP Transport", () => {
expect(tools.some((t) => t.name === "temp-echo")).toBe(false);
});
});

describe("Header and Auth Handling", () => {
it("should pass custom headers to transport via requestInfo", async () => {
const ctx = createExecutionContext();
const sessionId = await initializeStreamableHTTPServer(ctx);

// Send request with custom headers using the echoRequestInfo tool
const echoMessage: JSONRPCMessage = {
id: "echo-headers-1",
jsonrpc: "2.0",
method: "tools/call",
params: {
name: "echoRequestInfo",
arguments: {}
}
};

const request = new Request(baseUrl, {
body: JSON.stringify(echoMessage),
headers: {
Accept: "application/json, text/event-stream",
"Content-Type": "application/json",
"mcp-session-id": sessionId,
"x-user-id": "test-user-123",
"x-request-id": "req-456",
"x-custom-header": "custom-value"
},
method: "POST"
});

const response = await worker.fetch(request, env, ctx);
expect(response.status).toBe(200);

// Parse the SSE response
const sseText = await readSSEEvent(response);
const parsed = parseSSEData(sseText) as JSONRPCResponse;
expect(parsed.id).toBe("echo-headers-1");

// Extract the echoed request info
const result = parsed.result as CallToolResult;
const contentText = result.content?.[0]?.text;
const echoedData = JSON.parse(
typeof contentText === "string" ? contentText : "{}"
);

// Verify custom headers were passed through
expect(echoedData.hasRequestInfo).toBe(true);
expect(echoedData.headers["x-user-id"]).toBe("test-user-123");
expect(echoedData.headers["x-request-id"]).toBe("req-456");
expect(echoedData.headers["x-custom-header"]).toBe("custom-value");

// Verify that certain internal headers that the transport adds are NOT exposed
// The transport adds cf-mcp-method and cf-mcp-message internally but should filter them
expect(echoedData.headers["cf-mcp-method"]).toBeUndefined();
expect(echoedData.headers["cf-mcp-message"]).toBeUndefined();
expect(echoedData.headers.upgrade).toBeUndefined();

// Verify standard headers are also present
expect(echoedData.headers.accept).toContain("text/event-stream");
expect(echoedData.headers["content-type"]).toBe("application/json");

// Verify sessionId is passed through extra data
expect(echoedData.sessionId).toBeDefined();
expect(echoedData.sessionId).toBe(sessionId);
});
});
});
70 changes: 69 additions & 1 deletion packages/agents/src/tests/worker.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import { McpServer } from "@modelcontextprotocol/sdk/server/mcp.js";
import type { CallToolResult } from "@modelcontextprotocol/sdk/types.js";
import type { RequestHandlerExtra } from "@modelcontextprotocol/sdk/shared/protocol.js";
import type {
CallToolResult,
IsomorphicHeaders,
ServerNotification,
ServerRequest
} from "@modelcontextprotocol/sdk/types.js";
import { z } from "zod";
import { McpAgent } from "../mcp/index.ts";
import {
Expand Down Expand Up @@ -39,6 +45,19 @@ type Props = {
testValue: string;
};

type ToolExtraInfo = RequestHandlerExtra<ServerRequest, ServerNotification>;

type EchoResponseData = {
headers: IsomorphicHeaders;
authInfo: ToolExtraInfo["authInfo"] | null;
hasRequestInfo: boolean;
hasAuthInfo: boolean;
requestId: ToolExtraInfo["requestId"];
sessionId: string | null;
availableExtraKeys: string[];
[key: string]: unknown;
};

export class TestMcpAgent extends McpAgent<Env, State, Props> {
observability = undefined;
private tempToolHandle?: { remove: () => void };
Expand Down Expand Up @@ -169,6 +188,55 @@ export class TestMcpAgent extends McpAgent<Env, State, Props> {
return { content: [{ type: "text", text: "nothing to remove" }] };
}
);

// Echo request info for testing header and auth passthrough
this.server.tool(
"echoRequestInfo",
"Echo back request headers and auth info",
{},
async (_args, extra: ToolExtraInfo): Promise<CallToolResult> => {
// Extract headers from requestInfo, auth from authInfo
const headers: IsomorphicHeaders = extra.requestInfo?.headers ?? {};
const authInfo = extra.authInfo ?? null;

// Track non-function properties available in extra
const extraRecord = extra as Record<string, unknown>;
const extraKeys = Object.keys(extraRecord).filter(
(key) => typeof extraRecord[key] !== "function"
);

// Build response object with all available data
const responseData: EchoResponseData = {
headers,
authInfo,
hasRequestInfo: !!extra.requestInfo,
hasAuthInfo: !!extra.authInfo,
requestId: extra.requestId,
// Include any sessionId if it exists
sessionId: extra.sessionId ?? null,
// List all available properties in extra
availableExtraKeys: extraKeys
};

// Add any other properties from extra that aren't already included
extraKeys.forEach((key) => {
if (
!["requestInfo", "authInfo", "requestId", "sessionId"].includes(key)
) {
responseData[`extra_${key}`] = extraRecord[key];
}
});

return {
content: [
{
type: "text",
text: JSON.stringify(responseData, null, 2)
}
]
};
}
);
}
}

Expand Down
Loading