fix(tui): add thread-safety locks for _sessions and prompt dicts
C1: Add _sessions_lock to protect all compound mutations and iterations
on the global _sessions dict across 5+ concurrent execution contexts
(main dispatcher, pool workers, daemon threads, notification poller,
atexit handler).
C2: Add _prompt_lock to protect _pending/_pending_prompt_payloads/_answers
dicts from races between _block() (agent callback thread) and
_respond() (pool worker). Lock scope is kept tight — _block() only
holds the lock during registration/cleanup, releasing it before
_emit() and ev.wait() to avoid blocking other prompts for 300s.
All 187 existing TUI tests pass with no regressions.
This commit is contained in:
committed by
Teknium
parent
2069e78b88
commit
5bcb63e400
@ -125,6 +125,8 @@ _db = None
|
||||
_db_error: str | None = None
|
||||
_stdout_lock = threading.Lock()
|
||||
_cfg_lock = threading.Lock()
|
||||
_sessions_lock = threading.Lock()
|
||||
_prompt_lock = threading.Lock()
|
||||
_cfg_cache: dict | None = None
|
||||
_cfg_mtime: float | None = None
|
||||
_cfg_path = None
|
||||
@ -325,7 +327,9 @@ def _finalize_session(session: dict | None, end_reason: str = "tui_close") -> No
|
||||
|
||||
|
||||
def _shutdown_sessions() -> None:
|
||||
for session in list(_sessions.values()):
|
||||
with _sessions_lock:
|
||||
snapshot = list(_sessions.values())
|
||||
for session in snapshot:
|
||||
_finalize_session(session, end_reason="tui_shutdown")
|
||||
try:
|
||||
worker = session.get("slash_worker")
|
||||
@ -546,7 +550,8 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
||||
key = session["session_key"]
|
||||
|
||||
def _build() -> None:
|
||||
current = _sessions.get(sid)
|
||||
with _sessions_lock:
|
||||
current = _sessions.get(sid)
|
||||
if current is None:
|
||||
ready.set()
|
||||
return
|
||||
@ -585,7 +590,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
||||
pass
|
||||
|
||||
_wire_callbacks(sid)
|
||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||
with _sessions_lock:
|
||||
if sid in _sessions:
|
||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||
_notify_session_boundary("on_session_reset", key)
|
||||
|
||||
info = _session_info(agent, current)
|
||||
@ -598,7 +605,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
||||
current["agent_error"] = str(e)
|
||||
_emit("error", sid, {"message": f"agent init failed: {e}"})
|
||||
finally:
|
||||
if _sessions.get(sid) is not current:
|
||||
with _sessions_lock:
|
||||
replaced = _sessions.get(sid) is not current
|
||||
if replaced:
|
||||
if worker is not None:
|
||||
try:
|
||||
worker.close()
|
||||
@ -819,9 +828,10 @@ def _cwd_for_session_key(session_key: str) -> str:
|
||||
"""
|
||||
if not session_key:
|
||||
return ""
|
||||
for sess in list(_sessions.values()):
|
||||
if sess.get("session_key") == session_key:
|
||||
return str(sess.get("cwd") or "")
|
||||
with _sessions_lock:
|
||||
for sess in list(_sessions.values()):
|
||||
if sess.get("session_key") == session_key:
|
||||
return str(sess.get("cwd") or "")
|
||||
return ""
|
||||
|
||||
|
||||
@ -863,16 +873,19 @@ def _enable_gateway_prompts() -> None:
|
||||
def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str:
|
||||
rid = uuid.uuid4().hex[:8]
|
||||
ev = threading.Event()
|
||||
_pending[rid] = (sid, ev)
|
||||
payload["request_id"] = rid
|
||||
_pending_prompt_payloads[rid] = (event, dict(payload))
|
||||
with _prompt_lock:
|
||||
_pending[rid] = (sid, ev)
|
||||
payload["request_id"] = rid
|
||||
_pending_prompt_payloads[rid] = (event, dict(payload))
|
||||
try:
|
||||
_emit(event, sid, payload)
|
||||
ev.wait(timeout=timeout)
|
||||
finally:
|
||||
_pending.pop(rid, None)
|
||||
_pending_prompt_payloads.pop(rid, None)
|
||||
return _answers.pop(rid, "")
|
||||
with _prompt_lock:
|
||||
_pending.pop(rid, None)
|
||||
_pending_prompt_payloads.pop(rid, None)
|
||||
with _prompt_lock:
|
||||
return _answers.pop(rid, "")
|
||||
|
||||
|
||||
def _clear_pending(sid: str | None = None) -> None:
|
||||
@ -884,10 +897,11 @@ def _clear_pending(sid: str | None = None) -> None:
|
||||
sessions sharing the same tui_gateway process. When *sid* is
|
||||
None, every pending prompt is released (used during shutdown).
|
||||
"""
|
||||
for rid, (owner_sid, ev) in list(_pending.items()):
|
||||
if sid is None or owner_sid == sid:
|
||||
_answers[rid] = ""
|
||||
ev.set()
|
||||
with _prompt_lock:
|
||||
for rid, (owner_sid, ev) in list(_pending.items()):
|
||||
if sid is None or owner_sid == sid:
|
||||
_answers[rid] = ""
|
||||
ev.set()
|
||||
|
||||
|
||||
# ── Agent factory ────────────────────────────────────────────────────
|
||||
@ -2402,34 +2416,37 @@ def _make_agent(sid: str, key: str, session_id: str | None = None):
|
||||
|
||||
def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
|
||||
now = time.time()
|
||||
_sessions[sid] = {
|
||||
"agent": agent,
|
||||
"session_key": key,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
"inflight_turn": None,
|
||||
"created_at": now,
|
||||
"last_active": now,
|
||||
"running": False,
|
||||
"attached_images": [],
|
||||
"image_counter": 0,
|
||||
"cwd": _completion_cwd(),
|
||||
"cols": cols,
|
||||
"slash_worker": None,
|
||||
"show_reasoning": _load_show_reasoning(),
|
||||
"tool_progress_mode": _load_tool_progress_mode(),
|
||||
"edit_snapshots": {},
|
||||
"tool_started_at": {},
|
||||
# Pin async event emissions to whichever transport created the
|
||||
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
|
||||
"transport": current_transport() or _stdio_transport,
|
||||
}
|
||||
with _sessions_lock:
|
||||
_sessions[sid] = {
|
||||
"agent": agent,
|
||||
"session_key": key,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
"inflight_turn": None,
|
||||
"created_at": now,
|
||||
"last_active": now,
|
||||
"running": False,
|
||||
"attached_images": [],
|
||||
"image_counter": 0,
|
||||
"cwd": _completion_cwd(),
|
||||
"cols": cols,
|
||||
"slash_worker": None,
|
||||
"show_reasoning": _load_show_reasoning(),
|
||||
"tool_progress_mode": _load_tool_progress_mode(),
|
||||
"edit_snapshots": {},
|
||||
"tool_started_at": {},
|
||||
# Pin async event emissions to whichever transport created the
|
||||
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
|
||||
"transport": current_transport() or _stdio_transport,
|
||||
}
|
||||
db = _get_db()
|
||||
if db is not None:
|
||||
row = db.get_session(key)
|
||||
if row and row.get("cwd"):
|
||||
_sessions[sid]["cwd"] = row["cwd"]
|
||||
with _sessions_lock:
|
||||
if sid in _sessions:
|
||||
_sessions[sid]["cwd"] = row["cwd"]
|
||||
else:
|
||||
try:
|
||||
db.update_session_cwd(key, _sessions[sid]["cwd"])
|
||||
@ -2464,9 +2481,11 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
|
||||
# session startup resilient).
|
||||
pass
|
||||
_wire_callbacks(sid)
|
||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||
with _sessions_lock:
|
||||
if sid in _sessions:
|
||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||
_notify_session_boundary("on_session_reset", key)
|
||||
_emit("session.info", sid, _session_info(agent, _sessions[sid]))
|
||||
_emit("session.info", sid, _session_info(agent, _sessions.get(sid, {})))
|
||||
|
||||
|
||||
def _new_session_key() -> str:
|
||||
@ -2818,32 +2837,33 @@ def _(rid, params: dict) -> dict:
|
||||
ready = threading.Event()
|
||||
now = time.time()
|
||||
|
||||
_sessions[sid] = {
|
||||
"agent": None,
|
||||
"agent_error": None,
|
||||
"agent_ready": ready,
|
||||
"attached_images": [],
|
||||
"cols": cols,
|
||||
"created_at": now,
|
||||
"edit_snapshots": {},
|
||||
"explicit_cwd": explicit_cwd,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
"image_counter": 0,
|
||||
"cwd": resolved_cwd,
|
||||
"inflight_turn": None,
|
||||
"last_active": now,
|
||||
"pending_title": title or None,
|
||||
"running": False,
|
||||
"session_key": key,
|
||||
"show_reasoning": _load_show_reasoning(),
|
||||
"slash_worker": None,
|
||||
"tool_progress_mode": _load_tool_progress_mode(),
|
||||
"tool_started_at": {},
|
||||
"transport": current_transport() or _stdio_transport,
|
||||
}
|
||||
_register_session_cwd(_sessions[sid])
|
||||
with _sessions_lock:
|
||||
_sessions[sid] = {
|
||||
"agent": None,
|
||||
"agent_error": None,
|
||||
"agent_ready": ready,
|
||||
"attached_images": [],
|
||||
"cols": cols,
|
||||
"created_at": now,
|
||||
"edit_snapshots": {},
|
||||
"explicit_cwd": explicit_cwd,
|
||||
"history": history,
|
||||
"history_lock": threading.Lock(),
|
||||
"history_version": 0,
|
||||
"image_counter": 0,
|
||||
"cwd": resolved_cwd,
|
||||
"inflight_turn": None,
|
||||
"last_active": now,
|
||||
"pending_title": title or None,
|
||||
"running": False,
|
||||
"session_key": key,
|
||||
"show_reasoning": _load_show_reasoning(),
|
||||
"slash_worker": None,
|
||||
"tool_progress_mode": _load_tool_progress_mode(),
|
||||
"tool_started_at": {},
|
||||
"transport": current_transport() or _stdio_transport,
|
||||
}
|
||||
_register_session_cwd(_sessions[sid])
|
||||
# NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop
|
||||
# launch (and every "New agent" / draft) opens a session here just to paint
|
||||
# the composer, so eagerly creating a row left an "Untitled" empty session
|
||||
@ -3233,7 +3253,8 @@ def _(rid, params: dict) -> dict:
|
||||
"""
|
||||
current = str(params.get("current_session_id") or "")
|
||||
try:
|
||||
snapshot = list(_sessions.items())
|
||||
with _sessions_lock:
|
||||
snapshot = list(_sessions.items())
|
||||
except Exception as e:
|
||||
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
||||
|
||||
@ -3287,7 +3308,8 @@ def _(rid, params: dict) -> dict:
|
||||
# dictionary changed size during iteration``. If even the snapshot
|
||||
# raises, fail closed (refuse the delete) rather than fail open.
|
||||
try:
|
||||
snapshot = list(_sessions.values())
|
||||
with _sessions_lock:
|
||||
snapshot = list(_sessions.values())
|
||||
except Exception as e:
|
||||
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
||||
active = {s.get("session_key") for s in snapshot if s.get("session_key")}
|
||||
@ -3644,11 +3666,13 @@ def _(rid, params: dict) -> dict:
|
||||
@method("session.close")
|
||||
def _(rid, params: dict) -> dict:
|
||||
sid = params.get("session_id", "")
|
||||
current = _sessions.get(sid)
|
||||
with _sessions_lock:
|
||||
current = _sessions.get(sid)
|
||||
if not current:
|
||||
return _ok(rid, {"closed": False})
|
||||
with _session_resume_lock:
|
||||
session = _sessions.pop(sid, None)
|
||||
with _sessions_lock:
|
||||
session = _sessions.pop(sid, None)
|
||||
if not session:
|
||||
return _ok(rid, {"closed": False})
|
||||
_finalize_session(session)
|
||||
@ -4097,7 +4121,8 @@ def _notification_event_belongs_elsewhere(session: dict, evt: dict) -> bool:
|
||||
if evt_key == str(session.get("session_key") or ""):
|
||||
return False
|
||||
try:
|
||||
snapshot = list(_sessions.values())
|
||||
with _sessions_lock:
|
||||
snapshot = list(_sessions.values())
|
||||
except Exception:
|
||||
# If we can't safely enumerate live sessions, fail open so we don't
|
||||
# crash the poller thread or drop the event.
|
||||
@ -5000,12 +5025,13 @@ def _(rid, params: dict) -> dict:
|
||||
|
||||
def _respond(rid, params, key):
|
||||
r = params.get("request_id", "")
|
||||
entry = _pending.get(r)
|
||||
if not entry:
|
||||
return _err(rid, 4009, f"no pending {key} request")
|
||||
_, ev = entry
|
||||
_answers[r] = params.get(key, "")
|
||||
ev.set()
|
||||
with _prompt_lock:
|
||||
entry = _pending.get(r)
|
||||
if not entry:
|
||||
return _err(rid, 4009, f"no pending {key} request")
|
||||
_, ev = entry
|
||||
_answers[r] = params.get(key, "")
|
||||
ev.set()
|
||||
return _ok(rid, {"status": "ok"})
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user