fix(run_agent): gate concurrent checkpoint preflight on block_result (fixes #34827)
In the concurrent tool-execution path, checkpoint preflight (write_file, patch, destructive terminal) fired BEFORE plugin guardrail block_result was computed. A blocked write_file could still dirty checkpoint state (doc_modified_this_turn, _last_write_file_call_id, turn_counter). Move checkpoint preflight to AFTER block_result computation, gated on `if block_result is None:` — matching the invariant the sequential path already enforces.
This commit is contained in:
@ -180,28 +180,9 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Checkpoint for file-mutating tools
|
||||
if function_name in {"write_file", "patch"} and agent._checkpoint_mgr.enabled:
|
||||
try:
|
||||
file_path = function_args.get("path", "")
|
||||
if file_path:
|
||||
work_dir = agent._checkpoint_mgr.get_working_dir_for_path(file_path)
|
||||
agent._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Checkpoint before destructive terminal commands
|
||||
if function_name == "terminal" and agent._checkpoint_mgr.enabled:
|
||||
try:
|
||||
cmd = function_args.get("command", "")
|
||||
if _is_destructive_command(cmd):
|
||||
cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
agent._checkpoint_mgr.ensure_checkpoint(
|
||||
cwd, f"before terminal: {cmd[:60]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── Block evaluation (BEFORE checkpoint preflight) ───────────
|
||||
# We must know whether the tool will execute before touching
|
||||
# checkpoint state (dedup slot, real snapshots).
|
||||
block_result = None
|
||||
blocked_by_guardrail = False
|
||||
if _ts_scope_block is not None:
|
||||
@ -224,6 +205,30 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
||||
block_result = agent._guardrail_block_result(guardrail_decision)
|
||||
blocked_by_guardrail = True
|
||||
|
||||
# ── Checkpoint preflight (only for tools that will execute) ──
|
||||
if block_result is None:
|
||||
# Checkpoint for file-mutating tools
|
||||
if function_name in {"write_file", "patch"} and agent._checkpoint_mgr.enabled:
|
||||
try:
|
||||
file_path = function_args.get("path", "")
|
||||
if file_path:
|
||||
work_dir = agent._checkpoint_mgr.get_working_dir_for_path(file_path)
|
||||
agent._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Checkpoint before destructive terminal commands
|
||||
if function_name == "terminal" and agent._checkpoint_mgr.enabled:
|
||||
try:
|
||||
cmd = function_args.get("command", "")
|
||||
if _is_destructive_command(cmd):
|
||||
cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd())
|
||||
agent._checkpoint_mgr.ensure_checkpoint(
|
||||
cwd, f"before terminal: {cmd[:60]}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
|
||||
@ -2543,6 +2543,122 @@ class TestConcurrentToolExecution:
|
||||
assert json.loads(result) == {"error": "Blocked"}
|
||||
assert agent._turns_since_memory == 5
|
||||
|
||||
def test_concurrent_blocked_write_skips_checkpoint(self, agent, monkeypatch):
|
||||
"""Concurrent path: blocked write_file should not trigger checkpoint."""
|
||||
tc1 = _mock_tool_call(name="write_file",
|
||||
arguments='{"path":"test.txt","content":"hello"}',
|
||||
call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file",
|
||||
arguments='{"path":"other.py"}',
|
||||
call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: "Blocked" if args[0] == "write_file" else None,
|
||||
)
|
||||
|
||||
agent._checkpoint_mgr.enabled = True
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return f"result_{name}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock:
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
cp_mock.assert_not_called()
|
||||
|
||||
def test_concurrent_blocked_patch_skips_checkpoint(self, agent, monkeypatch):
|
||||
"""Concurrent path: blocked patch should not trigger checkpoint."""
|
||||
tc1 = _mock_tool_call(name="patch",
|
||||
arguments='{"path":"f.py","old":"a","new":"b"}',
|
||||
call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file",
|
||||
arguments='{"path":"other.py"}',
|
||||
call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: "Blocked" if args[0] == "patch" else None,
|
||||
)
|
||||
|
||||
agent._checkpoint_mgr.enabled = True
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return f"result_{name}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock:
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
cp_mock.assert_not_called()
|
||||
|
||||
def test_concurrent_blocked_terminal_skips_checkpoint(self, agent, monkeypatch):
|
||||
"""Concurrent path: blocked terminal should not trigger checkpoint."""
|
||||
tc1 = _mock_tool_call(name="terminal",
|
||||
arguments='{"command":"rm -rf /tmp/foo"}',
|
||||
call_id="c1")
|
||||
tc2 = _mock_tool_call(name="read_file",
|
||||
arguments='{"path":"other.py"}',
|
||||
call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: "Blocked" if args[0] == "terminal" else None,
|
||||
)
|
||||
|
||||
agent._checkpoint_mgr.enabled = True
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return f"result_{name}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock:
|
||||
with patch("agent.tool_executor._is_destructive_command", return_value=True):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
cp_mock.assert_not_called()
|
||||
|
||||
def test_concurrent_blocked_write_does_not_steal_slot_from_allowed_write(self, agent, monkeypatch):
|
||||
"""When write_file is blocked, its dedup slot must not be consumed,
|
||||
so a subsequent allowed write_file for the same path still checkpoints."""
|
||||
tc1 = _mock_tool_call(name="write_file",
|
||||
arguments='{"path":"dup.txt","content":"blocked"}',
|
||||
call_id="c1")
|
||||
tc2 = _mock_tool_call(name="write_file",
|
||||
arguments='{"path":"dup.txt","content":"allowed"}',
|
||||
call_id="c2")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2])
|
||||
messages = []
|
||||
|
||||
call_count = {"n": 0}
|
||||
def block_first_only(*args, **kwargs):
|
||||
call_count["n"] += 1
|
||||
return "Blocked" if call_count["n"] == 1 else None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
block_first_only,
|
||||
)
|
||||
|
||||
agent._checkpoint_mgr.enabled = True
|
||||
|
||||
def fake_handle(name, args, task_id, **kwargs):
|
||||
return f"result_{name}"
|
||||
|
||||
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||
with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock:
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
# Second (allowed) write must checkpoint even though first was blocked.
|
||||
cp_mock.assert_called_once()
|
||||
|
||||
|
||||
class TestPathsOverlap:
|
||||
"""Unit tests for the _paths_overlap helper."""
|
||||
|
||||
Reference in New Issue
Block a user