Extends the existing /undo command from a single in-memory exchange removal into a full rewind: back up N user turns (default 1), soft-delete the truncated rows in SessionDB (active=0, kept for audit, hidden from re-prompts and search), notify memory providers, and prefill the composer with the backed-up message text for editing — CLI and TUI. Reuses the SessionDB rewind primitives, the on_session_switch(rewound=True) memory hook, and the TUI command.dispatch prefill payload from SaguaroDev's #21910 work, wired to /undo [N] instead of a separate /rewind picker. - cli.py: undo_last(n, prefill) — in-memory truncate + SQLite soft-delete + agent surgery (system-prompt invalidate, flush-index reset) + memory notify + editable buffer prefill; /undo dispatch parses optional count; checkpoint-rollback caller passes prefill=False - tui_gateway/server.py: command.dispatch undo branch (was rewind) parses count, picks Nth-from-last user turn, clamps to oldest - commands.py: /undo gains [N] args_hint - tests: rename + expand TUI suite (multi-turn, clamp, invalid-count) - release.py: AUTHOR_MAP entry for SaguaroDev Co-authored-by: SaguaroDev <74339271+SaguaroDev@users.noreply.github.com>
184 lines
6.5 KiB
Python
184 lines
6.5 KiB
Python
"""Tests for /undo handling in tui_gateway.
|
|
|
|
The TUI routes ``/undo`` through ``command.dispatch`` (it's in
|
|
``_PENDING_INPUT_COMMANDS`` because the CLI handler queues input the
|
|
slash-worker subprocess can't read). The server handles it directly,
|
|
mutates SessionDB to soft-delete rows, refreshes the in-memory session
|
|
history, fires the memory-provider hook with ``rewound=True``, and
|
|
returns ``{"type": "prefill", "message": <text>, "notice": ...}`` so
|
|
the Ink client drops the message into the composer for editing.
|
|
|
|
``/undo N`` backs up N user turns at once (default 1). See issue #21910.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import threading
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from hermes_state import SessionDB
|
|
|
|
|
|
@pytest.fixture()
|
|
def hermes_home(tmp_path, monkeypatch):
|
|
home = tmp_path / ".hermes"
|
|
home.mkdir()
|
|
monkeypatch.setattr(Path, "home", lambda: tmp_path)
|
|
monkeypatch.setenv("HERMES_HOME", str(home))
|
|
yield home
|
|
|
|
|
|
@pytest.fixture()
|
|
def server(hermes_home):
|
|
with patch.dict(
|
|
"sys.modules",
|
|
{
|
|
"hermes_cli.env_loader": MagicMock(),
|
|
"hermes_cli.banner": MagicMock(),
|
|
},
|
|
):
|
|
mod = importlib.import_module("tui_gateway.server")
|
|
yield mod
|
|
mod._sessions.clear()
|
|
mod._pending.clear()
|
|
mod._answers.clear()
|
|
mod._methods.clear()
|
|
importlib.reload(mod)
|
|
|
|
|
|
@pytest.fixture()
|
|
def db(hermes_home):
|
|
return SessionDB(db_path=hermes_home / "state.db")
|
|
|
|
|
|
@pytest.fixture()
|
|
def session_with_history(server, db):
|
|
"""Build a session with 3 user turns + assistant replies persisted in DB."""
|
|
sid = "sid-undo"
|
|
session_key = "tui-undo-1"
|
|
db.create_session(session_key, source="tui")
|
|
for i in range(1, 4):
|
|
db.append_message(session_key, "user", f"question {i}")
|
|
db.append_message(session_key, "assistant", f"answer {i}")
|
|
history = db.get_messages_as_conversation(session_key)
|
|
agent = MagicMock()
|
|
agent._memory_manager = MagicMock()
|
|
agent._last_flushed_db_idx = len(history)
|
|
s = {
|
|
"session_key": session_key,
|
|
"history": list(history),
|
|
"history_lock": threading.Lock(),
|
|
"history_version": 0,
|
|
"running": False,
|
|
"agent": agent,
|
|
"attached_images": [],
|
|
"cols": 120,
|
|
}
|
|
server._sessions[sid] = s
|
|
# Wire the DB cache so _get_db() returns our fixture.
|
|
server._db = db
|
|
return sid, session_key, s, agent
|
|
|
|
|
|
def _call(server, method, **params):
|
|
return server._methods[method](1, params)
|
|
|
|
|
|
def test_undo_returns_prefill_with_target_text(server, session_with_history):
|
|
sid, session_key, s, agent = session_with_history
|
|
resp = _call(server, "command.dispatch", session_id=sid, name="undo", arg="")
|
|
result = resp["result"]
|
|
assert result["type"] == "prefill"
|
|
# Default /undo backs up one user turn — "question 3"
|
|
assert result["message"] == "question 3"
|
|
assert "Undid" in result["notice"]
|
|
|
|
|
|
def test_undo_truncates_in_memory_history(server, session_with_history, db):
|
|
sid, session_key, s, agent = session_with_history
|
|
_call(server, "command.dispatch", session_id=sid, name="undo", arg="")
|
|
# After undoing to "question 3", active history should be 4 rows:
|
|
# user q1, asst a1, user q2, asst a2
|
|
assert len(s["history"]) == 4
|
|
roles = [m["role"] for m in s["history"]]
|
|
assert roles == ["user", "assistant", "user", "assistant"]
|
|
# version bumped
|
|
assert s["history_version"] == 1
|
|
|
|
|
|
def test_undo_n_backs_up_multiple_turns(server, session_with_history, db):
|
|
"""/undo 2 backs up two user turns to "question 2"."""
|
|
sid, session_key, s, agent = session_with_history
|
|
resp = _call(server, "command.dispatch", session_id=sid, name="undo", arg="2")
|
|
result = resp["result"]
|
|
assert result["type"] == "prefill"
|
|
assert result["message"] == "question 2"
|
|
assert "2 turns" in result["notice"]
|
|
# Active history truncated to user q1 + asst a1
|
|
assert len(s["history"]) == 2
|
|
assert [m["role"] for m in s["history"]] == ["user", "assistant"]
|
|
|
|
|
|
def test_undo_n_clamps_to_oldest_turn(server, session_with_history, db):
|
|
"""/undo with N larger than the number of user turns backs up to the oldest."""
|
|
sid, session_key, s, agent = session_with_history
|
|
resp = _call(server, "command.dispatch", session_id=sid, name="undo", arg="99")
|
|
result = resp["result"]
|
|
assert result["message"] == "question 1"
|
|
assert len(s["history"]) == 0
|
|
|
|
|
|
def test_undo_rejects_invalid_count(server, session_with_history):
|
|
sid, _, _, _ = session_with_history
|
|
resp = _call(server, "command.dispatch", session_id=sid, name="undo", arg="abc")
|
|
assert "error" in resp
|
|
assert "invalid count" in resp["error"]["message"].lower()
|
|
|
|
|
|
def test_undo_soft_deletes_rows_in_db(server, session_with_history, db):
|
|
sid, session_key, _, _ = session_with_history
|
|
_call(server, "command.dispatch", session_id=sid, name="undo", arg="")
|
|
# All rows still present
|
|
all_rows = db.get_messages(session_key, include_inactive=True)
|
|
assert len(all_rows) == 6
|
|
# 2 inactive (the "question 3" row + its trailing "answer 3").
|
|
active = [r for r in all_rows if r["active"] == 1]
|
|
assert len(active) == 4
|
|
# rewind_count bumped
|
|
sess = db.get_session(session_key)
|
|
assert sess["rewind_count"] == 1
|
|
|
|
|
|
def test_undo_notifies_memory_provider(server, session_with_history):
|
|
sid, session_key, _, agent = session_with_history
|
|
_call(server, "command.dispatch", session_id=sid, name="undo", arg="")
|
|
agent._memory_manager.on_session_switch.assert_called_once()
|
|
args, kwargs = agent._memory_manager.on_session_switch.call_args
|
|
assert args[0] == session_key
|
|
assert kwargs["rewound"] is True
|
|
assert kwargs["reset"] is False
|
|
|
|
|
|
def test_undo_refuses_when_session_busy(server, session_with_history):
|
|
sid, _, s, _ = session_with_history
|
|
s["running"] = True
|
|
resp = _call(server, "command.dispatch", session_id=sid, name="undo", arg="")
|
|
assert "error" in resp
|
|
assert "busy" in resp["error"]["message"].lower()
|
|
|
|
|
|
def test_undo_errors_when_no_active_session(server):
|
|
resp = _call(server, "command.dispatch", session_id="no-such-sid", name="undo", arg="")
|
|
assert "error" in resp
|
|
assert "no active session" in resp["error"]["message"].lower()
|
|
|
|
|
|
def test_undo_in_pending_input_commands(server):
|
|
"""Registry sanity: /undo must be in _PENDING_INPUT_COMMANDS so
|
|
slash.exec rejects it and the TUI falls through to command.dispatch."""
|
|
assert "undo" in server._PENDING_INPUT_COMMANDS
|