fix(web): run URL SSRF checks off the event loop in async paths
Add async_is_safe_url() wrapping is_safe_url via asyncio.to_thread, and route all async SSRF call sites through it: web_extract_tool, the vision/video preflight checks, and both download redirect guards. socket.getaddrinfo blocks; calling it inline from async tool paths froze the event loop for the duration of DNS resolution. vision_tools: split _validate_image_url into _image_url_shape_ok (no DNS) + sync _validate_image_url (for sync callers/tests) + async _validate_image_url_async. Widened beyond the original PR #3691 to sibling async sites that also blocked the loop (second redirect guard, video preflight). Salvage of #3691 by @Kewe63 — surgically re-applied onto current main because the original branch was too stale to cherry-pick cleanly (would have reverted the web_crawl_tool refactor). Co-authored-by: Kewe63 <kewe.3217@gmail.com>
This commit is contained in:
@ -372,7 +372,8 @@ class TestVisionDispatchLoopSafety:
|
||||
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._validate_image_url",
|
||||
"tools.vision_tools._validate_image_url_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
@ -416,7 +417,8 @@ class TestVisionDispatchLoopSafety:
|
||||
side_effect=lambda url, dest, **kw: _write_fake_image(dest),
|
||||
),
|
||||
patch(
|
||||
"tools.vision_tools._validate_image_url",
|
||||
"tools.vision_tools._validate_image_url_async",
|
||||
new_callable=AsyncMock,
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import patch
|
||||
|
||||
from tools.url_safety import (
|
||||
is_safe_url,
|
||||
async_is_safe_url,
|
||||
is_always_blocked_url,
|
||||
_is_blocked_ip,
|
||||
_global_allow_private_urls,
|
||||
@ -195,6 +196,24 @@ class TestIsSafeUrl:
|
||||
assert is_safe_url("https://multimedia.nt.qq.com.cn/download?id=123") is False
|
||||
|
||||
|
||||
class TestAsyncIsSafeUrl:
|
||||
"""async_is_safe_url must match is_safe_url (runs DNS in a thread pool)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_url_allowed(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("93.184.216.34", 0)),
|
||||
]):
|
||||
assert await async_is_safe_url("https://example.com/x") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localhost_blocked(self):
|
||||
with patch("socket.getaddrinfo", return_value=[
|
||||
(2, 1, 6, "", ("127.0.0.1", 0)),
|
||||
]):
|
||||
assert await async_is_safe_url("http://localhost:8080/") is False
|
||||
|
||||
|
||||
class TestIsBlockedIp:
|
||||
"""Direct tests for the _is_blocked_ip helper."""
|
||||
|
||||
|
||||
@ -297,7 +297,7 @@ class TestErrorLoggingExcInfo:
|
||||
async def test_analysis_error_logs_exc_info(self, caplog):
|
||||
"""When vision_analyze_tool encounters an error, it should log with exc_info."""
|
||||
with (
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
|
||||
patch(
|
||||
"tools.vision_tools._download_image",
|
||||
new_callable=AsyncMock,
|
||||
@ -329,7 +329,7 @@ class TestErrorLoggingExcInfo:
|
||||
return dest
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
|
||||
patch("tools.vision_tools._download_image", side_effect=fake_download),
|
||||
patch(
|
||||
"tools.vision_tools._image_to_base64_data_url",
|
||||
@ -451,7 +451,7 @@ class TestVisionSafetyGuards:
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools.check_website_access", return_value=blocked),
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True),
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download,
|
||||
):
|
||||
result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe"))
|
||||
@ -549,7 +549,9 @@ class TestTildeExpansion:
|
||||
img = fake_home / "test_image.png"
|
||||
img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8)
|
||||
|
||||
# Windows expanduser() prefers USERPROFILE over HOME; POSIX uses HOME.
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setenv("USERPROFILE", str(fake_home))
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_choice = MagicMock()
|
||||
@ -580,6 +582,7 @@ class TestTildeExpansion:
|
||||
fake_home = tmp_path / "fakehome"
|
||||
fake_home.mkdir()
|
||||
monkeypatch.setenv("HOME", str(fake_home))
|
||||
monkeypatch.setenv("USERPROFILE", str(fake_home))
|
||||
|
||||
result = await vision_analyze_tool(
|
||||
"~/nonexistent.png", "describe this", "test/model"
|
||||
|
||||
@ -372,7 +372,10 @@ class TestWebToolPolicy:
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
async def _allow_ssrf(_url: str) -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
|
||||
# The per-URL website-policy gate moved into the firecrawl plugin's
|
||||
# extract() during the web-provider migration. Patch it at the new
|
||||
# location.
|
||||
@ -406,7 +409,10 @@ class TestWebToolPolicy:
|
||||
from plugins.web.firecrawl import provider as firecrawl_provider
|
||||
|
||||
# Allow test URLs past SSRF check so website policy is what gets tested
|
||||
monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True)
|
||||
async def _allow_ssrf(_url: str) -> bool:
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf)
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test":
|
||||
|
||||
@ -27,6 +27,7 @@ import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import asyncio
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from utils import is_truthy_value
|
||||
@ -349,3 +350,12 @@ def is_safe_url(url: str) -> bool:
|
||||
# become SSRF bypass vectors
|
||||
logger.warning("Blocked request — URL safety check error for %s: %s", url, exc)
|
||||
return False
|
||||
|
||||
|
||||
async def async_is_safe_url(url: str) -> bool:
|
||||
"""Same rules as :func:`is_safe_url`, but run the DNS work off the event loop.
|
||||
|
||||
``socket.getaddrinfo`` can block; call this from async code paths (gateway,
|
||||
``web_extract_tool``, vision download hooks) instead of ``is_safe_url``.
|
||||
"""
|
||||
return await asyncio.to_thread(is_safe_url, url)
|
||||
|
||||
@ -74,35 +74,36 @@ _VISION_DOWNLOAD_TIMEOUT = _resolve_download_timeout()
|
||||
_VISION_MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""
|
||||
Basic validation of image URL format.
|
||||
|
||||
Args:
|
||||
url (str): The URL to validate
|
||||
|
||||
Returns:
|
||||
bool: True if URL appears to be valid, False otherwise
|
||||
"""
|
||||
def _image_url_shape_ok(url: str) -> bool:
|
||||
"""HTTP(S) shape check only (scheme, netloc). No DNS."""
|
||||
if not url or not isinstance(url, str):
|
||||
return False
|
||||
|
||||
# Basic HTTP/HTTPS URL check
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return False
|
||||
|
||||
# Parse to ensure we at least have a network location; still allow URLs
|
||||
# without file extensions (e.g. CDN endpoints that redirect to images).
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _validate_image_url(url: str) -> bool:
|
||||
"""Validate image URL for sync callers and tests (SSRF via sync DNS check)."""
|
||||
if not _image_url_shape_ok(url):
|
||||
return False
|
||||
# Block private/internal addresses to prevent SSRF
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(url):
|
||||
return False
|
||||
return is_safe_url(url)
|
||||
|
||||
return True
|
||||
|
||||
async def _validate_image_url_async(url: str) -> bool:
|
||||
"""Validate remote image URL without blocking the event loop on DNS."""
|
||||
if not _image_url_shape_ok(url):
|
||||
return False
|
||||
from tools.url_safety import async_is_safe_url
|
||||
return await async_is_safe_url(url)
|
||||
|
||||
|
||||
def _detect_image_mime_type(image_path: Path) -> Optional[str]:
|
||||
@ -181,8 +182,8 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
||||
"""
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
from tools.url_safety import async_is_safe_url
|
||||
if not await async_is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {redirect_url}"
|
||||
)
|
||||
@ -716,7 +717,7 @@ async def _vision_analyze_native(
|
||||
if local_path.is_file():
|
||||
temp_image_path = local_path
|
||||
should_cleanup = False
|
||||
elif _validate_image_url(image_url):
|
||||
elif await _validate_image_url_async(image_url):
|
||||
blocked = check_website_access(image_url)
|
||||
if blocked:
|
||||
return tool_error(blocked["message"], success=False)
|
||||
@ -870,7 +871,7 @@ async def vision_analyze_tool(
|
||||
logger.info("Using local image file: %s", image_url)
|
||||
temp_image_path = local_path
|
||||
should_cleanup = False # Don't delete cached/local files
|
||||
elif _validate_image_url(image_url):
|
||||
elif await _validate_image_url_async(image_url):
|
||||
# Remote URL -- download to a temporary location
|
||||
blocked = check_website_access(image_url)
|
||||
if blocked:
|
||||
@ -1265,8 +1266,8 @@ async def _download_video(video_url: str, destination: Path, max_retries: int =
|
||||
async def _ssrf_redirect_guard(response):
|
||||
if response.is_redirect and response.next_request:
|
||||
redirect_url = str(response.next_request.url)
|
||||
from tools.url_safety import is_safe_url
|
||||
if not is_safe_url(redirect_url):
|
||||
from tools.url_safety import async_is_safe_url
|
||||
if not await async_is_safe_url(redirect_url):
|
||||
raise ValueError(
|
||||
f"Blocked redirect to private/internal address: {redirect_url}"
|
||||
)
|
||||
@ -1372,7 +1373,7 @@ async def video_analyze_tool(
|
||||
logger.info("Using local video file: %s", video_url)
|
||||
temp_video_path = local_path
|
||||
should_cleanup = False
|
||||
elif _validate_image_url(video_url):
|
||||
elif await _validate_image_url_async(video_url):
|
||||
blocked = check_website_access(video_url)
|
||||
if blocked:
|
||||
raise PermissionError(blocked["message"])
|
||||
|
||||
@ -102,7 +102,7 @@ from tools.tool_backend_helpers import ( # noqa: F401
|
||||
nous_tool_gateway_unavailable_message,
|
||||
prefers_gateway,
|
||||
)
|
||||
from tools.url_safety import is_safe_url
|
||||
from tools.url_safety import async_is_safe_url
|
||||
import sys
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -934,7 +934,7 @@ async def web_extract_tool(
|
||||
safe_urls = []
|
||||
ssrf_blocked: List[Dict[str, Any]] = []
|
||||
for url in urls:
|
||||
if not is_safe_url(url):
|
||||
if not await async_is_safe_url(url):
|
||||
ssrf_blocked.append({
|
||||
"url": url, "title": "", "content": "",
|
||||
"error": "Blocked: URL targets a private or internal network address",
|
||||
|
||||
Reference in New Issue
Block a user