fix(tui): add thread-safety locks for _sessions and prompt dicts
C1: Add _sessions_lock to protect all compound mutations and iterations
on the global _sessions dict across 5+ concurrent execution contexts
(main dispatcher, pool workers, daemon threads, notification poller,
atexit handler).
C2: Add _prompt_lock to protect _pending/_pending_prompt_payloads/_answers
dicts from races between _block() (agent callback thread) and
_respond() (pool worker). Lock scope is kept tight — _block() only
holds the lock during registration/cleanup, releasing it before
_emit() and ev.wait() to avoid blocking other prompts for 300s.
All 187 existing TUI tests pass with no regressions.
This commit is contained in:
committed by
Teknium
parent
2069e78b88
commit
5bcb63e400
@ -125,6 +125,8 @@ _db = None
|
|||||||
_db_error: str | None = None
|
_db_error: str | None = None
|
||||||
_stdout_lock = threading.Lock()
|
_stdout_lock = threading.Lock()
|
||||||
_cfg_lock = threading.Lock()
|
_cfg_lock = threading.Lock()
|
||||||
|
_sessions_lock = threading.Lock()
|
||||||
|
_prompt_lock = threading.Lock()
|
||||||
_cfg_cache: dict | None = None
|
_cfg_cache: dict | None = None
|
||||||
_cfg_mtime: float | None = None
|
_cfg_mtime: float | None = None
|
||||||
_cfg_path = None
|
_cfg_path = None
|
||||||
@ -325,7 +327,9 @@ def _finalize_session(session: dict | None, end_reason: str = "tui_close") -> No
|
|||||||
|
|
||||||
|
|
||||||
def _shutdown_sessions() -> None:
|
def _shutdown_sessions() -> None:
|
||||||
for session in list(_sessions.values()):
|
with _sessions_lock:
|
||||||
|
snapshot = list(_sessions.values())
|
||||||
|
for session in snapshot:
|
||||||
_finalize_session(session, end_reason="tui_shutdown")
|
_finalize_session(session, end_reason="tui_shutdown")
|
||||||
try:
|
try:
|
||||||
worker = session.get("slash_worker")
|
worker = session.get("slash_worker")
|
||||||
@ -546,7 +550,8 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
|||||||
key = session["session_key"]
|
key = session["session_key"]
|
||||||
|
|
||||||
def _build() -> None:
|
def _build() -> None:
|
||||||
current = _sessions.get(sid)
|
with _sessions_lock:
|
||||||
|
current = _sessions.get(sid)
|
||||||
if current is None:
|
if current is None:
|
||||||
ready.set()
|
ready.set()
|
||||||
return
|
return
|
||||||
@ -585,7 +590,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
_wire_callbacks(sid)
|
_wire_callbacks(sid)
|
||||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
with _sessions_lock:
|
||||||
|
if sid in _sessions:
|
||||||
|
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||||
_notify_session_boundary("on_session_reset", key)
|
_notify_session_boundary("on_session_reset", key)
|
||||||
|
|
||||||
info = _session_info(agent, current)
|
info = _session_info(agent, current)
|
||||||
@ -598,7 +605,9 @@ def _start_agent_build(sid: str, session: dict) -> None:
|
|||||||
current["agent_error"] = str(e)
|
current["agent_error"] = str(e)
|
||||||
_emit("error", sid, {"message": f"agent init failed: {e}"})
|
_emit("error", sid, {"message": f"agent init failed: {e}"})
|
||||||
finally:
|
finally:
|
||||||
if _sessions.get(sid) is not current:
|
with _sessions_lock:
|
||||||
|
replaced = _sessions.get(sid) is not current
|
||||||
|
if replaced:
|
||||||
if worker is not None:
|
if worker is not None:
|
||||||
try:
|
try:
|
||||||
worker.close()
|
worker.close()
|
||||||
@ -819,9 +828,10 @@ def _cwd_for_session_key(session_key: str) -> str:
|
|||||||
"""
|
"""
|
||||||
if not session_key:
|
if not session_key:
|
||||||
return ""
|
return ""
|
||||||
for sess in list(_sessions.values()):
|
with _sessions_lock:
|
||||||
if sess.get("session_key") == session_key:
|
for sess in list(_sessions.values()):
|
||||||
return str(sess.get("cwd") or "")
|
if sess.get("session_key") == session_key:
|
||||||
|
return str(sess.get("cwd") or "")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@ -863,16 +873,19 @@ def _enable_gateway_prompts() -> None:
|
|||||||
def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str:
|
def _block(event: str, sid: str, payload: dict, timeout: int = 300) -> str:
|
||||||
rid = uuid.uuid4().hex[:8]
|
rid = uuid.uuid4().hex[:8]
|
||||||
ev = threading.Event()
|
ev = threading.Event()
|
||||||
_pending[rid] = (sid, ev)
|
with _prompt_lock:
|
||||||
payload["request_id"] = rid
|
_pending[rid] = (sid, ev)
|
||||||
_pending_prompt_payloads[rid] = (event, dict(payload))
|
payload["request_id"] = rid
|
||||||
|
_pending_prompt_payloads[rid] = (event, dict(payload))
|
||||||
try:
|
try:
|
||||||
_emit(event, sid, payload)
|
_emit(event, sid, payload)
|
||||||
ev.wait(timeout=timeout)
|
ev.wait(timeout=timeout)
|
||||||
finally:
|
finally:
|
||||||
_pending.pop(rid, None)
|
with _prompt_lock:
|
||||||
_pending_prompt_payloads.pop(rid, None)
|
_pending.pop(rid, None)
|
||||||
return _answers.pop(rid, "")
|
_pending_prompt_payloads.pop(rid, None)
|
||||||
|
with _prompt_lock:
|
||||||
|
return _answers.pop(rid, "")
|
||||||
|
|
||||||
|
|
||||||
def _clear_pending(sid: str | None = None) -> None:
|
def _clear_pending(sid: str | None = None) -> None:
|
||||||
@ -884,10 +897,11 @@ def _clear_pending(sid: str | None = None) -> None:
|
|||||||
sessions sharing the same tui_gateway process. When *sid* is
|
sessions sharing the same tui_gateway process. When *sid* is
|
||||||
None, every pending prompt is released (used during shutdown).
|
None, every pending prompt is released (used during shutdown).
|
||||||
"""
|
"""
|
||||||
for rid, (owner_sid, ev) in list(_pending.items()):
|
with _prompt_lock:
|
||||||
if sid is None or owner_sid == sid:
|
for rid, (owner_sid, ev) in list(_pending.items()):
|
||||||
_answers[rid] = ""
|
if sid is None or owner_sid == sid:
|
||||||
ev.set()
|
_answers[rid] = ""
|
||||||
|
ev.set()
|
||||||
|
|
||||||
|
|
||||||
# ── Agent factory ────────────────────────────────────────────────────
|
# ── Agent factory ────────────────────────────────────────────────────
|
||||||
@ -2402,34 +2416,37 @@ def _make_agent(sid: str, key: str, session_id: str | None = None):
|
|||||||
|
|
||||||
def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
|
def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
|
||||||
now = time.time()
|
now = time.time()
|
||||||
_sessions[sid] = {
|
with _sessions_lock:
|
||||||
"agent": agent,
|
_sessions[sid] = {
|
||||||
"session_key": key,
|
"agent": agent,
|
||||||
"history": history,
|
"session_key": key,
|
||||||
"history_lock": threading.Lock(),
|
"history": history,
|
||||||
"history_version": 0,
|
"history_lock": threading.Lock(),
|
||||||
"inflight_turn": None,
|
"history_version": 0,
|
||||||
"created_at": now,
|
"inflight_turn": None,
|
||||||
"last_active": now,
|
"created_at": now,
|
||||||
"running": False,
|
"last_active": now,
|
||||||
"attached_images": [],
|
"running": False,
|
||||||
"image_counter": 0,
|
"attached_images": [],
|
||||||
"cwd": _completion_cwd(),
|
"image_counter": 0,
|
||||||
"cols": cols,
|
"cwd": _completion_cwd(),
|
||||||
"slash_worker": None,
|
"cols": cols,
|
||||||
"show_reasoning": _load_show_reasoning(),
|
"slash_worker": None,
|
||||||
"tool_progress_mode": _load_tool_progress_mode(),
|
"show_reasoning": _load_show_reasoning(),
|
||||||
"edit_snapshots": {},
|
"tool_progress_mode": _load_tool_progress_mode(),
|
||||||
"tool_started_at": {},
|
"edit_snapshots": {},
|
||||||
# Pin async event emissions to whichever transport created the
|
"tool_started_at": {},
|
||||||
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
|
# Pin async event emissions to whichever transport created the
|
||||||
"transport": current_transport() or _stdio_transport,
|
# session (stdio for Ink, JSON-RPC WS for the dashboard sidebar).
|
||||||
}
|
"transport": current_transport() or _stdio_transport,
|
||||||
|
}
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
row = db.get_session(key)
|
row = db.get_session(key)
|
||||||
if row and row.get("cwd"):
|
if row and row.get("cwd"):
|
||||||
_sessions[sid]["cwd"] = row["cwd"]
|
with _sessions_lock:
|
||||||
|
if sid in _sessions:
|
||||||
|
_sessions[sid]["cwd"] = row["cwd"]
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
db.update_session_cwd(key, _sessions[sid]["cwd"])
|
db.update_session_cwd(key, _sessions[sid]["cwd"])
|
||||||
@ -2464,9 +2481,11 @@ def _init_session(sid: str, key: str, agent, history: list, cols: int = 80):
|
|||||||
# session startup resilient).
|
# session startup resilient).
|
||||||
pass
|
pass
|
||||||
_wire_callbacks(sid)
|
_wire_callbacks(sid)
|
||||||
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
with _sessions_lock:
|
||||||
|
if sid in _sessions:
|
||||||
|
_sessions[sid]["_notif_stop"] = _start_notification_poller(sid, _sessions[sid])
|
||||||
_notify_session_boundary("on_session_reset", key)
|
_notify_session_boundary("on_session_reset", key)
|
||||||
_emit("session.info", sid, _session_info(agent, _sessions[sid]))
|
_emit("session.info", sid, _session_info(agent, _sessions.get(sid, {})))
|
||||||
|
|
||||||
|
|
||||||
def _new_session_key() -> str:
|
def _new_session_key() -> str:
|
||||||
@ -2818,32 +2837,33 @@ def _(rid, params: dict) -> dict:
|
|||||||
ready = threading.Event()
|
ready = threading.Event()
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
_sessions[sid] = {
|
with _sessions_lock:
|
||||||
"agent": None,
|
_sessions[sid] = {
|
||||||
"agent_error": None,
|
"agent": None,
|
||||||
"agent_ready": ready,
|
"agent_error": None,
|
||||||
"attached_images": [],
|
"agent_ready": ready,
|
||||||
"cols": cols,
|
"attached_images": [],
|
||||||
"created_at": now,
|
"cols": cols,
|
||||||
"edit_snapshots": {},
|
"created_at": now,
|
||||||
"explicit_cwd": explicit_cwd,
|
"edit_snapshots": {},
|
||||||
"history": history,
|
"explicit_cwd": explicit_cwd,
|
||||||
"history_lock": threading.Lock(),
|
"history": history,
|
||||||
"history_version": 0,
|
"history_lock": threading.Lock(),
|
||||||
"image_counter": 0,
|
"history_version": 0,
|
||||||
"cwd": resolved_cwd,
|
"image_counter": 0,
|
||||||
"inflight_turn": None,
|
"cwd": resolved_cwd,
|
||||||
"last_active": now,
|
"inflight_turn": None,
|
||||||
"pending_title": title or None,
|
"last_active": now,
|
||||||
"running": False,
|
"pending_title": title or None,
|
||||||
"session_key": key,
|
"running": False,
|
||||||
"show_reasoning": _load_show_reasoning(),
|
"session_key": key,
|
||||||
"slash_worker": None,
|
"show_reasoning": _load_show_reasoning(),
|
||||||
"tool_progress_mode": _load_tool_progress_mode(),
|
"slash_worker": None,
|
||||||
"tool_started_at": {},
|
"tool_progress_mode": _load_tool_progress_mode(),
|
||||||
"transport": current_transport() or _stdio_transport,
|
"tool_started_at": {},
|
||||||
}
|
"transport": current_transport() or _stdio_transport,
|
||||||
_register_session_cwd(_sessions[sid])
|
}
|
||||||
|
_register_session_cwd(_sessions[sid])
|
||||||
# NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop
|
# NOTE: we intentionally do NOT persist a DB row here. Every TUI/desktop
|
||||||
# launch (and every "New agent" / draft) opens a session here just to paint
|
# launch (and every "New agent" / draft) opens a session here just to paint
|
||||||
# the composer, so eagerly creating a row left an "Untitled" empty session
|
# the composer, so eagerly creating a row left an "Untitled" empty session
|
||||||
@ -3233,7 +3253,8 @@ def _(rid, params: dict) -> dict:
|
|||||||
"""
|
"""
|
||||||
current = str(params.get("current_session_id") or "")
|
current = str(params.get("current_session_id") or "")
|
||||||
try:
|
try:
|
||||||
snapshot = list(_sessions.items())
|
with _sessions_lock:
|
||||||
|
snapshot = list(_sessions.items())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
||||||
|
|
||||||
@ -3287,7 +3308,8 @@ def _(rid, params: dict) -> dict:
|
|||||||
# dictionary changed size during iteration``. If even the snapshot
|
# dictionary changed size during iteration``. If even the snapshot
|
||||||
# raises, fail closed (refuse the delete) rather than fail open.
|
# raises, fail closed (refuse the delete) rather than fail open.
|
||||||
try:
|
try:
|
||||||
snapshot = list(_sessions.values())
|
with _sessions_lock:
|
||||||
|
snapshot = list(_sessions.values())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
return _err(rid, 5036, f"could not enumerate active sessions: {e}")
|
||||||
active = {s.get("session_key") for s in snapshot if s.get("session_key")}
|
active = {s.get("session_key") for s in snapshot if s.get("session_key")}
|
||||||
@ -3644,11 +3666,13 @@ def _(rid, params: dict) -> dict:
|
|||||||
@method("session.close")
|
@method("session.close")
|
||||||
def _(rid, params: dict) -> dict:
|
def _(rid, params: dict) -> dict:
|
||||||
sid = params.get("session_id", "")
|
sid = params.get("session_id", "")
|
||||||
current = _sessions.get(sid)
|
with _sessions_lock:
|
||||||
|
current = _sessions.get(sid)
|
||||||
if not current:
|
if not current:
|
||||||
return _ok(rid, {"closed": False})
|
return _ok(rid, {"closed": False})
|
||||||
with _session_resume_lock:
|
with _session_resume_lock:
|
||||||
session = _sessions.pop(sid, None)
|
with _sessions_lock:
|
||||||
|
session = _sessions.pop(sid, None)
|
||||||
if not session:
|
if not session:
|
||||||
return _ok(rid, {"closed": False})
|
return _ok(rid, {"closed": False})
|
||||||
_finalize_session(session)
|
_finalize_session(session)
|
||||||
@ -4097,7 +4121,8 @@ def _notification_event_belongs_elsewhere(session: dict, evt: dict) -> bool:
|
|||||||
if evt_key == str(session.get("session_key") or ""):
|
if evt_key == str(session.get("session_key") or ""):
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
snapshot = list(_sessions.values())
|
with _sessions_lock:
|
||||||
|
snapshot = list(_sessions.values())
|
||||||
except Exception:
|
except Exception:
|
||||||
# If we can't safely enumerate live sessions, fail open so we don't
|
# If we can't safely enumerate live sessions, fail open so we don't
|
||||||
# crash the poller thread or drop the event.
|
# crash the poller thread or drop the event.
|
||||||
@ -5000,12 +5025,13 @@ def _(rid, params: dict) -> dict:
|
|||||||
|
|
||||||
def _respond(rid, params, key):
|
def _respond(rid, params, key):
|
||||||
r = params.get("request_id", "")
|
r = params.get("request_id", "")
|
||||||
entry = _pending.get(r)
|
with _prompt_lock:
|
||||||
if not entry:
|
entry = _pending.get(r)
|
||||||
return _err(rid, 4009, f"no pending {key} request")
|
if not entry:
|
||||||
_, ev = entry
|
return _err(rid, 4009, f"no pending {key} request")
|
||||||
_answers[r] = params.get(key, "")
|
_, ev = entry
|
||||||
ev.set()
|
_answers[r] = params.get(key, "")
|
||||||
|
ev.set()
|
||||||
return _ok(rid, {"status": "ok"})
|
return _ok(rid, {"status": "ok"})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user