refactor(state): compute last_active ordering at SQL level via recursive CTE

Follow-up to the previous commit. Replace the post-fetch Python re-sort (which
required dropping LIMIT/OFFSET from SQL and scanning every session row) with a
recursive CTE that walks compression-continuation chains and computes
effective_last_active per root at SQL level. The outer query can then ORDER BY
+ LIMIT efficiently, and the Python projection loop no longer has to handle
ordering.

This preserves the correctness win (old compression roots whose live tip was
touched recently surface correctly) without the O(N) scan, which matters for
users with thousands of sessions.

Adds a regression test pinning the compression-tip case at limit=1 — the
stress case that any bounded-oversample shortcut would get wrong.

Co-authored-by: simbam99 <simbamax99@gmail.com>
This commit is contained in:
Teknium
2026-04-30 20:03:33 -07:00
parent 142b4bf3ce
commit 5089c55e0b
2 changed files with 139 additions and 40 deletions

View File

@ -955,9 +955,12 @@ class SessionDB:
raw root rows (useful for admin/debug UIs).
Pass ``order_by_last_active=True`` to sort by most-recent activity
instead of original conversation start time. This is computed after
compression-tip projection so "recent sessions" surfaces the live tip
of a compressed conversation in the correct slot.
instead of original conversation start time. For compression chains,
the "most-recent activity" is taken from the live tip (not the root),
so an old conversation that was compressed and continued recently
surfaces in the correct slot. Ordering is computed at SQL level via
a recursive CTE that walks compression-continuation edges, so LIMIT
and OFFSET still apply efficiently.
"""
where_clauses = []
params = []
@ -985,33 +988,80 @@ class SessionDB:
params.extend(exclude_sources)
where_sql = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
order_sql = (
"ORDER BY last_active DESC, s.started_at DESC, s.id DESC"
if order_by_last_active
else "ORDER BY s.started_at DESC"
)
limit_sql = ""
if not order_by_last_active:
limit_sql = "LIMIT ? OFFSET ?"
if order_by_last_active:
# Compute effective_last_active by walking each surfaced session's
# compression-continuation chain forward in SQL and taking the MAX
# timestamp across the chain. This lets us ORDER BY + LIMIT at SQL
# level instead of fetching every row and sorting in Python, while
# still surfacing old compression roots whose live tip is fresh.
#
# The CTE seeds from rows the outer WHERE admits (roots + branch
# children), then recursively joins forward through
# compression-continuation edges using the same criteria as
# get_compression_tip (parent.end_reason='compression' AND
# child.started_at >= parent.ended_at).
query = f"""
WITH RECURSIVE chain(root_id, cur_id) AS (
SELECT s.id, s.id FROM sessions s {where_sql}
UNION ALL
SELECT c.root_id, child.id
FROM chain c
JOIN sessions parent ON parent.id = c.cur_id
JOIN sessions child ON child.parent_session_id = c.cur_id
WHERE parent.end_reason = 'compression'
AND child.started_at >= parent.ended_at
),
chain_max AS (
SELECT
root_id,
MAX(COALESCE(
(SELECT MAX(m.timestamp) FROM messages m WHERE m.session_id = cur_id),
(SELECT started_at FROM sessions ss WHERE ss.id = cur_id)
)) AS effective_last_active
FROM chain
GROUP BY root_id
)
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active,
COALESCE(cm.effective_last_active, s.started_at) AS _effective_last_active
FROM sessions s
LEFT JOIN chain_max cm ON cm.root_id = s.id
{where_sql}
ORDER BY _effective_last_active DESC, s.started_at DESC, s.id DESC
LIMIT ? OFFSET ?
"""
# WHERE params apply twice (CTE seed + outer select).
params = params + params + [limit, offset]
else:
query = f"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
{where_sql}
ORDER BY s.started_at DESC
LIMIT ? OFFSET ?
"""
params.extend([limit, offset])
query = f"""
SELECT s.*,
COALESCE(
(SELECT SUBSTR(REPLACE(REPLACE(m.content, X'0A', ' '), X'0D', ' '), 1, 63)
FROM messages m
WHERE m.session_id = s.id AND m.role = 'user' AND m.content IS NOT NULL
ORDER BY m.timestamp, m.id LIMIT 1),
''
) AS _preview_raw,
COALESCE(
(SELECT MAX(m2.timestamp) FROM messages m2 WHERE m2.session_id = s.id),
s.started_at
) AS last_active
FROM sessions s
{where_sql}
{order_sql}
{limit_sql}
"""
with self._lock:
cursor = self._conn.execute(query, params)
rows = cursor.fetchall()
@ -1025,6 +1075,8 @@ class SessionDB:
s["preview"] = text + ("..." if len(raw) > 60 else "")
else:
s["preview"] = ""
# Drop the internal ordering column so callers see a clean dict.
s.pop("_effective_last_active", None)
sessions.append(s)
# Project compression roots forward to their tips. Each row whose
@ -1061,17 +1113,6 @@ class SessionDB:
projected.append(merged)
sessions = projected
if order_by_last_active:
sessions.sort(
key=lambda s: (
s.get("last_active") or s.get("started_at") or 0,
s.get("started_at") or 0,
s.get("id") or "",
),
reverse=True,
)
sessions = sessions[offset:offset + limit]
return sessions
def _get_session_rich_row(self, session_id: str) -> Optional[Dict[str, Any]]:

View File

@ -1752,6 +1752,64 @@ class TestListSessionsRich:
s["id"] for s in db.list_sessions_rich(limit=5, order_by_last_active=True)
] == ["old", "new"]
def test_order_by_last_active_uses_compression_tip_activity(self, db):
"""A compression root whose tip was touched recently must rank above
a newer uncompressed session, even when that tip activity lives in a
different row and the outer LIMIT could otherwise cut it.
This is the case that forced SQL-level chain walking: a naive "cap
the SQL fetch at limit*K" optimization would drop the old root off
the SQL page before post-projection could promote it.
"""
t0 = 1709500000.0
db.create_session("root1", "cli")
with db._lock:
db._conn.execute("UPDATE sessions SET started_at=? WHERE id=?", (t0, "root1"))
db._conn.execute(
"UPDATE sessions SET ended_at=?, end_reason=? WHERE id=?",
(t0 + 100, "compression", "root1"),
)
db.append_message("root1", "user", "old ask")
# Continuation tip created after root ended; last activity much later.
db.create_session("tip1", "cli", parent_session_id="root1")
with db._lock:
db._conn.execute("UPDATE sessions SET started_at=? WHERE id=?", (t0 + 101, "tip1"))
db.append_message("tip1", "user", "latest message")
# Bunch of newer, uncompressed sessions — fresher start_at but older
# last activity than the tip. Explicitly pin message timestamps so
# they don't pick up wall-clock from append_message.
for i in range(5):
sid = f"newer{i}"
db.create_session(sid, "cli")
with db._lock:
db._conn.execute(
"UPDATE sessions SET started_at=? WHERE id=?",
(t0 + 500 + i, sid),
)
db.append_message(sid, "user", f"msg {i}")
with db._lock:
db._conn.execute(
"UPDATE messages SET timestamp=? WHERE session_id=? AND content=?",
(t0 + 500 + i, sid, f"msg {i}"),
)
# Tip activity timestamp is the latest thing in the DB.
with db._lock:
db._conn.execute(
"UPDATE messages SET timestamp=? WHERE session_id=? AND content=?",
(t0 + 10_000, "tip1", "latest message"),
)
db._conn.commit()
# limit=1 is the stress test: the old root must win the single slot.
top = db.list_sessions_rich(limit=1, order_by_last_active=True)
assert len(top) == 1
# Projection surfaces the tip's id in the root's slot.
assert top[0]["id"] == "tip1"
assert top[0]["_lineage_root_id"] == "root1"
def test_rich_list_includes_title(self, db):
db.create_session("s1", "cli")
db.set_session_title("s1", "refactoring auth")