Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,8 @@ If you have an SSE MCP server available at `http://localhost:4000/tidewave/mcp`,
## Configuration

`mcp-proxy` either accepts the SSE URL as argument or using the environment variable `SSE_URL`. For debugging purposes, you can also pass `--debug`, which will log debug messages on stderr.

Other supported flags:

* `--max-disconnected-time` the maximum amount of time for trying to reconnect while disconnected. When not set, defaults to infinity.
* `--receive-timeout` the maximum amount of time to wait for an individual reply from the MCP server in milliseconds. Defaults to 60000 (60 seconds).
11 changes: 9 additions & 2 deletions lib/mcp_proxy.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ defmodule McpProxy do
@doc false
def main(args) do
{opts, args} =
OptionParser.parse!(args, strict: [debug: :boolean, max_disconnected_time: :integer])
OptionParser.parse!(args,
strict: [
debug: :boolean,
max_disconnected_time: :integer,
receive_timeout: :integer
]
)

base_url =
case args do
Expand All @@ -23,8 +29,9 @@ defmodule McpProxy do
end

Application.ensure_all_started(:req)
Application.put_all_env(mcp_proxy: opts)

{:ok, handler} = SSE.start_link({base_url, opts})
{:ok, handler} = SSE.start_link(base_url)
ref = Process.monitor(handler)

receive do
Expand Down
100 changes: 73 additions & 27 deletions lib/mcp_proxy/sse.ex
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ defmodule McpProxy.SSE do
end

@impl true
def init({sse_url, opts}) do
def init(sse_url) do
{:ok,
%{
url: sse_url,
endpoint: nil,
debug: Keyword.get(opts, :debug, false),
max_disconnected_time: Keyword.get(opts, :max_disconnected_time),
debug: Application.get_env(:mcp_proxy, :debug, false),
max_disconnected_time: Application.get_env(:mcp_proxy, :max_disconnected_time),
disconnected_since: nil,
state: :connecting,
connect_tries: 0,
Expand Down Expand Up @@ -263,11 +263,17 @@ defmodule McpProxy.SSE do
end

# regular events
def handle_info({:sse_event, {:message, event}}, state),
do: handle_event(event, state, &IO.puts(Jason.encode!(&1)))
def handle_info({:sse_event, {:message, event}}, state) do
handle_event(event, state, {&IO.puts(Jason.encode!(&1)), fn _ -> :ok end})
end

def handle_info({:io_event, event}, state),
do: handle_event(event, state, &forward_message(&1, state.endpoint, state.debug))
def handle_info({:io_event, event}, state) do
handle_event(
event,
state,
{&forward_message(&1, state.endpoint, state.debug), &IO.puts(Jason.encode!(&1))}
)
end

# whenever the HTTP process dies, we try to reconnect
def handle_info({:DOWN, _ref, :process, http_pid, _reason}, %{http_pid: http_pid} = state) do
Expand Down Expand Up @@ -327,52 +333,92 @@ defmodule McpProxy.SSE do
Enum.reduce(messages, state, fn message, state -> handle_event(message, state, handler) end)
end

defp handle_event(%{"jsonrpc" => "2.0", "id" => request_id} = event, state, handler)
defp handle_event(%{"jsonrpc" => "2.0", "id" => request_id} = event, state, {req, resp})
when is_request(event) do
# whenever we get a request from the client (OR the server!)
# we generate a random ID to prevent duplicate IDs, for example when
# a reconnected server decides to send a ping and always starts with ID 0
new_id = random_id!()
event = Map.put(event, "id", new_id)

handler.(event)
state =
case req.(event) do
:ok ->
%{state | id_map: Map.put(state.id_map, new_id, request_id)}

{:reply_error, reply} ->
resp.(Map.put(reply, "id", request_id))
# we don't store the new_id when we already replied to prevent duplicates;
# instead, we'll log an error if a server reply is received later and we already
# replied
state
end

{:noreply, %{state | id_map: Map.put(state.id_map, new_id, request_id)}}
{:noreply, state}
end

defp handle_event(%{"jsonrpc" => "2.0", "id" => response_id} = event, state, handler)
defp handle_event(%{"jsonrpc" => "2.0", "id" => response_id} = event, state, {handler, _})
when is_response(event) do
# whenever we receive a response (from the client or server)
# we fetch the original ID from the id map to present the expected
# ID in the reply
original_id = Map.fetch!(state.id_map, response_id)
event = Map.put(event, "id", original_id)
case state.id_map do
%{^response_id => original_id} ->
event = Map.put(event, "id", original_id)

handler.(event)

{:noreply, %{state | id_map: Map.delete(state.id_map, response_id)}}

handler.(event)
_ ->
Logger.error(
"Did not find original ID for response: #{response_id}. Discarding response!"
)

{:noreply, %{state | id_map: Map.delete(state.id_map, response_id)}}
{:noreply, state}
end
end

# no id, so must be a notification that we can just forward as is
defp handle_event(%{"jsonrpc" => "2.0"} = event, state, handler) do
handler.(event)
defp handle_event(%{"jsonrpc" => "2.0"} = event, state, {req, _}) do
req.(event)

{:noreply, state}
end

## other helpers

defp forward_message(message, endpoint, debug) do
try do
if debug, do: Logger.debug("Forwarding request to server: #{inspect(message)}")
Req.post!(endpoint, json: message)
rescue
error ->
Logger.error(
"Failed to forward message: #{Exception.format(:error, error, __STACKTRACE__)}"
)

# TODO: store message and replay later?
if debug, do: Logger.debug("Forwarding request to server: #{inspect(message)}")

case Req.post(endpoint,
json: message,
receive_timeout: Application.get_env(:mcp_proxy, :receive_timeout, 60_000)
) do
{:ok, %{status: status}} when status in 200..299 ->
:ok

{:ok, %{status: status}} ->
{:reply_error,
%{
jsonrpc: "2.0",
error: %{
code: -32011,
message: "Failed to forward request. Request failed with status code: #{status}"
}
}}

{:error, %Req.TransportError{reason: reason}} ->
Logger.error("Failed to forward message #{inspect(message)}:\n#{inspect(reason)}")

{:reply_error,
%{
jsonrpc: "2.0",
error: %{
code: -32011,
message: "Failed to forward request. Reason: #{inspect(reason)}"
}
}}
end
end

Expand Down
33 changes: 31 additions & 2 deletions test/mcp_proxy_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ defmodule McpProxyTest do

alias McpProxy.SSEServer

setup do
setup context do
{result, _output} =
with_io(:stderr, fn ->
_server_pid = start_supervised!(SSEServer, restart: :temporary)
Expand All @@ -17,7 +17,10 @@ defmodule McpProxyTest do
_main_pid =
spawn_link(fn ->
Process.group_leader(self(), parent)
McpProxy.main(["http://localhost:#{port}/sse", "--debug"])

McpProxy.main(
["http://localhost:#{port}/sse", "--debug"] ++ (context[:extra_args] || [])
)
end)

assert_receive {:io_request, io_pid, reply_as, {:get_line, :unicode, []}}
Expand Down Expand Up @@ -116,6 +119,32 @@ defmodule McpProxyTest do
assert io =~ "Flushing buffer"
end

@tag extra_args: ["--receive-timeout", "50"]
test "handles receive timeout", %{io_pid: io_pid, reply_as: reply_as} do
capture_io(:stderr, fn ->
send_message(
io_pid,
reply_as,
%{
jsonrpc: "2.0",
id: "call-1",
method: "tools/call",
params: %{"name" => "sleep", "arguments" => %{"time" => 100}}
}
)

assert_receive {:io_request, _io_pid, _reply_as, {:get_line, :unicode, []}}
assert_receive {:io_request, put_pid, put_reply_as, {:put_chars, :unicode, json}}, 100
send(put_pid, {:io_reply, put_reply_as, :ok})

assert %{"id" => "call-1", "error" => %{"message" => message}} = Jason.decode!(json)
assert message =~ "Failed to forward request"

# wait an extra 100 milliseconds for the log about discarding a duplicate response
Process.sleep(100)
end) =~ "Discarding!"
end

defp send_message(io_pid, reply_as, json) do
send(io_pid, {:io_reply, reply_as, Jason.encode_to_iodata!(json)})
end
Expand Down
22 changes: 18 additions & 4 deletions test/support/sse_server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ defmodule McpProxy.SSEServer do
post "/message" do
with %{query_params: %{"session_id" => session_id}} <- conn,
[{pid, _}] <- Registry.lookup(McpProxy.SSEServer.Registry, session_id) do
GenServer.cast(pid, {:request, conn.body_params})
GenServer.call(pid, {:request, conn.body_params})

conn
|> put_resp_content_type("application/json")
Expand All @@ -93,16 +93,16 @@ defmodule McpProxy.SSEServer do
end

@impl GenServer
def handle_cast({:request, message}, state) do
def handle_call({:request, message}, _from, state) do
result = handle_message(message)

if result do
case Plug.Conn.chunk(state.conn, ["event: message\ndata: ", result]) do
{:ok, conn} -> {:noreply, %{state | conn: conn}}
{:ok, conn} -> {:reply, :ok, %{state | conn: conn}}
{:error, :closed} -> {:stop, :shutdown, state}
end
else
{:noreply, state}
{:reply, :ok, state}
end
end

Expand Down Expand Up @@ -164,6 +164,20 @@ defmodule McpProxy.SSEServer do
})
end

defp handle_message(%{
"id" => request_id,
"method" => "tools/call",
"params" => %{"name" => "sleep", "arguments" => %{"time" => time}}
}) do
Process.sleep(time)

Jason.encode_to_iodata!(%{
jsonrpc: "2.0",
id: request_id,
result: %{content: [%{text: "Ok"}]}
})
end

defp handle_message(%{"id" => id, "method" => other} = _message) do
Jason.encode_to_iodata!(%{
jsonrpc: "2.0",
Expand Down