fix(tui): add thread-safety locks for _sessions and prompt dicts

C1: Add _sessions_lock to protect all compound mutations and iterations
    on the global _sessions dict across 5+ concurrent execution contexts
    (main dispatcher, pool workers, daemon threads, notification poller,
    atexit handler).

C2: Add _prompt_lock to protect _pending/_pending_prompt_payloads/_answers
    dicts from races between _block() (agent callback thread) and
    _respond() (pool worker).  Lock scope is kept tight — _block() only
    holds the lock during registration/cleanup, releasing it before
    _emit() and ev.wait() to avoid blocking other prompts for 300s.

All 187 existing TUI tests pass with no regressions.
This commit is contained in:
asill-livestream
2026-06-04 06:20:31 +09:00
committed by Teknium
parent 2069e78b88
commit 5bcb63e400

View File

@ -125,6 +125,8 @@ _db = None
_db_error: str | None = None _db_error: str | None = None
_stdout_lock = threading.Lock() _stdout_lock = threading.Lock()
_cfg_lock = threading.Lock() _cfg_lock = threading.Lock()
_sessions_lock = threading.Lock()
_prompt_lock = threading.Lock()
_cfg_cache: dict | None = None _cfg_cache: dict | None = None
_cfg_mtime: float | None = None _cfg_mtime: float | None = None
_cfg_path = None _cfg_path = None
@ -325,7 +327,9 @@ def _finalize_session(session: dict | None, end_reason: str = "tui_close") -> No
def _shutdown_sessions() -> None: def _shutdown_sessions() -> None:
for session in list(_sessions.values()): with _sessions_lock:
snapshot = list(_sessions.values())
for session in snapshot:
_finalize_session(session, end_reason="tui_shutdown") _finalize_session(session, end_reason="tui_shutdown")
try: try:
worker = session.get("slash_worker") worker = session.get("slash_worker")
@ -546,7 +550,8 @@ def _start_agent_build(sid: str, session: dict) -> None:
key = session["session_key"] key = session["session_key"]
def _build() -> None: def _build() -> None:
current = _sessions.get(sid) with _sessions_lock:
current = _sessions.get(sid)
if current is None: if current is None:
ready.set() ready.set()
return return
@ -585,7 +590,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
pass pass
_wire_callbacks(sid) _wire_callbacks(sid)
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) with _sessions_lock:
if sid in _sessions:
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
_notify_session_boundary("on_session_reset", key) _notify_session_boundary("on_session_reset", key)
info = _session_info(agent, current) info = _session_info(agent, current)
@ -598,7 +605,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
current["agent_error"] = str(e) current["agent_error"] = str(e)
_emit("error", sid, {"message": f"agent init failed: {e}"}) _emit("error", sid, {"message": f"agent init failed: {e}"})
finally: finally:
if _sessions.get(sid) is not current: with _sessions_lock:
replaced = _sessions.get(sid) is not current
if replaced:
if worker is not None: if worker is not None:
try: try:
worker.close() worker.close()
@ -819,9 +828,10 @@ def _cwd_for_session_key(session_key: str) -> str:
""" """
if not session_key: if not session_key:
return "" return ""
for sess in list(_sessions.values()): with _sessions_lock:
if sess.get("session_key") == session_key: for sess in list(_sessions.values()):
return str(sess.get("cwd") or "") if sess.get("session_key") == session_key:
return str(sess.get("cwd") or "")
return "" return ""
@ -863,16 +873,19 @@ def _enable_gateway_prompts() -> None:
def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str: def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str:
rid = uuid.uuid4().hex[:8] rid = uuid.uuid4().hex[:8]
ev = threading.Event() ev = threading.Event()
_pending[rid] = (sid, ev) with _prompt_lock:
payload["request_id"] = rid _pending[rid] = (sid, ev)
_pending_prompt_payloads[rid] = (event, dict(payload)) payload["request_id"] = rid
_pending_prompt_payloads[rid] = (event, dict(payload))
try: try:
_emit(event, sid, payload) _emit(event, sid, payload)
ev.wait(timeout=timeout) ev.wait(timeout=timeout)
finally: finally:
_pending.pop(rid, None) with _prompt_lock:
_pending_prompt_payloads.pop(rid, None) _pending.pop(rid, None)
return _answers.pop(rid, "") _pending_prompt_payloads.pop(rid, None)
with _prompt_lock:
return _answers.pop(rid, "")
def _clear_pending(sid: str | None = None) -> None: def _clear_pending(sid: str | None = None) -> None:
@ -884,10 +897,11 @@ def _clear_pending(sid: str | None = None) -> None:
sessions sharing the same tui_gateway process. When *sid* is sessions sharing the same tui_gateway process. When *sid* is
None, every pending prompt is released (used during shutdown). None, every pending prompt is released (used during shutdown).
""" """
for rid, (owner_sid, ev) in list(_pending.items()): with _prompt_lock:
if sid is None or owner_sid == sid: for rid, (owner_sid, ev) in list(_pending.items()):
_answers[rid] = "" if sid is None or owner_sid == sid:
ev.set() _answers[rid] = ""
ev.set()
# ── Agent factory ──────────────────────────────────────────────────── # ── Agent factory ────────────────────────────────────────────────────
@ -2402,34 +2416,37 @@ def _make_agent(sid: str, key: str, session_id: str | None = None):
def _init_session(sid: str, key: str, agent, history: list, cols: int = 80): def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
now = time.time() now = time.time()
_sessions[sid] = { with _sessions_lock:
"agent": agent, _sessions[sid] = {
"session_key": key, "agent": agent,
"history": history, "session_key": key,
"history_lock": threading.Lock(), "history": history,
"history_version": 0, "history_lock": threading.Lock(),
"inflight_turn": None, "history_version": 0,
"created_at": now, "inflight_turn": None,
"last_active": now, "created_at": now,
"running": False, "last_active": now,
"attached_images": [], "running": False,
"image_counter": 0, "attached_images": [],
"cwd": _completion_cwd(), "image_counter": 0,
"cols": cols, "cwd": _completion_cwd(),
"slash_worker": None, "cols": cols,
"show_reasoning": _load_show_reasoning(), "slash_worker": None,
"tool_progress_mode": _load_tool_progress_mode(), "show_reasoning": _load_show_reasoning(),
"edit_snapshots": {}, "tool_progress_mode": _load_tool_progress_mode(),
"tool_started_at": {}, "edit_snapshots": {},
# Pin async event emissions to whichever transport created the "tool_started_at": {},
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar). # Pin async event emissions to whichever transport created the
"transport": current_transport() or _stdio_transport, # session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
} "transport": current_transport() or _stdio_transport,
}
db = _get_db() db = _get_db()
if db is not None: if db is not None:
row = db.get_session(key) row = db.get_session(key)
if row and row.get("cwd"): if row and row.get("cwd"):
_sessions[sid]["cwd"] = row["cwd"] with _sessions_lock:
if sid in _sessions:
_sessions[sid]["cwd"] = row["cwd"]
else: else:
try: try:
db.update_session_cwd(key, _sessions[sid]["cwd"]) db.update_session_cwd(key, _sessions[sid]["cwd"])
@ -2464,9 +2481,11 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
# session startup resilient). # session startup resilient).
pass pass
_wire_callbacks(sid) _wire_callbacks(sid)
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) with _sessions_lock:
if sid in _sessions:
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
_notify_session_boundary("on_session_reset", key) _notify_session_boundary("on_session_reset", key)
_emit("session.info", sid, _session_info(agent, _sessions[sid])) _emit("session.info", sid, _session_info(agent, _sessions.get(sid, {})))
def _new_session_key() -> str: def _new_session_key() -> str:
@ -2818,32 +2837,33 @@ def _(rid, params: dict) -> dict:
ready = threading.Event() ready = threading.Event()
now = time.time() now = time.time()
_sessions[sid] = { with _sessions_lock:
"agent": None, _sessions[sid] = {
"agent_error": None, "agent": None,
"agent_ready": ready, "agent_error": None,
"attached_images": [], "agent_ready": ready,
"cols": cols, "attached_images": [],
"created_at": now, "cols": cols,
"edit_snapshots": {}, "created_at": now,
"explicit_cwd": explicit_cwd, "edit_snapshots": {},
"history": history, "explicit_cwd": explicit_cwd,
"history_lock": threading.Lock(), "history": history,
"history_version": 0, "history_lock": threading.Lock(),
"image_counter": 0, "history_version": 0,
"cwd": resolved_cwd, "image_counter": 0,
"inflight_turn": None, "cwd": resolved_cwd,
"last_active": now, "inflight_turn": None,
"pending_title": title or None, "last_active": now,
"running": False, "pending_title": title or None,
"session_key": key, "running": False,
"show_reasoning": _load_show_reasoning(), "session_key": key,
"slash_worker": None, "show_reasoning": _load_show_reasoning(),
"tool_progress_mode": _load_tool_progress_mode(), "slash_worker": None,
"tool_started_at": {}, "tool_progress_mode": _load_tool_progress_mode(),
"transport": current_transport() or _stdio_transport, "tool_started_at": {},
} "transport": current_transport() or _stdio_transport,
_register_session_cwd(_sessions[sid]) }
_register_session_cwd(_sessions[sid])
# NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop # NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop
# launch (and every "New agent" / draft) opens a session here just to paint # launch (and every "New agent" / draft) opens a session here just to paint
# the composer, so eagerly creating a row left an "Untitled" empty session # the composer, so eagerly creating a row left an "Untitled" empty session
@ -3233,7 +3253,8 @@ def _(rid, params: dict) -> dict:
""" """
current = str(params.get("current_session_id") or "") current = str(params.get("current_session_id") or "")
try: try:
snapshot = list(_sessions.items()) with _sessions_lock:
snapshot = list(_sessions.items())
except Exception as e: except Exception as e:
return _err(rid, 5036, f"could not enumerate active sessions: {e}") return _err(rid, 5036, f"could not enumerate active sessions: {e}")
@ -3287,7 +3308,8 @@ def _(rid, params: dict) -> dict:
# dictionary changed size during iteration``. If even the snapshot # dictionary changed size during iteration``. If even the snapshot
# raises, fail closed (refuse the delete) rather than fail open. # raises, fail closed (refuse the delete) rather than fail open.
try: try:
snapshot = list(_sessions.values()) with _sessions_lock:
snapshot = list(_sessions.values())
except Exception as e: except Exception as e:
return _err(rid, 5036, f"could not enumerate active sessions: {e}") return _err(rid, 5036, f"could not enumerate active sessions: {e}")
active = {s.get("session_key") for s in snapshot if s.get("session_key")} active = {s.get("session_key") for s in snapshot if s.get("session_key")}
@ -3644,11 +3666,13 @@ def _(rid, params: dict) -> dict:
@method("session.close") @method("session.close")
def _(rid, params: dict) -> dict: def _(rid, params: dict) -> dict:
sid = params.get("session_id", "") sid = params.get("session_id", "")
current = _sessions.get(sid) with _sessions_lock:
current = _sessions.get(sid)
if not current: if not current:
return _ok(rid, {"closed": False}) return _ok(rid, {"closed": False})
with _session_resume_lock: with _session_resume_lock:
session = _sessions.pop(sid, None) with _sessions_lock:
session = _sessions.pop(sid, None)
if not session: if not session:
return _ok(rid, {"closed": False}) return _ok(rid, {"closed": False})
_finalize_session(session) _finalize_session(session)
@ -4097,7 +4121,8 @@ def _notification_event_belongs_elsewhere(session: dict, evt: dict) -> bool:
if evt_key == str(session.get("session_key") or ""): if evt_key == str(session.get("session_key") or ""):
return False return False
try: try:
snapshot = list(_sessions.values()) with _sessions_lock:
snapshot = list(_sessions.values())
except Exception: except Exception:
# If we can't safely enumerate live sessions, fail open so we don't # If we can't safely enumerate live sessions, fail open so we don't
# crash the poller thread or drop the event. # crash the poller thread or drop the event.
@ -5000,12 +5025,13 @@ def _(rid, params: dict) -> dict:
def _respond(rid, params, key): def _respond(rid, params, key):
r = params.get("request_id", "") r = params.get("request_id", "")
entry = _pending.get(r) with _prompt_lock:
if not entry: entry = _pending.get(r)
return _err(rid, 4009, f"no pending {key} request") if not entry:
_, ev = entry return _err(rid, 4009, f"no pending {key} request")
_answers[r] = params.get(key, "") _, ev = entry
ev.set() _answers[r] = params.get(key, "")
ev.set()
return _ok(rid, {"status": "ok"}) return _ok(rid, {"status": "ok"})