diff --git a/tests/test_model_tools_async_bridge.py b/tests/test_model_tools_async_bridge.py index 81ffb2cc6..54fce36d2 100644 --- a/tests/test_model_tools_async_bridge.py +++ b/tests/test_model_tools_async_bridge.py @@ -372,7 +372,8 @@ class TestVisionDispatchLoopSafety: side_effect=lambda url, dest, **kw: _write_fake_image(dest), ), patch( - "tools.vision_tools._validate_image_url", + "tools.vision_tools._validate_image_url_async", + new_callable=AsyncMock, return_value=True, ), patch( @@ -416,7 +417,8 @@ class TestVisionDispatchLoopSafety: side_effect=lambda url, dest, **kw: _write_fake_image(dest), ), patch( - "tools.vision_tools._validate_image_url", + "tools.vision_tools._validate_image_url_async", + new_callable=AsyncMock, return_value=True, ), patch( diff --git a/tests/tools/test_url_safety.py b/tests/tools/test_url_safety.py index 8513a848b..a5e00dcf6 100644 --- a/tests/tools/test_url_safety.py +++ b/tests/tools/test_url_safety.py @@ -5,6 +5,7 @@ from unittest.mock import patch from tools.url_safety import ( is_safe_url, + async_is_safe_url, is_always_blocked_url, _is_blocked_ip, _global_allow_private_urls, @@ -195,6 +196,24 @@ class TestIsSafeUrl: assert is_safe_url("https://multimedia.nt.qq.com.cn/download?id=123") is False +class TestAsyncIsSafeUrl: + """async_is_safe_url must match is_safe_url (runs DNS in a thread pool).""" + + @pytest.mark.asyncio + async def test_public_url_allowed(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("93.184.216.34", 0)), + ]): + assert await async_is_safe_url("https://example.com/x") is True + + @pytest.mark.asyncio + async def test_localhost_blocked(self): + with patch("socket.getaddrinfo", return_value=[ + (2, 1, 6, "", ("127.0.0.1", 0)), + ]): + assert await async_is_safe_url("http://localhost:8080/") is False + + class TestIsBlockedIp: """Direct tests for the _is_blocked_ip helper.""" diff --git a/tests/tools/test_vision_tools.py b/tests/tools/test_vision_tools.py index 2edff071e..9373d08f2 100644 --- a/tests/tools/test_vision_tools.py +++ b/tests/tools/test_vision_tools.py @@ -297,7 +297,7 @@ class TestErrorLoggingExcInfo: async def test_analysis_error_logs_exc_info(self, caplog): """When vision_analyze_tool encounters an error, it should log with exc_info.""" with ( - patch("tools.vision_tools._validate_image_url", return_value=True), + patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True), patch( "tools.vision_tools._download_image", new_callable=AsyncMock, @@ -329,7 +329,7 @@ class TestErrorLoggingExcInfo: return dest with ( - patch("tools.vision_tools._validate_image_url", return_value=True), + patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True), patch("tools.vision_tools._download_image", side_effect=fake_download), patch( "tools.vision_tools._image_to_base64_data_url", @@ -451,7 +451,7 @@ class TestVisionSafetyGuards: with ( patch("tools.vision_tools.check_website_access", return_value=blocked), - patch("tools.vision_tools._validate_image_url", return_value=True), + patch("tools.vision_tools._validate_image_url_async", new_callable=AsyncMock, return_value=True), patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download, ): result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe")) @@ -549,7 +549,9 @@ class TestTildeExpansion: img = fake_home / "test_image.png" img.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 8) + # Windows expanduser() prefers USERPROFILE over HOME; POSIX uses HOME. monkeypatch.setenv("HOME", str(fake_home)) + monkeypatch.setenv("USERPROFILE", str(fake_home)) mock_response = MagicMock() mock_choice = MagicMock() @@ -580,6 +582,7 @@ class TestTildeExpansion: fake_home = tmp_path / "fakehome" fake_home.mkdir() monkeypatch.setenv("HOME", str(fake_home)) + monkeypatch.setenv("USERPROFILE", str(fake_home)) result = await vision_analyze_tool( "~/nonexistent.png", "describe this", "test/model" diff --git a/tests/tools/test_website_policy.py b/tests/tools/test_website_policy.py index bfe222ef8..712a37286 100644 --- a/tests/tools/test_website_policy.py +++ b/tests/tools/test_website_policy.py @@ -372,7 +372,10 @@ class TestWebToolPolicy: from plugins.web.firecrawl import provider as firecrawl_provider # Allow test URLs past SSRF check so website policy is what gets tested - monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) + async def _allow_ssrf(_url: str) -> bool: + return True + + monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf) # The per-URL website-policy gate moved into the firecrawl plugin's # extract() during the web-provider migration. Patch it at the new # location. @@ -406,7 +409,10 @@ class TestWebToolPolicy: from plugins.web.firecrawl import provider as firecrawl_provider # Allow test URLs past SSRF check so website policy is what gets tested - monkeypatch.setattr(web_tools, "is_safe_url", lambda url: True) + async def _allow_ssrf(_url: str) -> bool: + return True + + monkeypatch.setattr(web_tools, "async_is_safe_url", _allow_ssrf) def fake_check(url): if url == "https://allowed.test": diff --git a/tools/url_safety.py b/tools/url_safety.py index a0ce297a9..13117d760 100644 --- a/tools/url_safety.py +++ b/tools/url_safety.py @@ -27,6 +27,7 @@ import ipaddress import logging import os import socket +import asyncio from urllib.parse import urlparse from utils import is_truthy_value @@ -349,3 +350,12 @@ def is_safe_url(url: str) -> bool: # become SSRF bypass vectors logger.warning("Blocked request — URL safety check error for %s: %s", url, exc) return False + + +async def async_is_safe_url(url: str) -> bool: + """Same rules as :func:`is_safe_url`, but run the DNS work off the event loop. + + ``socket.getaddrinfo`` can block; call this from async code paths (gateway, + ``web_extract_tool``, vision download hooks) instead of ``is_safe_url``. + """ + return await asyncio.to_thread(is_safe_url, url) diff --git a/tools/vision_tools.py b/tools/vision_tools.py index 0def28142..3187f5476 100644 --- a/tools/vision_tools.py +++ b/tools/vision_tools.py @@ -74,35 +74,36 @@ _VISION_DOWNLOAD_TIMEOUT = _resolve_download_timeout() _VISION_MAX_DOWNLOAD_BYTES = 50 * 1024 * 1024 -def _validate_image_url(url: str) -> bool: - """ - Basic validation of image URL format. - - Args: - url (str): The URL to validate - - Returns: - bool: True if URL appears to be valid, False otherwise - """ +def _image_url_shape_ok(url: str) -> bool: + """HTTP(S) shape check only (scheme, netloc). No DNS.""" if not url or not isinstance(url, str): return False - # Basic HTTP/HTTPS URL check if not url.startswith(("http://", "https://")): return False - # Parse to ensure we at least have a network location; still allow URLs # without file extensions (e.g. CDN endpoints that redirect to images). parsed = urlparse(url) if not parsed.netloc: return False + return True + +def _validate_image_url(url: str) -> bool: + """Validate image URL for sync callers and tests (SSRF via sync DNS check).""" + if not _image_url_shape_ok(url): + return False # Block private/internal addresses to prevent SSRF from tools.url_safety import is_safe_url - if not is_safe_url(url): - return False + return is_safe_url(url) - return True + +async def _validate_image_url_async(url: str) -> bool: + """Validate remote image URL without blocking the event loop on DNS.""" + if not _image_url_shape_ok(url): + return False + from tools.url_safety import async_is_safe_url + return await async_is_safe_url(url) def _detect_image_mime_type(image_path: Path) -> Optional[str]: @@ -181,8 +182,8 @@ async def _download_image(image_url: str, destination: Path, max_retries: int = """ if response.is_redirect and response.next_request: redirect_url = str(response.next_request.url) - from tools.url_safety import is_safe_url - if not is_safe_url(redirect_url): + from tools.url_safety import async_is_safe_url + if not await async_is_safe_url(redirect_url): raise ValueError( f"Blocked redirect to private/internal address: {redirect_url}" ) @@ -716,7 +717,7 @@ async def _vision_analyze_native( if local_path.is_file(): temp_image_path = local_path should_cleanup = False - elif _validate_image_url(image_url): + elif await _validate_image_url_async(image_url): blocked = check_website_access(image_url) if blocked: return tool_error(blocked["message"], success=False) @@ -870,7 +871,7 @@ async def vision_analyze_tool( logger.info("Using local image file: %s", image_url) temp_image_path = local_path should_cleanup = False # Don't delete cached/local files - elif _validate_image_url(image_url): + elif await _validate_image_url_async(image_url): # Remote URL -- download to a temporary location blocked = check_website_access(image_url) if blocked: @@ -1265,8 +1266,8 @@ async def _download_video(video_url: str, destination: Path, max_retries: int = async def _ssrf_redirect_guard(response): if response.is_redirect and response.next_request: redirect_url = str(response.next_request.url) - from tools.url_safety import is_safe_url - if not is_safe_url(redirect_url): + from tools.url_safety import async_is_safe_url + if not await async_is_safe_url(redirect_url): raise ValueError( f"Blocked redirect to private/internal address: {redirect_url}" ) @@ -1372,7 +1373,7 @@ async def video_analyze_tool( logger.info("Using local video file: %s", video_url) temp_video_path = local_path should_cleanup = False - elif _validate_image_url(video_url): + elif await _validate_image_url_async(video_url): blocked = check_website_access(video_url) if blocked: raise PermissionError(blocked["message"]) diff --git a/tools/web_tools.py b/tools/web_tools.py index 8f5275da2..a97370c48 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -102,7 +102,7 @@ from tools.tool_backend_helpers import ( # noqa: F401 nous_tool_gateway_unavailable_message, prefers_gateway, ) -from tools.url_safety import is_safe_url +from tools.url_safety import async_is_safe_url import sys logger = logging.getLogger(__name__) @@ -934,7 +934,7 @@ async def web_extract_tool( safe_urls = [] ssrf_blocked: List[Dict[str, Any]] = [] for url in urls: - if not is_safe_url(url): + if not await async_is_safe_url(url): ssrf_blocked.append({ "url": url, "title": "", "content": "", "error": "Blocked: URL targets a private or internal network address",