From 5bcb63e400987e937e696b385e01285339b565c5 Mon Sep 17 00:00:00 2001 From: asill-livestream Date: Thu, 4 Jun 2026 06:20:31 +0900 Subject: [PATCH] fix(tui): add thread-safety locks for _sessions and prompt dicts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit C1: Add _sessions_lock to protect all compound mutations and iterations on the global _sessions dict across 5+ concurrent execution contexts (main dispatcher, pool workers, daemon threads, notification poller, atexit handler). C2: Add _prompt_lock to protect _pending/_pending_prompt_payloads/_answers dicts from races between _block() (agent callback thread) and _respond() (pool worker). Lock scope is kept tight — _block() only holds the lock during registration/cleanup, releasing it before _emit() and ev.wait() to avoid blocking other prompts for 300s. All 187 existing TUI tests pass with no regressions. --- tui_gateway/server.py | 186 ++++++++++++++++++++++++------------------ 1 file changed, 106 insertions(+), 80 deletions(-) diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 61822b6da..266d2e3ff 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -125,6 +125,8 @@ _db = None _db_error: str | None = None _stdout_lock = threading.Lock() _cfg_lock = threading.Lock() +_sessions_lock = threading.Lock() +_prompt_lock = threading.Lock() _cfg_cache: dict | None = None _cfg_mtime: float | None = None _cfg_path = None @@ -325,7 +327,9 @@ def _finalize_session(session: dict | None, end_reason: str = "tui_close") -> No def _shutdown_sessions() -> None: - for session in list(_sessions.values()): + with _sessions_lock: + snapshot = list(_sessions.values()) + for session in snapshot: _finalize_session(session, end_reason="tui_shutdown") try: worker = session.get("slash_worker") @@ -546,7 +550,8 @@ def _start_agent_build(sid: str, session: dict) -> None: key = session["session_key"] def _build() -> None: - current = _sessions.get(sid) + with _sessions_lock: + current = _sessions.get(sid) if current is None: ready.set() return @@ -585,7 +590,9 @@ def _start_agent_build(sid: str, session: dict) -> None: pass _wire_callbacks(sid) - _sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) + with _sessions_lock: + if sid in _sessions: + _sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) _notify_session_boundary("on_session_reset", key) info = _session_info(agent, current) @@ -598,7 +605,9 @@ def _start_agent_build(sid: str, session: dict) -> None: current["agent_error"] = str(e) _emit("error", sid, {"message": f"agent init failed: {e}"}) finally: - if _sessions.get(sid) is not current: + with _sessions_lock: + replaced = _sessions.get(sid) is not current + if replaced: if worker is not None: try: worker.close() @@ -819,9 +828,10 @@ def _cwd_for_session_key(session_key: str) -> str: """ if not session_key: return "" - for sess in list(_sessions.values()): - if sess.get("session_key") == session_key: - return str(sess.get("cwd") or "") + with _sessions_lock: + for sess in list(_sessions.values()): + if sess.get("session_key") == session_key: + return str(sess.get("cwd") or "") return "" @@ -863,16 +873,19 @@ def _enable_gateway_prompts() -> None: def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str: rid = uuid.uuid4().hex[:8] ev = threading.Event() - _pending[rid] = (sid, ev) - payload["request_id"] = rid - _pending_prompt_payloads[rid] = (event, dict(payload)) + with _prompt_lock: + _pending[rid] = (sid, ev) + payload["request_id"] = rid + _pending_prompt_payloads[rid] = (event, dict(payload)) try: _emit(event, sid, payload) ev.wait(timeout=timeout) finally: - _pending.pop(rid, None) - _pending_prompt_payloads.pop(rid, None) - return _answers.pop(rid, "") + with _prompt_lock: + _pending.pop(rid, None) + _pending_prompt_payloads.pop(rid, None) + with _prompt_lock: + return _answers.pop(rid, "") def _clear_pending(sid: str | None = None) -> None: @@ -884,10 +897,11 @@ def _clear_pending(sid: str | None = None) -> None: sessions sharing the same tui_gateway process. When *sid* is None, every pending prompt is released (used during shutdown). """ - for rid, (owner_sid, ev) in list(_pending.items()): - if sid is None or owner_sid == sid: - _answers[rid] = "" - ev.set() + with _prompt_lock: + for rid, (owner_sid, ev) in list(_pending.items()): + if sid is None or owner_sid == sid: + _answers[rid] = "" + ev.set() # ── Agent factory ──────────────────────────────────────────────────── @@ -2402,34 +2416,37 @@ def _make_agent(sid: str, key: str, session_id: str | None = None): def _init_session(sid: str, key: str, agent, history: list, cols: int = 80): now = time.time() - _sessions[sid] = { - "agent": agent, - "session_key": key, - "history": history, - "history_lock": threading.Lock(), - "history_version": 0, - "inflight_turn": None, - "created_at": now, - "last_active": now, - "running": False, - "attached_images": [], - "image_counter": 0, - "cwd": _completion_cwd(), - "cols": cols, - "slash_worker": None, - "show_reasoning": _load_show_reasoning(), - "tool_progress_mode": _load_tool_progress_mode(), - "edit_snapshots": {}, - "tool_started_at": {}, - # Pin async event emissions to whichever transport created the - # session (stdio for Ink, JSON-RPC WS for the dashboard sidebar). - "transport": current_transport() or _stdio_transport, - } + with _sessions_lock: + _sessions[sid] = { + "agent": agent, + "session_key": key, + "history": history, + "history_lock": threading.Lock(), + "history_version": 0, + "inflight_turn": None, + "created_at": now, + "last_active": now, + "running": False, + "attached_images": [], + "image_counter": 0, + "cwd": _completion_cwd(), + "cols": cols, + "slash_worker": None, + "show_reasoning": _load_show_reasoning(), + "tool_progress_mode": _load_tool_progress_mode(), + "edit_snapshots": {}, + "tool_started_at": {}, + # Pin async event emissions to whichever transport created the + # session (stdio for Ink, JSON-RPC WS for the dashboard sidebar). + "transport": current_transport() or _stdio_transport, + } db = _get_db() if db is not None: row = db.get_session(key) if row and row.get("cwd"): - _sessions[sid]["cwd"] = row["cwd"] + with _sessions_lock: + if sid in _sessions: + _sessions[sid]["cwd"] = row["cwd"] else: try: db.update_session_cwd(key, _sessions[sid]["cwd"]) @@ -2464,9 +2481,11 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80): # session startup resilient). pass _wire_callbacks(sid) - _sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) + with _sessions_lock: + if sid in _sessions: + _sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid]) _notify_session_boundary("on_session_reset", key) - _emit("session.info", sid, _session_info(agent, _sessions[sid])) + _emit("session.info", sid, _session_info(agent, _sessions.get(sid, {}))) def _new_session_key() -> str: @@ -2818,32 +2837,33 @@ def _(rid, params: dict) -> dict: ready = threading.Event() now = time.time() - _sessions[sid] = { - "agent": None, - "agent_error": None, - "agent_ready": ready, - "attached_images": [], - "cols": cols, - "created_at": now, - "edit_snapshots": {}, - "explicit_cwd": explicit_cwd, - "history": history, - "history_lock": threading.Lock(), - "history_version": 0, - "image_counter": 0, - "cwd": resolved_cwd, - "inflight_turn": None, - "last_active": now, - "pending_title": title or None, - "running": False, - "session_key": key, - "show_reasoning": _load_show_reasoning(), - "slash_worker": None, - "tool_progress_mode": _load_tool_progress_mode(), - "tool_started_at": {}, - "transport": current_transport() or _stdio_transport, - } - _register_session_cwd(_sessions[sid]) + with _sessions_lock: + _sessions[sid] = { + "agent": None, + "agent_error": None, + "agent_ready": ready, + "attached_images": [], + "cols": cols, + "created_at": now, + "edit_snapshots": {}, + "explicit_cwd": explicit_cwd, + "history": history, + "history_lock": threading.Lock(), + "history_version": 0, + "image_counter": 0, + "cwd": resolved_cwd, + "inflight_turn": None, + "last_active": now, + "pending_title": title or None, + "running": False, + "session_key": key, + "show_reasoning": _load_show_reasoning(), + "slash_worker": None, + "tool_progress_mode": _load_tool_progress_mode(), + "tool_started_at": {}, + "transport": current_transport() or _stdio_transport, + } + _register_session_cwd(_sessions[sid]) # NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop # launch (and every "New agent" / draft) opens a session here just to paint # the composer, so eagerly creating a row left an "Untitled" empty session @@ -3233,7 +3253,8 @@ def _(rid, params: dict) -> dict: """ current = str(params.get("current_session_id") or "") try: - snapshot = list(_sessions.items()) + with _sessions_lock: + snapshot = list(_sessions.items()) except Exception as e: return _err(rid, 5036, f"could not enumerate active sessions: {e}") @@ -3287,7 +3308,8 @@ def _(rid, params: dict) -> dict: # dictionary changed size during iteration``. If even the snapshot # raises, fail closed (refuse the delete) rather than fail open. try: - snapshot = list(_sessions.values()) + with _sessions_lock: + snapshot = list(_sessions.values()) except Exception as e: return _err(rid, 5036, f"could not enumerate active sessions: {e}") active = {s.get("session_key") for s in snapshot if s.get("session_key")} @@ -3644,11 +3666,13 @@ def _(rid, params: dict) -> dict: @method("session.close") def _(rid, params: dict) -> dict: sid = params.get("session_id", "") - current = _sessions.get(sid) + with _sessions_lock: + current = _sessions.get(sid) if not current: return _ok(rid, {"closed": False}) with _session_resume_lock: - session = _sessions.pop(sid, None) + with _sessions_lock: + session = _sessions.pop(sid, None) if not session: return _ok(rid, {"closed": False}) _finalize_session(session) @@ -4097,7 +4121,8 @@ def _notification_event_belongs_elsewhere(session: dict, evt: dict) -> bool: if evt_key == str(session.get("session_key") or ""): return False try: - snapshot = list(_sessions.values()) + with _sessions_lock: + snapshot = list(_sessions.values()) except Exception: # If we can't safely enumerate live sessions, fail open so we don't # crash the poller thread or drop the event. @@ -5000,12 +5025,13 @@ def _(rid, params: dict) -> dict: def _respond(rid, params, key): r = params.get("request_id", "") - entry = _pending.get(r) - if not entry: - return _err(rid, 4009, f"no pending {key} request") - _, ev = entry - _answers[r] = params.get(key, "") - ev.set() + with _prompt_lock: + entry = _pending.get(r) + if not entry: + return _err(rid, 4009, f"no pending {key} request") + _, ev = entry + _answers[r] = params.get(key, "") + ev.set() return _ok(rid, {"status": "ok"})