fix(tui): add thread-safety locks for _sessions and prompt dicts

C1: Add _sessions_lock to protect all compound mutations and iterations
    on the global _sessions dict across 5+ concurrent execution contexts
    (main dispatcher, pool workers, daemon threads, notification poller,
    atexit handler).

C2: Add _prompt_lock to protect _pending/_pending_prompt_payloads/_answers
    dicts from races between _block() (agent callback thread) and
    _respond() (pool worker).  Lock scope is kept tight — _block() only
    holds the lock during registration/cleanup, releasing it before
    _emit() and ev.wait() to avoid blocking other prompts for 300s.

All 187 existing TUI tests pass with no regressions.
This commit is contained in:
asill-livestream
2026-06-04 06:20:31 +09:00
committed by Teknium
parent 2069e78b88
commit 5bcb63e400

View File

@ -125,6 +125,8 @@ _db = None
_db_error: str | None = None
_stdout_lock = threading.Lock()
_cfg_lock = threading.Lock()
_sessions_lock = threading.Lock()
_prompt_lock = threading.Lock()
_cfg_cache: dict | None = None
_cfg_mtime: float | None = None
_cfg_path = None
@ -325,7 +327,9 @@ def _finalize_session(session: dict | None, end_reason: str = "tui_close") -> No
def _shutdown_sessions() -> None:
for session in list(_sessions.values()):
with _sessions_lock:
snapshot = list(_sessions.values())
for session in snapshot:
_finalize_session(session, end_reason="tui_shutdown")
try:
worker = session.get("slash_worker")
@ -546,7 +550,8 @@ def _start_agent_build(sid: str, session: dict) -> None:
key = session["session_key"]
def _build() -> None:
current = _sessions.get(sid)
with _sessions_lock:
current = _sessions.get(sid)
if current is None:
ready.set()
return
@ -585,7 +590,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
pass
_wire_callbacks(sid)
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
with _sessions_lock:
if sid in _sessions:
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
_notify_session_boundary("on_session_reset", key)
info = _session_info(agent, current)
@ -598,7 +605,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
current["agent_error"] = str(e)
_emit("error", sid, {"message": f"agent init failed: {e}"})
finally:
if _sessions.get(sid) is not current:
with _sessions_lock:
replaced = _sessions.get(sid) is not current
if replaced:
if worker is not None:
try:
worker.close()
@ -819,9 +828,10 @@ def _cwd_for_session_key(session_key: str) -> str:
"""
if not session_key:
return ""
for sess in list(_sessions.values()):
if sess.get("session_key") == session_key:
return str(sess.get("cwd") or "")
with _sessions_lock:
for sess in list(_sessions.values()):
if sess.get("session_key") == session_key:
return str(sess.get("cwd") or "")
return ""
@ -863,16 +873,19 @@ def _enable_gateway_prompts() -> None:
def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str:
rid = uuid.uuid4().hex[:8]
ev = threading.Event()
_pending[rid] = (sid, ev)
payload["request_id"] = rid
_pending_prompt_payloads[rid] = (event, dict(payload))
with _prompt_lock:
_pending[rid] = (sid, ev)
payload["request_id"] = rid
_pending_prompt_payloads[rid] = (event, dict(payload))
try:
_emit(event, sid, payload)
ev.wait(timeout=timeout)
finally:
_pending.pop(rid, None)
_pending_prompt_payloads.pop(rid, None)
return _answers.pop(rid, "")
with _prompt_lock:
_pending.pop(rid, None)
_pending_prompt_payloads.pop(rid, None)
with _prompt_lock:
return _answers.pop(rid, "")
def _clear_pending(sid: str | None = None) -> None:
@ -884,10 +897,11 @@ def _clear_pending(sid: str | None = None) -> None:
sessions sharing the same tui_gateway process. When *sid* is
None, every pending prompt is released (used during shutdown).
"""
for rid, (owner_sid, ev) in list(_pending.items()):
if sid is None or owner_sid == sid:
_answers[rid] = ""
ev.set()
with _prompt_lock:
for rid, (owner_sid, ev) in list(_pending.items()):
if sid is None or owner_sid == sid:
_answers[rid] = ""
ev.set()
# ── Agent factory ────────────────────────────────────────────────────
@ -2402,34 +2416,37 @@ def _make_agent(sid: str, key: str, session_id: str | None = None):
def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
now = time.time()
_sessions[sid] = {
"agent": agent,
"session_key": key,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
"inflight_turn": None,
"created_at": now,
"last_active": now,
"running": False,
"attached_images": [],
"image_counter": 0,
"cwd": _completion_cwd(),
"cols": cols,
"slash_worker": None,
"show_reasoning": _load_show_reasoning(),
"tool_progress_mode": _load_tool_progress_mode(),
"edit_snapshots": {},
"tool_started_at": {},
# Pin async event emissions to whichever transport created the
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
"transport": current_transport() or _stdio_transport,
}
with _sessions_lock:
_sessions[sid] = {
"agent": agent,
"session_key": key,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
"inflight_turn": None,
"created_at": now,
"last_active": now,
"running": False,
"attached_images": [],
"image_counter": 0,
"cwd": _completion_cwd(),
"cols": cols,
"slash_worker": None,
"show_reasoning": _load_show_reasoning(),
"tool_progress_mode": _load_tool_progress_mode(),
"edit_snapshots": {},
"tool_started_at": {},
# Pin async event emissions to whichever transport created the
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
"transport": current_transport() or _stdio_transport,
}
db = _get_db()
if db is not None:
row = db.get_session(key)
if row and row.get("cwd"):
_sessions[sid]["cwd"] = row["cwd"]
with _sessions_lock:
if sid in _sessions:
_sessions[sid]["cwd"] = row["cwd"]
else:
try:
db.update_session_cwd(key, _sessions[sid]["cwd"])
@ -2464,9 +2481,11 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
# session startup resilient).
pass
_wire_callbacks(sid)
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
with _sessions_lock:
if sid in _sessions:
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
_notify_session_boundary("on_session_reset", key)
_emit("session.info", sid, _session_info(agent, _sessions[sid]))
_emit("session.info", sid, _session_info(agent, _sessions.get(sid, {})))
def _new_session_key() -> str:
@ -2818,32 +2837,33 @@ def _(rid, params: dict) -> dict:
ready = threading.Event()
now = time.time()
_sessions[sid] = {
"agent": None,
"agent_error": None,
"agent_ready": ready,
"attached_images": [],
"cols": cols,
"created_at": now,
"edit_snapshots": {},
"explicit_cwd": explicit_cwd,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
"image_counter": 0,
"cwd": resolved_cwd,
"inflight_turn": None,
"last_active": now,
"pending_title": title or None,
"running": False,
"session_key": key,
"show_reasoning": _load_show_reasoning(),
"slash_worker": None,
"tool_progress_mode": _load_tool_progress_mode(),
"tool_started_at": {},
"transport": current_transport() or _stdio_transport,
}
_register_session_cwd(_sessions[sid])
with _sessions_lock:
_sessions[sid] = {
"agent": None,
"agent_error": None,
"agent_ready": ready,
"attached_images": [],
"cols": cols,
"created_at": now,
"edit_snapshots": {},
"explicit_cwd": explicit_cwd,
"history": history,
"history_lock": threading.Lock(),
"history_version": 0,
"image_counter": 0,
"cwd": resolved_cwd,
"inflight_turn": None,
"last_active": now,
"pending_title": title or None,
"running": False,
"session_key": key,
"show_reasoning": _load_show_reasoning(),
"slash_worker": None,
"tool_progress_mode": _load_tool_progress_mode(),
"tool_started_at": {},
"transport": current_transport() or _stdio_transport,
}
_register_session_cwd(_sessions[sid])
# NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop
# launch (and every "New agent" / draft) opens a session here just to paint
# the composer, so eagerly creating a row left an "Untitled" empty session
@ -3233,7 +3253,8 @@ def _(rid, params: dict) -> dict:
"""
current = str(params.get("current_session_id") or "")
try:
snapshot = list(_sessions.items())
with _sessions_lock:
snapshot = list(_sessions.items())
except Exception as e:
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
@ -3287,7 +3308,8 @@ def _(rid, params: dict) -> dict:
# dictionary changed size during iteration``. If even the snapshot
# raises, fail closed (refuse the delete) rather than fail open.
try:
snapshot = list(_sessions.values())
with _sessions_lock:
snapshot = list(_sessions.values())
except Exception as e:
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
active = {s.get("session_key") for s in snapshot if s.get("session_key")}
@ -3644,11 +3666,13 @@ def _(rid, params: dict) -> dict:
@method("session.close")
def _(rid, params: dict) -> dict:
sid = params.get("session_id", "")
current = _sessions.get(sid)
with _sessions_lock:
current = _sessions.get(sid)
if not current:
return _ok(rid, {"closed": False})
with _session_resume_lock:
session = _sessions.pop(sid, None)
with _sessions_lock:
session = _sessions.pop(sid, None)
if not session:
return _ok(rid, {"closed": False})
_finalize_session(session)
@ -4097,7 +4121,8 @@ def _notification_event_belongs_elsewhere(session: dict, evt: dict) -> bool:
if evt_key == str(session.get("session_key") or ""):
return False
try:
snapshot = list(_sessions.values())
with _sessions_lock:
snapshot = list(_sessions.values())
except Exception:
# If we can't safely enumerate live sessions, fail open so we don't
# crash the poller thread or drop the event.
@ -5000,12 +5025,13 @@ def _(rid, params: dict) -> dict:
def _respond(rid, params, key):
r = params.get("request_id", "")
entry = _pending.get(r)
if not entry:
return _err(rid, 4009, f"no pending {key} request")
_, ev = entry
_answers[r] = params.get(key, "")
ev.set()
with _prompt_lock:
entry = _pending.get(r)
if not entry:
return _err(rid, 4009, f"no pending {key} request")
_, ev = entry
_answers[r] = params.get(key, "")
ev.set()
return _ok(rid, {"status": "ok"})