fix(gateway): /stop can interrupt a sibling participant's run in a per-user thread (#35959)

In a per-user thread (thread_sessions_per_user=True), each participant
gets an isolated session key (...:{thread_id}:{user_id}). A run another
user started lives under a different key, so the caller's own /stop found
nothing and replied 'no active task to stop'.

When /stop finds no run under the caller's own key, fall back to
interrupting any running agent(s) sharing the caller's thread prefix
({chat_id}:{thread_id}), gated on _is_user_authorized. Thread-only — the
fallback returns [] for non-thread channels, and a prefix-collision guard
prevents thr1 from matching thr11.
This commit is contained in:
Teknium
2026-05-31 09:29:03 -07:00
committed by GitHub
parent de4f40ed02
commit 1044d9f25d
2 changed files with 222 additions and 2 deletions

View File

@ -10124,6 +10124,45 @@ class GatewayRunner:
return "\n".join(lines)
def _sibling_thread_run_keys(self, source: SessionSource, own_key: str) -> list:
"""Find running-agent keys for OTHER participants in the same thread.
Only applies when the message originates in a thread. In per-user
thread mode (``thread_sessions_per_user=True``) each participant gets
an isolated session key of the form
``agent:main:{platform}:{chat_type}:{chat_id}:{thread_id}:{user_id}``,
so a run started by another user is invisible to the caller's own
``/stop``. This returns the keys of any *actually running* agents
(not the pending sentinel, not the caller's own key) whose key shares
the caller's ``{chat_id}:{thread_id}`` prefix.
Returns an empty list when the source is not in a thread, or when no
sibling runs exist — callers must still gate on authorization.
"""
thread_id = getattr(source, "thread_id", None)
chat_id = getattr(source, "chat_id", None)
if not thread_id or not chat_id:
return []
platform = source.platform.value
chat_type = getattr(source, "chat_type", None) or ""
# Prefix that every per-user key in this thread shares, up to and
# including the thread_id segment. Matching either the exact
# shared-thread key or any key with a further (user_id) segment
# (prefix + ":") avoids cross-matching an unrelated thread whose id
# merely starts with this one.
prefix = ":".join(
["agent:main", platform, chat_type, str(chat_id), str(thread_id)]
)
matches = []
for key, agent in list(self._running_agents.items()):
if key == own_key:
continue
if agent is _AGENT_PENDING_SENTINEL or not agent:
continue
if key == prefix or key.startswith(prefix + ":"):
matches.append(key)
return matches
async def _handle_stop_command(self, event: MessageEvent) -> Union[str, EphemeralReply]:
"""Handle /stop command - interrupt a running agent.
@ -10160,8 +10199,31 @@ class GatewayRunner:
invalidation_reason="stop_command_handler",
)
return EphemeralReply(t("gateway.stop.stopped"))
else:
return t("gateway.stop.no_active")
# No run under the caller's own session key. In a per-user thread
# (thread_sessions_per_user=True) each participant is isolated even
# inside one shared thread, so a run another user started lives under
# a different key. Authorized users should still be able to /stop it
# (#bernard-thread-stop). Fall back to interrupting any running
# agent(s) that share this thread, gated on authorization.
sibling_keys = self._sibling_thread_run_keys(source, session_key)
if sibling_keys and self._is_user_authorized(source):
for sibling_key in sibling_keys:
await self._interrupt_and_clear_session(
sibling_key,
source,
interrupt_reason=_INTERRUPT_REASON_STOP,
invalidation_reason="stop_command_thread_sibling",
)
logger.info(
"STOP (thread sibling) by %s — interrupted %d run(s) in thread: %s",
session_key,
len(sibling_keys),
", ".join(sibling_keys),
)
return EphemeralReply(t("gateway.stop.stopped"))
return t("gateway.stop.no_active")
async def _handle_platform_command(self, event: MessageEvent) -> str:
"""Handle ``/platform list|pause|resume [name]`` — surface and

View File

@ -0,0 +1,158 @@
"""Regression tests: /stop can interrupt a sibling participant's run in a
per-user thread.
When ``thread_sessions_per_user=True``, each participant in a thread gets an
isolated session key (``...:{thread_id}:{user_id}``). A run another user
started lives under a different key, so the caller's own ``/stop`` used to find
nothing and reply "no active task to stop". Authorized users should be able to
stop any run in the same thread.
"""
import pytest
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL, _INTERRUPT_REASON_STOP
from gateway.session import SessionSource, build_session_key
from gateway.platforms.base import Platform, MessageEvent, MessageType
class _FakeAgent:
pass
def _thread_source(uid, thread_id="thr1", chat_id="chan1"):
return SessionSource(
platform=Platform.DISCORD,
chat_type="forum",
chat_id=chat_id,
thread_id=thread_id,
user_id=uid,
)
def _per_user_key(uid, thread_id="thr1", chat_id="chan1"):
return build_session_key(
_thread_source(uid, thread_id, chat_id),
thread_sessions_per_user=True,
)
# ---------------------------------------------------------------------------
# _sibling_thread_run_keys
# ---------------------------------------------------------------------------
def test_sibling_finds_other_users_run_in_same_thread():
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA")
key_b = _per_user_key("userB")
runner._running_agents = {key_b: _FakeAgent()}
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b]
def test_sibling_excludes_callers_own_key():
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA")
key_b = _per_user_key("userB")
runner._running_agents = {key_a: _FakeAgent(), key_b: _FakeAgent()}
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b]
def test_sibling_skips_pending_sentinel():
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA")
key_b = _per_user_key("userB")
runner._running_agents = {key_b: _AGENT_PENDING_SENTINEL}
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == []
def test_sibling_does_not_match_different_thread_same_chat():
# thr1 caller must not match a run in thr11 (prefix-collision guard).
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA", thread_id="thr1")
key_b_other = _per_user_key("userB", thread_id="thr11")
runner._running_agents = {key_b_other: _FakeAgent()}
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == []
def test_sibling_returns_empty_for_non_thread_source():
# Non-thread group/channel must NOT trigger the cross-user fallback.
runner = object.__new__(GatewayRunner)
nonthread = SessionSource(
platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userA"
)
grp_b = build_session_key(
SessionSource(
platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userB"
)
)
runner._running_agents = {grp_b: _FakeAgent()}
assert runner._sibling_thread_run_keys(nonthread, "agent:main:discord:group:chan1:userA") == []
# ---------------------------------------------------------------------------
# _handle_stop_command fallback path
# ---------------------------------------------------------------------------
class _StoreEntry:
def __init__(self, session_key):
self.session_key = session_key
class _FakeStore:
def __init__(self, session_key):
self._key = session_key
def get_or_create_session(self, source):
return _StoreEntry(self._key)
@pytest.mark.asyncio
async def test_stop_interrupts_sibling_thread_run_when_authorized(monkeypatch):
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA")
key_b = _per_user_key("userB")
runner._running_agents = {key_b: _FakeAgent()}
runner.session_store = _FakeStore(key_a)
interrupted = []
async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason):
interrupted.append((session_key, interrupt_reason, invalidation_reason))
runner._interrupt_and_clear_session = _fake_interrupt
runner._is_user_authorized = lambda source: True
event = MessageEvent(
text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA")
)
result = await runner._handle_stop_command(event)
assert interrupted == [(key_b, _INTERRUPT_REASON_STOP, "stop_command_thread_sibling")]
# EphemeralReply or str — both carry the "stopped" message, not "no_active".
assert "no active" not in str(getattr(result, "text", result)).lower()
@pytest.mark.asyncio
async def test_stop_does_not_interrupt_sibling_when_unauthorized(monkeypatch):
runner = object.__new__(GatewayRunner)
key_a = _per_user_key("userA")
key_b = _per_user_key("userB")
runner._running_agents = {key_b: _FakeAgent()}
runner.session_store = _FakeStore(key_a)
interrupted = []
async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason):
interrupted.append(session_key)
runner._interrupt_and_clear_session = _fake_interrupt
runner._is_user_authorized = lambda source: False
event = MessageEvent(
text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA")
)
result = await runner._handle_stop_command(event)
assert interrupted == []
assert "no active" in str(getattr(result, "text", result)).lower()