diff --git a/tests/tui_gateway/test_rewind_command.py b/tests/tui_gateway/test_rewind_command.py new file mode 100644 index 000000000..ae2de14e2 --- /dev/null +++ b/tests/tui_gateway/test_rewind_command.py @@ -0,0 +1,154 @@ +"""Tests for /rewind handling in tui_gateway. + +The TUI routes ``/rewind`` through ``command.dispatch`` (it's in +``_PENDING_INPUT_COMMANDS`` because the CLI handler queues input the +slash-worker subprocess can't read). The server handles it directly, +mutates SessionDB to soft-delete rows, refreshes the in-memory session +history, fires the memory-provider hook with ``rewound=True``, and +returns ``{"type": "prefill", "message": , "notice": ...}`` so +the Ink client drops the message into the composer for editing. +See issue #21910. +""" + +from __future__ import annotations + +import importlib +import threading +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from hermes_state import SessionDB + + +@pytest.fixture() +def hermes_home(tmp_path, monkeypatch): + home = tmp_path / ".hermes" + home.mkdir() + monkeypatch.setattr(Path, "home", lambda: tmp_path) + monkeypatch.setenv("HERMES_HOME", str(home)) + yield home + + +@pytest.fixture() +def server(hermes_home): + with patch.dict( + "sys.modules", + { + "hermes_cli.env_loader": MagicMock(), + "hermes_cli.banner": MagicMock(), + }, + ): + mod = importlib.import_module("tui_gateway.server") + yield mod + mod._sessions.clear() + mod._pending.clear() + mod._answers.clear() + mod._methods.clear() + importlib.reload(mod) + + +@pytest.fixture() +def db(hermes_home): + return SessionDB(db_path=hermes_home / "state.db") + + +@pytest.fixture() +def session_with_history(server, db): + """Build a session with 3 user turns + assistant replies persisted in DB.""" + sid = "sid-rewind" + session_key = "tui-rewind-1" + db.create_session(session_key, source="tui") + for i in range(1, 4): + db.append_message(session_key, "user", f"question {i}") + db.append_message(session_key, "assistant", f"answer {i}") + history = db.get_messages_as_conversation(session_key) + agent = MagicMock() + agent._memory_manager = MagicMock() + agent._last_flushed_db_idx = len(history) + s = { + "session_key": session_key, + "history": list(history), + "history_lock": threading.Lock(), + "history_version": 0, + "running": False, + "agent": agent, + "attached_images": [], + "cols": 120, + } + server._sessions[sid] = s + # Wire the DB cache so _get_db() returns our fixture. + server._db = db + return sid, session_key, s, agent + + +def _call(server, method, **params): + return server._methods[method](1, params) + + +def test_rewind_returns_prefill_with_target_text(server, session_with_history): + sid, session_key, s, agent = session_with_history + resp = _call(server, "command.dispatch", session_id=sid, name="rewind", arg="") + result = resp["result"] + assert result["type"] == "prefill" + # v1 auto-picks the most recent user turn — "question 3" + assert result["message"] == "question 3" + assert "Rewound" in result["notice"] + + +def test_rewind_truncates_in_memory_history(server, session_with_history, db): + sid, session_key, s, agent = session_with_history + _call(server, "command.dispatch", session_id=sid, name="rewind", arg="") + # After rewinding to "question 3", active history should be 4 rows: + # user q1, asst a1, user q2, asst a2 + assert len(s["history"]) == 4 + roles = [m["role"] for m in s["history"]] + assert roles == ["user", "assistant", "user", "assistant"] + # version bumped + assert s["history_version"] == 1 + + +def test_rewind_soft_deletes_rows_in_db(server, session_with_history, db): + sid, session_key, _, _ = session_with_history + _call(server, "command.dispatch", session_id=sid, name="rewind", arg="") + # All rows still present + all_rows = db.get_messages(session_key, include_inactive=True) + assert len(all_rows) == 6 + # 2 inactive (the "question 3" row + its trailing siblings — here just + # "question 3" + "answer 3", since target was the q3 user row). + active = [r for r in all_rows if r["active"] == 1] + assert len(active) == 4 + # rewind_count bumped + sess = db.get_session(session_key) + assert sess["rewind_count"] == 1 + + +def test_rewind_notifies_memory_provider(server, session_with_history): + sid, session_key, _, agent = session_with_history + _call(server, "command.dispatch", session_id=sid, name="rewind", arg="") + agent._memory_manager.on_session_switch.assert_called_once() + args, kwargs = agent._memory_manager.on_session_switch.call_args + assert args[0] == session_key + assert kwargs["rewound"] is True + assert kwargs["reset"] is False + + +def test_rewind_refuses_when_session_busy(server, session_with_history): + sid, _, s, _ = session_with_history + s["running"] = True + resp = _call(server, "command.dispatch", session_id=sid, name="rewind", arg="") + assert "error" in resp + assert "busy" in resp["error"]["message"].lower() + + +def test_rewind_errors_when_no_active_session(server): + resp = _call(server, "command.dispatch", session_id="no-such-sid", name="rewind", arg="") + assert "error" in resp + assert "no active session" in resp["error"]["message"].lower() + + +def test_rewind_in_pending_input_commands(server): + """Registry sanity: /rewind must be in _PENDING_INPUT_COMMANDS so + slash.exec rejects it and the TUI falls through to command.dispatch.""" + assert "rewind" in server._PENDING_INPUT_COMMANDS diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 700010cf5..26be2db38 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -5496,6 +5496,7 @@ _PENDING_INPUT_COMMANDS: frozenset[str] = frozenset( "steer", "plan", "goal", + "rewind", } ) @@ -5881,6 +5882,94 @@ def _(rid, params: dict) -> dict: {"type": "send", "notice": notice, "message": state.goal}, ) + if name == "rewind": + # /rewind: pick the most-recent user message and prefill the + # composer with its text after soft-deleting everything that + # came after it in the transcript. v1 auto-picks the latest + # user turn (Claude-Code-style single-step undo); a multi-step + # picker UI is tracked as a follow-up to issue #21910. + if not session: + return _err(rid, 4001, "no active session to rewind") + if session.get("running"): + return _err( + rid, 4009, "session busy — /interrupt the current turn before /rewind" + ) + db = _get_db() + if db is None: + return _db_unavailable_error(rid, code=5008) + session_key = session.get("session_key", "") + if not session_key: + return _err(rid, 4001, "no session key for rewind") + try: + recents = db.list_recent_user_messages(session_key, limit=10) + except Exception as e: + return _err(rid, 5008, f"rewind: failed to load history: {e}") + if not recents: + return _err(rid, 4018, "no user messages to rewind to") + # v1: auto-pick the most recent user turn. The Ink UI does not + # yet host a dedicated picker overlay (#21910 follow-up). + target_id = recents[0]["id"] + try: + result = db.rewind_to_message(session_key, target_id) + except ValueError as e: + return _err(rid, 4004, f"rewind: {e}") + except Exception as e: + return _err(rid, 5008, f"rewind: {e}") + # Reload the active-only transcript into the in-memory session + # history so subsequent turns see the truncated view. + try: + active = db.get_messages_as_conversation(session_key) + except Exception: + active = [] + with session["history_lock"]: + session["history"] = list(active) + session["history_version"] = int(session.get("history_version", 0)) + 1 + # Notify memory providers — same hook /branch fires, plus the + # rewound flag so providers caching per-turn document state + # know to invalidate. See #6672 + #21910. + agent = session.get("agent") + if agent is not None: + mm = getattr(agent, "_memory_manager", None) + if mm is not None: + try: + mm.on_session_switch( + session_key, + parent_session_id="", + reset=False, + rewound=True, + ) + except Exception: + pass + if hasattr(agent, "_invalidate_system_prompt"): + try: + agent._invalidate_system_prompt() + except Exception: + pass + if hasattr(agent, "_last_flushed_db_idx"): + try: + agent._last_flushed_db_idx = len(active) + except Exception: + pass + target_msg = result.get("target_message") or {} + target_text = target_msg.get("content") or "" + if isinstance(target_text, list): + parts = [ + p.get("text", "") for p in target_text + if isinstance(p, dict) and p.get("type") == "text" + ] + target_text = "\n".join(t for t in parts if t) + if not isinstance(target_text, str): + target_text = "" + rewound_count = result.get("rewound_count", 0) + notice = ( + f"↶ Rewound {rewound_count} message(s). " + "Edit and resubmit, or send a new message." + ) + return _ok( + rid, + {"type": "prefill", "message": target_text, "notice": notice}, + ) + if name in {"snapshot", "snap"}: subcommand = arg.split(maxsplit=1)[0].lower() if arg else "" if subcommand in {"restore", "rewind"}: diff --git a/ui-tui/src/__tests__/asCommandDispatch.test.ts b/ui-tui/src/__tests__/asCommandDispatch.test.ts index dfa759517..5dac25fab 100644 --- a/ui-tui/src/__tests__/asCommandDispatch.test.ts +++ b/ui-tui/src/__tests__/asCommandDispatch.test.ts @@ -15,6 +15,15 @@ describe('asCommandDispatch', () => { type: 'send', message: 'hello world' }) + expect(asCommandDispatch({ type: 'prefill', message: 'edit me' })).toEqual({ + type: 'prefill', + message: 'edit me' + }) + expect(asCommandDispatch({ type: 'prefill', message: 'edit me', notice: '↶ rewound' })).toEqual({ + type: 'prefill', + message: 'edit me', + notice: '↶ rewound' + }) }) it('rejects malformed payloads', () => { @@ -23,5 +32,7 @@ describe('asCommandDispatch', () => { expect(asCommandDispatch({ type: 'skill', name: 1 })).toBeNull() expect(asCommandDispatch({ type: 'send' })).toBeNull() expect(asCommandDispatch({ type: 'send', message: 42 })).toBeNull() + expect(asCommandDispatch({ type: 'prefill' })).toBeNull() + expect(asCommandDispatch({ type: 'prefill', message: 42 })).toBeNull() }) }) diff --git a/ui-tui/src/app/createSlashHandler.ts b/ui-tui/src/app/createSlashHandler.ts index 0164ef0d5..71e2536d8 100644 --- a/ui-tui/src/app/createSlashHandler.ts +++ b/ui-tui/src/app/createSlashHandler.ts @@ -119,6 +119,19 @@ export function createSlashHandler(ctx: SlashHandlerContext): (cmd: string) => b } return d.message?.trim() ? send(d.message) : sys(`/${parsed.name}: empty message`) } + + if (d.type === 'prefill') { + // /rewind returns prefill: drop the chosen text into the + // composer so the user can edit and resubmit, instead of + // submitting it immediately like 'send'. + if (d.notice?.trim()) { + sys(d.notice) + } + if (d.message) { + ctx.composer.setInput(d.message) + } + return + } }) .catch(guardedErr) }) diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index 447dec3ea..c56a1aebf 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -48,6 +48,7 @@ export type CommandDispatchResponse = | { target: string; type: 'alias' } | { message?: string; name: string; type: 'skill' } | { message: string; notice?: string; type: 'send' } + | { message: string; notice?: string; type: 'prefill' } // ── Config ─────────────────────────────────────────────────────────── diff --git a/ui-tui/src/lib/rpc.ts b/ui-tui/src/lib/rpc.ts index 81dc70318..76862f073 100644 --- a/ui-tui/src/lib/rpc.ts +++ b/ui-tui/src/lib/rpc.ts @@ -34,6 +34,14 @@ export const asCommandDispatch = (value: unknown): CommandDispatchResponse | nul } } + if (t === 'prefill' && typeof o.message === 'string') { + return { + type: 'prefill', + message: o.message, + notice: typeof o.notice === 'string' ? o.notice : undefined, + } + } + return null }