fix(web): run URL SSRF checks off the event loop in async paths

Add async_is_safe_url() wrapping is_safe_url via asyncio.to_thread, and route
all async SSRF call sites through it: web_extract_tool, the vision/video
preflight checks, and both download redirect guards. socket.getaddrinfo blocks;
calling it inline from async tool paths froze the event loop for the duration of
DNS resolution.

vision_tools: split _validate_image_url into _image_url_shape_ok (no DNS) +
sync _validate_image_url (for sync callers/tests) + async _validate_image_url_async.

Widened beyond the original PR #3691 to sibling async sites that also blocked
the loop (second redirect guard, video preflight).

Salvage of #3691 by @Kewe63 — surgically re-applied onto current main because
the original branch was too stale to cherry-pick cleanly (would have reverted
the web_crawl_tool refactor).

Co-authored-by: Kewe63 <kewe.3217@gmail.com>
This commit is contained in:
kewe63
2026-06-04 05:57:11 -07:00
committed by Teknium
parent 46b2afc56b
commit c60952ba94
7 changed files with 72 additions and 31 deletions

View File

@ -372,7 +372,8 @@ class TestVisionDispatchLoopSafety:
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
),
patch(
"tools.vision_tools._validate_image_url",
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch(
@ -416,7 +417,8 @@ class TestVisionDispatchLoopSafety:
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
),
patch(
"tools.vision_tools._validate_image_url",
"tools.vision_tools._validate_image_url_async",
new_callable=AsyncMock,
return_value=True,
),
patch(

View File

@ -5,6 +5,7 @@ from unittest.mock import patch
from tools.url_safety import (
is_safe_url,
async_is_safe_url,
is_always_blocked_url,
_is_blocked_ip,
_global_allow_private_urls,
@ -195,6 +196,24 @@ class TestIsSafeUrl:
assert is_safe_url("https://multimedia.nt.qq.com.cn/download?id=123") is False
class TestAsyncIsSafeUrl:
"""async_is_safe_url must match is_safe_url (runs DNS in a thread pool)."""
@pytest.mark.asyncio
async def test_public_url_allowed(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("93.184.216.34", 0)),
]):
assert await async_is_safe_url("https://example.com/x") is True
@pytest.mark.asyncio
async def test_localhost_blocked(self):
with patch("socket.getaddrinfo", return_value=[
(2, 1, 6, "", ("127.0.0.1", 0)),
]):
assert await async_is_safe_url("http://localhost:8080/") is False
class TestIsBlockedIp:
"""Direct tests for the _is_blocked_ip helper."""

View File

@ -297,7 +297,7 @@ class TestErrorLoggingExcInfo:
async def test_analysis_error_logs_exc_info(self, caplog):
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
patch(
"tools.vision_tools._download_image",
new_callable=AsyncMock,
@ -329,7 +329,7 @@ class TestErrorLoggingExcInfo:
return dest
with (
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
patch("tools.vision_tools._download_image", side_effect=fake_download),
patch(
"tools.vision_tools._image_to_base64_data_url",
@ -451,7 +451,7 @@ class TestVisionSafetyGuards:
with (
patch("tools.vision_tools.check_website_access", return_value=blocked),
patch("tools.vision_tools._validate_image_url", return_value=True),
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download,
):
result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe"))
@ -549,7 +549,9 @@ class TestTildeExpansion:
img = fake_home / "test_image.png"
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
# Windows expanduser() prefers USERPROFILE over HOME; POSIX uses HOME.
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setenv("USERPROFILE", str(fake_home))
mock_response = MagicMock()
mock_choice = MagicMock()
@ -580,6 +582,7 @@ class TestTildeExpansion:
fake_home = tmp_path / "fakehome"
fake_home.mkdir()
monkeypatch.setenv("HOME", str(fake_home))
monkeypatch.setenv("USERPROFILE", str(fake_home))
result = await vision_analyze_tool(
"~/nonexistent.png", "describe this", "test/model"

View File

@ -372,7 +372,10 @@ class TestWebToolPolicy:
from plugins.web.firecrawl import provider as firecrawl_provider
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True
monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
# The per-URL website-policy gate moved into the firecrawl plugin's
# extract() during the web-provider migration. Patch it at the new
# location.
@ -406,7 +409,10 @@ class TestWebToolPolicy:
from plugins.web.firecrawl import provider as firecrawl_provider
# Allow test URLs past SSRF check so website policy is what gets tested
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
async def _allow_ssrf(_url: str) -> bool:
return True
monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
def fake_check(url):
if url == "https://allowed.test":

View File

@ -27,6 +27,7 @@ import ipaddress
import logging
import os
import socket
import asyncio
from urllib.parse import urlparse
from utils import is_truthy_value
@ -349,3 +350,12 @@ def is_safe_url(url: str) -> bool:
# become SSRF bypass vectors
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
return False
async def async_is_safe_url(url: str) -> bool:
"""Same rules as :func:`is_safe_url`, but run the DNS work off the event loop.
``socket.getaddrinfo`` can block; call this from async code paths (gateway,
``web_extract_tool``, vision download hooks) instead of ``is_safe_url``.
"""
return await asyncio.to_thread(is_safe_url, url)

View File

@ -74,35 +74,36 @@ _VISION_DOWNLOAD_TIMEOUT = _resolve_download_timeout()
_VISION_MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024
def _validate_image_url(url: str) -> bool:
"""
Basic validation of image URL format.
Args:
url (str): The URL to validate
Returns:
bool: True if URL appears to be valid, False otherwise
"""
def _image_url_shape_ok(url: str) -> bool:
"""HTTP(S) shape check only (scheme, netloc). No DNS."""
if not url or not isinstance(url, str):
return False
# Basic HTTP/HTTPS URL check
if not url.startswith(("http://", "https://")):
return False
# Parse to ensure we at least have a network location; still allow URLs
# without file extensions (e.g. CDN endpoints that redirect to images).
parsed = urlparse(url)
if not parsed.netloc:
return False
return True
def _validate_image_url(url: str) -> bool:
"""Validate image URL for sync callers and tests (SSRF via sync DNS check)."""
if not _image_url_shape_ok(url):
return False
# Block private/internal addresses to prevent SSRF
from tools.url_safety import is_safe_url
if not is_safe_url(url):
return False
return is_safe_url(url)
return True
async def _validate_image_url_async(url: str) -> bool:
"""Validate remote image URL without blocking the event loop on DNS."""
if not _image_url_shape_ok(url):
return False
from tools.url_safety import async_is_safe_url
return await async_is_safe_url(url)
def _detect_image_mime_type(image_path: Path) -> Optional[str]:
@ -181,8 +182,8 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
from tools.url_safety import async_is_safe_url
if not await async_is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {redirect_url}"
)
@ -716,7 +717,7 @@ async def _vision_analyze_native(
if local_path.is_file():
temp_image_path = local_path
should_cleanup = False
elif _validate_image_url(image_url):
elif await _validate_image_url_async(image_url):
blocked = check_website_access(image_url)
if blocked:
return tool_error(blocked["message"], success=False)
@ -870,7 +871,7 @@ async def vision_analyze_tool(
logger.info("Using local image file: %s", image_url)
temp_image_path = local_path
should_cleanup = False # Don't delete cached/local files
elif _validate_image_url(image_url):
elif await _validate_image_url_async(image_url):
# Remote URL -- download to a temporary location
blocked = check_website_access(image_url)
if blocked:
@ -1265,8 +1266,8 @@ async def _download_video(video_url: str, destination: Path, max_retries: int =
async def _ssrf_redirect_guard(response):
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
from tools.url_safety import async_is_safe_url
if not await async_is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {redirect_url}"
)
@ -1372,7 +1373,7 @@ async def video_analyze_tool(
logger.info("Using local video file: %s", video_url)
temp_video_path = local_path
should_cleanup = False
elif _validate_image_url(video_url):
elif await _validate_image_url_async(video_url):
blocked = check_website_access(video_url)
if blocked:
raise PermissionError(blocked["message"])

View File

@ -102,7 +102,7 @@ from tools.tool_backend_helpers import ( # noqa: F401
nous_tool_gateway_unavailable_message,
prefers_gateway,
)
from tools.url_safety import is_safe_url
from tools.url_safety import async_is_safe_url
import sys
logger = logging.getLogger(__name__)
@ -934,7 +934,7 @@ async def web_extract_tool(
safe_urls = []
ssrf_blocked: List[Dict[str, Any]] = []
for url in urls:
if not is_safe_url(url):
if not await async_is_safe_url(url):
ssrf_blocked.append({
"url": url, "title": "", "content": "",
"error": "Blocked: URL targets a private or internal network address",