fix(mcp): fail fast on HTML content-type instead of waiting full connect_timeout
A misconfigured MCP server URL that returns text/html (e.g. pointing at a web app root instead of an MCP endpoint) causes the MCP SDK to block for the full connect_timeout (default 60 s) before surfacing CancelledError. Add a lightweight HEAD pre-flight check that detects text/html responses in ≤5 s and raises ConnectionError with an actionable message. Non-HTML responses, missing headers, and network errors pass through silently so the normal MCP handshake proceeds unaffected. Fixes #36052
This commit is contained in:
137
tests/tools/test_mcp_preflight_content_type.py
Normal file
137
tests/tools/test_mcp_preflight_content_type.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Tests for _MCPServer._preflight_content_type early-fail behaviour."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server():
|
||||
"""Return a minimal MCPServerTask instance (bypasses __init__ complexity)."""
|
||||
s = MCPServerTask.__new__(MCPServerTask)
|
||||
s.name = "test-server"
|
||||
return s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTML response → ConnectionError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_rejects_html(server):
|
||||
"""A text/html response must raise ConnectionError immediately."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "text/html; charset=utf-8"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
with pytest.raises(ConnectionError, match="text/html"):
|
||||
await server._preflight_content_type("https://example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_rejects_html_on_get_fallback(server):
|
||||
"""When HEAD returns 405, fall back to GET — still reject HTML."""
|
||||
head_response = MagicMock()
|
||||
head_response.status_code = 405
|
||||
|
||||
get_response = MagicMock()
|
||||
get_response.status_code = 200
|
||||
get_response.headers = {"content-type": "text/html"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(return_value=head_response)
|
||||
mock_client.get = AsyncMock(return_value=get_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
with pytest.raises(ConnectionError, match="text/html"):
|
||||
await server._preflight_content_type("https://example.com")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-HTML responses → silent pass-through
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_accepts_json(server):
|
||||
"""application/json must NOT raise."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
# Should not raise
|
||||
await server._preflight_content_type("https://mcp-server.example.com/mcp")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_accepts_no_content_type(server):
|
||||
"""Missing Content-Type header must NOT raise."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
await server._preflight_content_type("https://mcp-server.example.com/mcp")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_swallows_network_errors(server):
|
||||
"""Network errors / timeouts must silently pass through."""
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(side_effect=TimeoutError("connect timed out"))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
# Should not raise — let the real MCP handshake deal with it
|
||||
await server._preflight_content_type("https://unreachable.example.com")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preflight_passes_headers_and_verify(server):
|
||||
"""Custom headers and ssl_verify are forwarded to the probe client."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.head = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("httpx.AsyncClient", return_value=mock_client) as client_cls:
|
||||
await server._preflight_content_type(
|
||||
"https://mcp.example.com/mcp",
|
||||
headers={"Authorization": "Bearer tok"},
|
||||
ssl_verify=False,
|
||||
)
|
||||
# Verify the client was created with ssl_verify=False
|
||||
client_cls.assert_called_once()
|
||||
call_kwargs = client_cls.call_args
|
||||
assert call_kwargs.kwargs.get("verify") is False
|
||||
@ -1457,6 +1457,54 @@ class MCPServerTask:
|
||||
# PID-reuse can't surface stale pgroup state later.
|
||||
_stdio_pgids.pop(pid, None)
|
||||
|
||||
@staticmethod
|
||||
async def _preflight_content_type(
|
||||
url: str,
|
||||
*,
|
||||
headers: Optional[dict] = None,
|
||||
ssl_verify: bool = True,
|
||||
timeout: float = 5.0,
|
||||
) -> None:
|
||||
"""Quick content-type probe before handing *url* to the MCP SDK.
|
||||
|
||||
A misconfigured ``mcp_servers.<name>.url`` that points at a plain web
|
||||
app (returning ``text/html``) causes the MCP SDK to sit on the
|
||||
connection for the full ``connect_timeout`` (default 60 s) before
|
||||
surfacing ``CancelledError``. A cheap HEAD request lets us detect
|
||||
this in ≤ 5 s and raise immediately with an actionable message.
|
||||
|
||||
Non-HTML responses (``application/json``, missing header, network
|
||||
errors) silently pass through so the normal MCP handshake proceeds.
|
||||
"""
|
||||
try:
|
||||
import httpx as _httpx
|
||||
|
||||
probe_headers = dict(headers) if headers else {}
|
||||
# HEAD is idempotent and lightweight; fall back to GET if the
|
||||
# server rejects HEAD (405 Method Not Allowed).
|
||||
async with _httpx.AsyncClient(
|
||||
verify=ssl_verify,
|
||||
follow_redirects=True,
|
||||
timeout=_httpx.Timeout(timeout),
|
||||
) as client:
|
||||
resp = await client.head(url, headers=probe_headers)
|
||||
if resp.status_code == 405:
|
||||
resp = await client.get(url, headers=probe_headers)
|
||||
ct = resp.headers.get("content-type", "")
|
||||
if "text/html" in ct.lower():
|
||||
raise ConnectionError(
|
||||
f"MCP server '{url}' returned Content-Type: {ct}. "
|
||||
"This looks like a regular web page, not an MCP endpoint. "
|
||||
"Verify the URL points to an MCP Streamable HTTP or SSE "
|
||||
"endpoint (e.g. https://host/mcp, not https://host/)."
|
||||
)
|
||||
except ConnectionError:
|
||||
raise
|
||||
except Exception:
|
||||
# Network errors, timeouts, etc. — let the real MCP handshake
|
||||
# deal with them; this is just a best-effort early check.
|
||||
pass
|
||||
|
||||
async def _run_http(self, config: dict):
|
||||
"""Run the server using HTTP/StreamableHTTP transport."""
|
||||
if not _MCP_HTTP_AVAILABLE:
|
||||
@ -1467,6 +1515,14 @@ class MCPServerTask:
|
||||
)
|
||||
|
||||
url = config["url"]
|
||||
# Pre-flight: reject obvious non-MCP endpoints (e.g. a web app
|
||||
# returning HTML) in seconds instead of waiting the full
|
||||
# connect_timeout (default 60 s).
|
||||
await self._preflight_content_type(
|
||||
url,
|
||||
headers=dict(config.get("headers") or {}),
|
||||
ssl_verify=config.get("ssl_verify", True),
|
||||
)
|
||||
headers = dict(config.get("headers") or {})
|
||||
# Some MCP servers require MCP-Protocol-Version on the initial
|
||||
# initialize request and reject session-less POSTs otherwise.
|
||||
|
||||
Reference in New Issue
Block a user