feat(state): add messages.active flag + rewind primitives (#21910)

Schema v12 adds:
- messages.active (default 1) — soft-delete flag for /rewind
- sessions.rewind_count (default 0) — audit counter
- idx_messages_session_active deferred index

New SessionDB methods:
- rewind_to_message(session_id, target_message_id) — soft-deletes rows
  >= target_id, refuses non-user targets, increments rewind_count
- restore_rewound(session_id, since_message_id) — undo for stretch goal
- list_recent_user_messages — picker source

Existing methods get include_inactive kwarg (default False):
- get_messages, get_messages_as_conversation, search_messages.
  Rewound rows excluded from session_search by default — opt-in for audit.

The deferred index pattern (DEFERRED_INDEX_SQL run after _reconcile_columns)
avoids 'no such column: active' on legacy pre-v12 databases, since
executescript(SCHEMA_SQL) runs before column reconciliation.
This commit is contained in:
SaguaroDev
2026-05-10 18:16:32 -04:00
committed by Teknium
parent 6c73e8ffaa
commit 3e59be0c41

View File

@ -263,6 +263,7 @@ CREATE TABLE IF NOT EXISTS sessions (
handoff_state TEXT,
handoff_platform TEXT,
handoff_error TEXT,
rewind_count INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
);
@ -283,7 +284,8 @@ CREATE TABLE IF NOT EXISTS messages (
codex_reasoning_items TEXT,
codex_message_items TEXT,
platform_message_id TEXT,
observed INTEGER DEFAULT 0
observed INTEGER DEFAULT 0,
active INTEGER NOT NULL DEFAULT 1
);
CREATE TABLE IF NOT EXISTS state_meta (
@ -305,6 +307,15 @@ CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestam
CREATE INDEX IF NOT EXISTS idx_compression_locks_expires ON compression_locks(expires_at);
"""
# Indexes that reference columns added in later schema versions must be
# created AFTER _reconcile_columns() has had a chance to ADD them on
# existing databases. SCHEMA_SQL above is run by sqlite executescript
# which would otherwise fail on legacy DBs ("no such column: active").
DEFERRED_INDEX_SQL = """
CREATE INDEX IF NOT EXISTS idx_messages_session_active
ON messages(session_id, active, timestamp);
"""
FTS_SQL = """
CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5(
content
@ -745,6 +756,10 @@ class SessionDB:
except sqlite3.OperationalError as exc:
logger.debug("idx_messages_platform_msg_id create skipped: %s", exc)
# Deferred indexes that reference the reconciler-added ``active``
# column (idx_messages_session_active) — same ordering constraint.
cursor.executescript(DEFERRED_INDEX_SQL)
fts5_available = self._sqlite_supports_fts5(cursor)
fts_migrations_complete = True
if not fts5_available:
@ -844,6 +859,18 @@ class SessionDB:
fts_migrations_complete = False
else:
fts_migrations_complete = False
if current_version < 12:
# v12: messages.active flag for rewind/undo soft-deletion.
# The declarative reconcile_columns() above adds the
# column itself; this UPDATE is belt-and-suspenders to
# ensure any rows that pre-existed the ADD COLUMN have
# active=1 rather than NULL.
try:
cursor.execute(
"UPDATE messages SET active = 1 WHERE active IS NULL"
)
except sqlite3.OperationalError:
pass
if current_version < SCHEMA_VERSION and fts_migrations_complete:
cursor.execute(
"UPDATE schema_version SET version = ?",
@ -1970,11 +1997,24 @@ class SessionDB:
self._execute_write(_do)
def get_messages(self, session_id: str) -> List[Dict[str, Any]]:
"""Load all messages for a session, ordered by insertion order."""
def get_messages(
self, session_id: str, include_inactive: bool = False
) -> List[Dict[str, Any]]:
"""Load messages for a session in insertion order.
By default only active messages are returned. Pass
``include_inactive=True`` to load soft-deleted rows (e.g. for
audit / debug views of rewound history). See
:meth:`rewind_to_message` for the soft-delete mechanic.
Ordered by AUTOINCREMENT id (true insertion order) rather than
timestamp — see c03acca50 for the WSL2 clock-regression rationale.
"""
active_clause = "" if include_inactive else " AND active = 1"
with self._lock:
cursor = self._conn.execute(
"SELECT * FROM messages WHERE session_id = ? ORDER BY id",
"SELECT * FROM messages WHERE session_id = ?"
f"{active_clause} ORDER BY id",
(session_id,),
)
rows = cursor.fetchall()
@ -2256,23 +2296,32 @@ class SessionDB:
return session_id
def get_messages_as_conversation(
self, session_id: str, include_ancestors: bool = False
self,
session_id: str,
include_ancestors: bool = False,
include_inactive: bool = False,
) -> List[Dict[str, Any]]:
"""
Load messages in the OpenAI conversation format (role + content dicts).
Used by the gateway to restore conversation history.
By default only active messages are returned. Pass
``include_inactive=True`` to load soft-deleted (rewound) rows
as well. See :meth:`rewind_to_message`.
"""
session_ids = [session_id]
if include_ancestors:
session_ids = self._session_lineage_root_to_tip(session_id)
active_clause = "" if include_inactive else " AND active = 1"
with self._lock:
placeholders = ",".join("?" for _ in session_ids)
rows = self._conn.execute(
"SELECT role, content, tool_call_id, tool_calls, tool_name, "
"finish_reason, reasoning, reasoning_content, reasoning_details, "
"codex_reasoning_items, codex_message_items, platform_message_id, observed "
f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id",
f"FROM messages WHERE session_id IN ({placeholders})"
f"{active_clause} ORDER BY id",
tuple(session_ids),
).fetchall()
@ -2370,6 +2419,175 @@ class SessionDB:
return False
return False
# =========================================================================
# Rewind (soft-delete) — see /rewind slash command + issue #21910
# =========================================================================
def rewind_to_message(
self, session_id: str, target_message_id: int
) -> Dict[str, Any]:
"""Soft-delete all messages with id >= ``target_message_id`` in *session_id*.
The target message itself becomes inactive as well so the caller
can pre-fill it as the next user prompt without it appearing
twice in the replayed transcript. Rewound rows are kept on
disk with ``active=0`` for audit / forensic inspection — use
:meth:`get_messages` with ``include_inactive=True`` to see them.
Returns a dict::
{
"rewound_count": int, # number of rows newly flipped to active=0
"target_message": dict, # full row dict of the target
"new_head_id": int|None # id of the last still-active row, or None
}
Raises ``ValueError`` if the target message does not exist in
*session_id* or if its role is not ``"user"``.
Always increments ``sessions.rewind_count`` — even when the
target is already inactive — so the counter accurately reflects
the number of rewind operations performed against the session.
Idempotent on the ``active`` flag: re-rewinding past the same
target is a no-op on row state but still bumps the counter.
"""
# 1) Validate target up-front (read-only, outside the write txn).
with self._lock:
row = self._conn.execute(
"SELECT * FROM messages WHERE id = ? AND session_id = ?",
(target_message_id, session_id),
).fetchone()
if row is None:
raise ValueError(
f"message {target_message_id} not found in session {session_id}"
)
target_row = dict(row)
if target_row.get("role") != "user":
raise ValueError(
f"rewind target must be a 'user' message (got role="
f"{target_row.get('role')!r}, id={target_message_id})"
)
# Decode content for callers (prefill the prompt buffer).
target_row["content"] = self._decode_content(target_row.get("content"))
rewound: List[int] = []
def _do(conn):
cursor = conn.execute(
"SELECT id FROM messages "
"WHERE session_id = ? AND id >= ? AND active = 1",
(session_id, target_message_id),
)
ids = [r[0] for r in cursor.fetchall()]
if ids:
placeholders = ",".join("?" for _ in ids)
conn.execute(
f"UPDATE messages SET active = 0 WHERE id IN ({placeholders})",
ids,
)
conn.execute(
"UPDATE sessions SET rewind_count = COALESCE(rewind_count, 0) + 1 "
"WHERE id = ?",
(session_id,),
)
return ids
rewound = self._execute_write(_do)
# 2) Compute new head id (largest still-active row id in session).
with self._lock:
head_row = self._conn.execute(
"SELECT MAX(id) FROM messages WHERE session_id = ? AND active = 1",
(session_id,),
).fetchone()
new_head_id = head_row[0] if head_row and head_row[0] is not None else None
return {
"rewound_count": len(rewound),
"target_message": target_row,
"new_head_id": new_head_id,
}
def restore_rewound(self, session_id: str, since_message_id: int) -> int:
"""Mark inactive messages with id >= *since_message_id* active again.
Returns the number of rows flipped back to ``active=1``.
Intended for undo-of-rewind and test cleanup; not wired to a
slash command in v1.
"""
def _do(conn):
cursor = conn.execute(
"SELECT id FROM messages "
"WHERE session_id = ? AND id >= ? AND active = 0",
(session_id, since_message_id),
)
ids = [r[0] for r in cursor.fetchall()]
if ids:
placeholders = ",".join("?" for _ in ids)
conn.execute(
f"UPDATE messages SET active = 1 WHERE id IN ({placeholders})",
ids,
)
return len(ids)
return self._execute_write(_do)
def list_recent_user_messages(
self,
session_id: str,
limit: int = 20,
include_inactive: bool = False,
) -> List[Dict[str, Any]]:
"""Return the *limit* most-recent user messages, newest first.
Each entry is a dict with keys ``id``, ``timestamp``, ``preview``.
``preview`` is the first 80 characters of the message content
(with line breaks collapsed to spaces). Used by the /rewind
slash command picker.
By default only active messages are returned.
"""
active_clause = "" if include_inactive else " AND active = 1"
with self._lock:
cursor = self._conn.execute(
"SELECT id, timestamp, content FROM messages "
"WHERE session_id = ? AND role = 'user'"
f"{active_clause} "
"ORDER BY id DESC LIMIT ?",
(session_id, int(limit)),
)
rows = cursor.fetchall()
result: List[Dict[str, Any]] = []
for row in rows:
decoded = self._decode_content(row["content"])
if isinstance(decoded, list):
# Multimodal — flatten text parts.
text_parts = [
p.get("text", "") for p in decoded
if isinstance(p, dict) and p.get("type") == "text"
]
preview = " ".join(t for t in text_parts if t).strip()
if not preview:
preview = "[multimodal content]"
elif isinstance(decoded, str):
preview = decoded
else:
preview = ""
preview = " ".join(preview.split()) # collapse whitespace
if len(preview) > 80:
preview = preview[:77] + "..."
result.append(
{
"id": row["id"],
"timestamp": row["timestamp"],
"preview": preview,
}
)
return result
# =========================================================================
# Search
# =========================================================================
@ -2467,6 +2685,7 @@ class SessionDB:
limit: int = 20,
offset: int = 0,
sort: str = None,
include_inactive: bool = False,
) -> List[Dict[str, Any]]:
"""
Full-text search across session messages using FTS5.
@ -2488,6 +2707,9 @@ class SessionDB:
The short-CJK LIKE fallback already orders by timestamp DESC and
ignores ``sort``. The trigram CJK path honours ``sort`` like the main
FTS5 path.
Rewound (``active=0``) rows are excluded by default. Pass
``include_inactive=True`` to search every row.
"""
if not self._fts_enabled:
return []
@ -2521,6 +2743,8 @@ class SessionDB:
# Build WHERE clauses dynamically
where_clauses = ["messages_fts MATCH ?"]
params: list = [query]
if not include_inactive:
where_clauses.append("m.active = 1")
if source_filter is not None:
source_placeholders = ",".join("?" for _ in source_filter)
@ -2600,6 +2824,8 @@ class SessionDB:
trigram_query = " ".join(parts)
tri_where = ["messages_fts_trigram MATCH ?"]
tri_params: list = [trigram_query]
if not include_inactive:
tri_where.append("m.active = 1")
if source_filter is not None:
tri_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})")
tri_params.extend(source_filter)