diff --git a/hermes_state.py b/hermes_state.py index ca802994a..0a7b86195 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -263,6 +263,7 @@ CREATE TABLE IF NOT EXISTS sessions ( handoff_state TEXT, handoff_platform TEXT, handoff_error TEXT, + rewind_count INTEGER NOT NULL DEFAULT 0, FOREIGN KEY (parent_session_id) REFERENCES sessions(id) ); @@ -283,7 +284,8 @@ CREATE TABLE IF NOT EXISTS messages ( codex_reasoning_items TEXT, codex_message_items TEXT, platform_message_id TEXT, - observed INTEGER DEFAULT 0 + observed INTEGER DEFAULT 0, + active INTEGER NOT NULL DEFAULT 1 ); CREATE TABLE IF NOT EXISTS state_meta ( @@ -305,6 +307,15 @@ CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestam CREATE INDEX IF NOT EXISTS idx_compression_locks_expires ON compression_locks(expires_at); """ +# Indexes that reference columns added in later schema versions must be +# created AFTER _reconcile_columns() has had a chance to ADD them on +# existing databases. SCHEMA_SQL above is run by sqlite executescript +# which would otherwise fail on legacy DBs ("no such column: active"). +DEFERRED_INDEX_SQL = """ +CREATE INDEX IF NOT EXISTS idx_messages_session_active + ON messages(session_id, active, timestamp); +""" + FTS_SQL = """ CREATE VIRTUAL TABLE IF NOT EXISTS messages_fts USING fts5( content @@ -745,6 +756,10 @@ class SessionDB: except sqlite3.OperationalError as exc: logger.debug("idx_messages_platform_msg_id create skipped: %s", exc) + # Deferred indexes that reference the reconciler-added ``active`` + # column (idx_messages_session_active) — same ordering constraint. + cursor.executescript(DEFERRED_INDEX_SQL) + fts5_available = self._sqlite_supports_fts5(cursor) fts_migrations_complete = True if not fts5_available: @@ -844,6 +859,18 @@ class SessionDB: fts_migrations_complete = False else: fts_migrations_complete = False + if current_version < 12: + # v12: messages.active flag for rewind/undo soft-deletion. + # The declarative reconcile_columns() above adds the + # column itself; this UPDATE is belt-and-suspenders to + # ensure any rows that pre-existed the ADD COLUMN have + # active=1 rather than NULL. + try: + cursor.execute( + "UPDATE messages SET active = 1 WHERE active IS NULL" + ) + except sqlite3.OperationalError: + pass if current_version < SCHEMA_VERSION and fts_migrations_complete: cursor.execute( "UPDATE schema_version SET version = ?", @@ -1970,11 +1997,24 @@ class SessionDB: self._execute_write(_do) - def get_messages(self, session_id: str) -> List[Dict[str, Any]]: - """Load all messages for a session, ordered by insertion order.""" + def get_messages( + self, session_id: str, include_inactive: bool = False + ) -> List[Dict[str, Any]]: + """Load messages for a session in insertion order. + + By default only active messages are returned. Pass + ``include_inactive=True`` to load soft-deleted rows (e.g. for + audit / debug views of rewound history). See + :meth:`rewind_to_message` for the soft-delete mechanic. + + Ordered by AUTOINCREMENT id (true insertion order) rather than + timestamp — see c03acca50 for the WSL2 clock-regression rationale. + """ + active_clause = "" if include_inactive else " AND active = 1" with self._lock: cursor = self._conn.execute( - "SELECT * FROM messages WHERE session_id = ? ORDER BY id", + "SELECT * FROM messages WHERE session_id = ?" + f"{active_clause} ORDER BY id", (session_id,), ) rows = cursor.fetchall() @@ -2256,23 +2296,32 @@ class SessionDB: return session_id def get_messages_as_conversation( - self, session_id: str, include_ancestors: bool = False + self, + session_id: str, + include_ancestors: bool = False, + include_inactive: bool = False, ) -> List[Dict[str, Any]]: """ Load messages in the OpenAI conversation format (role + content dicts). Used by the gateway to restore conversation history. + + By default only active messages are returned. Pass + ``include_inactive=True`` to load soft-deleted (rewound) rows + as well. See :meth:`rewind_to_message`. """ session_ids = [session_id] if include_ancestors: session_ids = self._session_lineage_root_to_tip(session_id) + active_clause = "" if include_inactive else " AND active = 1" with self._lock: placeholders = ",".join("?" for _ in session_ids) rows = self._conn.execute( "SELECT role, content, tool_call_id, tool_calls, tool_name, " "finish_reason, reasoning, reasoning_content, reasoning_details, " "codex_reasoning_items, codex_message_items, platform_message_id, observed " - f"FROM messages WHERE session_id IN ({placeholders}) ORDER BY id", + f"FROM messages WHERE session_id IN ({placeholders})" + f"{active_clause} ORDER BY id", tuple(session_ids), ).fetchall() @@ -2370,6 +2419,175 @@ class SessionDB: return False return False + # ========================================================================= + # Rewind (soft-delete) — see /rewind slash command + issue #21910 + # ========================================================================= + + def rewind_to_message( + self, session_id: str, target_message_id: int + ) -> Dict[str, Any]: + """Soft-delete all messages with id >= ``target_message_id`` in *session_id*. + + The target message itself becomes inactive as well so the caller + can pre-fill it as the next user prompt without it appearing + twice in the replayed transcript. Rewound rows are kept on + disk with ``active=0`` for audit / forensic inspection — use + :meth:`get_messages` with ``include_inactive=True`` to see them. + + Returns a dict:: + + { + "rewound_count": int, # number of rows newly flipped to active=0 + "target_message": dict, # full row dict of the target + "new_head_id": int|None # id of the last still-active row, or None + } + + Raises ``ValueError`` if the target message does not exist in + *session_id* or if its role is not ``"user"``. + + Always increments ``sessions.rewind_count`` — even when the + target is already inactive — so the counter accurately reflects + the number of rewind operations performed against the session. + Idempotent on the ``active`` flag: re-rewinding past the same + target is a no-op on row state but still bumps the counter. + """ + + # 1) Validate target up-front (read-only, outside the write txn). + with self._lock: + row = self._conn.execute( + "SELECT * FROM messages WHERE id = ? AND session_id = ?", + (target_message_id, session_id), + ).fetchone() + if row is None: + raise ValueError( + f"message {target_message_id} not found in session {session_id}" + ) + target_row = dict(row) + if target_row.get("role") != "user": + raise ValueError( + f"rewind target must be a 'user' message (got role=" + f"{target_row.get('role')!r}, id={target_message_id})" + ) + + # Decode content for callers (prefill the prompt buffer). + target_row["content"] = self._decode_content(target_row.get("content")) + + rewound: List[int] = [] + + def _do(conn): + cursor = conn.execute( + "SELECT id FROM messages " + "WHERE session_id = ? AND id >= ? AND active = 1", + (session_id, target_message_id), + ) + ids = [r[0] for r in cursor.fetchall()] + if ids: + placeholders = ",".join("?" for _ in ids) + conn.execute( + f"UPDATE messages SET active = 0 WHERE id IN ({placeholders})", + ids, + ) + conn.execute( + "UPDATE sessions SET rewind_count = COALESCE(rewind_count, 0) + 1 " + "WHERE id = ?", + (session_id,), + ) + return ids + + rewound = self._execute_write(_do) + + # 2) Compute new head id (largest still-active row id in session). + with self._lock: + head_row = self._conn.execute( + "SELECT MAX(id) FROM messages WHERE session_id = ? AND active = 1", + (session_id,), + ).fetchone() + new_head_id = head_row[0] if head_row and head_row[0] is not None else None + + return { + "rewound_count": len(rewound), + "target_message": target_row, + "new_head_id": new_head_id, + } + + def restore_rewound(self, session_id: str, since_message_id: int) -> int: + """Mark inactive messages with id >= *since_message_id* active again. + + Returns the number of rows flipped back to ``active=1``. + Intended for undo-of-rewind and test cleanup; not wired to a + slash command in v1. + """ + def _do(conn): + cursor = conn.execute( + "SELECT id FROM messages " + "WHERE session_id = ? AND id >= ? AND active = 0", + (session_id, since_message_id), + ) + ids = [r[0] for r in cursor.fetchall()] + if ids: + placeholders = ",".join("?" for _ in ids) + conn.execute( + f"UPDATE messages SET active = 1 WHERE id IN ({placeholders})", + ids, + ) + return len(ids) + + return self._execute_write(_do) + + def list_recent_user_messages( + self, + session_id: str, + limit: int = 20, + include_inactive: bool = False, + ) -> List[Dict[str, Any]]: + """Return the *limit* most-recent user messages, newest first. + + Each entry is a dict with keys ``id``, ``timestamp``, ``preview``. + ``preview`` is the first 80 characters of the message content + (with line breaks collapsed to spaces). Used by the /rewind + slash command picker. + + By default only active messages are returned. + """ + active_clause = "" if include_inactive else " AND active = 1" + with self._lock: + cursor = self._conn.execute( + "SELECT id, timestamp, content FROM messages " + "WHERE session_id = ? AND role = 'user'" + f"{active_clause} " + "ORDER BY id DESC LIMIT ?", + (session_id, int(limit)), + ) + rows = cursor.fetchall() + + result: List[Dict[str, Any]] = [] + for row in rows: + decoded = self._decode_content(row["content"]) + if isinstance(decoded, list): + # Multimodal — flatten text parts. + text_parts = [ + p.get("text", "") for p in decoded + if isinstance(p, dict) and p.get("type") == "text" + ] + preview = " ".join(t for t in text_parts if t).strip() + if not preview: + preview = "[multimodal content]" + elif isinstance(decoded, str): + preview = decoded + else: + preview = "" + preview = " ".join(preview.split()) # collapse whitespace + if len(preview) > 80: + preview = preview[:77] + "..." + result.append( + { + "id": row["id"], + "timestamp": row["timestamp"], + "preview": preview, + } + ) + return result + # ========================================================================= # Search # ========================================================================= @@ -2467,6 +2685,7 @@ class SessionDB: limit: int = 20, offset: int = 0, sort: str = None, + include_inactive: bool = False, ) -> List[Dict[str, Any]]: """ Full-text search across session messages using FTS5. @@ -2488,6 +2707,9 @@ class SessionDB: The short-CJK LIKE fallback already orders by timestamp DESC and ignores ``sort``. The trigram CJK path honours ``sort`` like the main FTS5 path. + + Rewound (``active=0``) rows are excluded by default. Pass + ``include_inactive=True`` to search every row. """ if not self._fts_enabled: return [] @@ -2521,6 +2743,8 @@ class SessionDB: # Build WHERE clauses dynamically where_clauses = ["messages_fts MATCH ?"] params: list = [query] + if not include_inactive: + where_clauses.append("m.active = 1") if source_filter is not None: source_placeholders = ",".join("?" for _ in source_filter) @@ -2600,6 +2824,8 @@ class SessionDB: trigram_query = " ".join(parts) tri_where = ["messages_fts_trigram MATCH ?"] tri_params: list = [trigram_query] + if not include_inactive: + tri_where.append("m.active = 1") if source_filter is not None: tri_where.append(f"s.source IN ({','.join('?' for _ in source_filter)})") tri_params.extend(source_filter)