fix(gateway): /stop can interrupt a sibling participant's run in a per-user thread (#35959)
In a per-user thread (thread_sessions_per_user=True), each participant
gets an isolated session key (...:{thread_id}:{user_id}). A run another
user started lives under a different key, so the caller's own /stop found
nothing and replied 'no active task to stop'.
When /stop finds no run under the caller's own key, fall back to
interrupting any running agent(s) sharing the caller's thread prefix
({chat_id}:{thread_id}), gated on _is_user_authorized. Thread-only — the
fallback returns [] for non-thread channels, and a prefix-collision guard
prevents thr1 from matching thr11.
This commit is contained in:
@ -10124,6 +10124,45 @@ class GatewayRunner:
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _sibling_thread_run_keys(self, source: SessionSource, own_key: str) -> list:
|
||||
"""Find running-agent keys for OTHER participants in the same thread.
|
||||
|
||||
Only applies when the message originates in a thread. In per-user
|
||||
thread mode (``thread_sessions_per_user=True``) each participant gets
|
||||
an isolated session key of the form
|
||||
``agent:main:{platform}:{chat_type}:{chat_id}:{thread_id}:{user_id}``,
|
||||
so a run started by another user is invisible to the caller's own
|
||||
``/stop``. This returns the keys of any *actually running* agents
|
||||
(not the pending sentinel, not the caller's own key) whose key shares
|
||||
the caller's ``{chat_id}:{thread_id}`` prefix.
|
||||
|
||||
Returns an empty list when the source is not in a thread, or when no
|
||||
sibling runs exist — callers must still gate on authorization.
|
||||
"""
|
||||
thread_id = getattr(source, "thread_id", None)
|
||||
chat_id = getattr(source, "chat_id", None)
|
||||
if not thread_id or not chat_id:
|
||||
return []
|
||||
platform = source.platform.value
|
||||
chat_type = getattr(source, "chat_type", None) or ""
|
||||
# Prefix that every per-user key in this thread shares, up to and
|
||||
# including the thread_id segment. Matching either the exact
|
||||
# shared-thread key or any key with a further (user_id) segment
|
||||
# (prefix + ":") avoids cross-matching an unrelated thread whose id
|
||||
# merely starts with this one.
|
||||
prefix = ":".join(
|
||||
["agent:main", platform, chat_type, str(chat_id), str(thread_id)]
|
||||
)
|
||||
matches = []
|
||||
for key, agent in list(self._running_agents.items()):
|
||||
if key == own_key:
|
||||
continue
|
||||
if agent is _AGENT_PENDING_SENTINEL or not agent:
|
||||
continue
|
||||
if key == prefix or key.startswith(prefix + ":"):
|
||||
matches.append(key)
|
||||
return matches
|
||||
|
||||
async def _handle_stop_command(self, event: MessageEvent) -> Union[str, EphemeralReply]:
|
||||
"""Handle /stop command - interrupt a running agent.
|
||||
|
||||
@ -10160,8 +10199,31 @@ class GatewayRunner:
|
||||
invalidation_reason="stop_command_handler",
|
||||
)
|
||||
return EphemeralReply(t("gateway.stop.stopped"))
|
||||
else:
|
||||
return t("gateway.stop.no_active")
|
||||
|
||||
# No run under the caller's own session key. In a per-user thread
|
||||
# (thread_sessions_per_user=True) each participant is isolated even
|
||||
# inside one shared thread, so a run another user started lives under
|
||||
# a different key. Authorized users should still be able to /stop it
|
||||
# (#bernard-thread-stop). Fall back to interrupting any running
|
||||
# agent(s) that share this thread, gated on authorization.
|
||||
sibling_keys = self._sibling_thread_run_keys(source, session_key)
|
||||
if sibling_keys and self._is_user_authorized(source):
|
||||
for sibling_key in sibling_keys:
|
||||
await self._interrupt_and_clear_session(
|
||||
sibling_key,
|
||||
source,
|
||||
interrupt_reason=_INTERRUPT_REASON_STOP,
|
||||
invalidation_reason="stop_command_thread_sibling",
|
||||
)
|
||||
logger.info(
|
||||
"STOP (thread sibling) by %s — interrupted %d run(s) in thread: %s",
|
||||
session_key,
|
||||
len(sibling_keys),
|
||||
", ".join(sibling_keys),
|
||||
)
|
||||
return EphemeralReply(t("gateway.stop.stopped"))
|
||||
|
||||
return t("gateway.stop.no_active")
|
||||
|
||||
async def _handle_platform_command(self, event: MessageEvent) -> str:
|
||||
"""Handle ``/platform list|pause|resume [name]`` — surface and
|
||||
|
||||
158
tests/gateway/test_stop_thread_sibling.py
Normal file
158
tests/gateway/test_stop_thread_sibling.py
Normal file
@ -0,0 +1,158 @@
|
||||
"""Regression tests: /stop can interrupt a sibling participant's run in a
|
||||
per-user thread.
|
||||
|
||||
When ``thread_sessions_per_user=True``, each participant in a thread gets an
|
||||
isolated session key (``...:{thread_id}:{user_id}``). A run another user
|
||||
started lives under a different key, so the caller's own ``/stop`` used to find
|
||||
nothing and reply "no active task to stop". Authorized users should be able to
|
||||
stop any run in the same thread.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL, _INTERRUPT_REASON_STOP
|
||||
from gateway.session import SessionSource, build_session_key
|
||||
from gateway.platforms.base import Platform, MessageEvent, MessageType
|
||||
|
||||
|
||||
class _FakeAgent:
|
||||
pass
|
||||
|
||||
|
||||
def _thread_source(uid, thread_id="thr1", chat_id="chan1"):
|
||||
return SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_type="forum",
|
||||
chat_id=chat_id,
|
||||
thread_id=thread_id,
|
||||
user_id=uid,
|
||||
)
|
||||
|
||||
|
||||
def _per_user_key(uid, thread_id="thr1", chat_id="chan1"):
|
||||
return build_session_key(
|
||||
_thread_source(uid, thread_id, chat_id),
|
||||
thread_sessions_per_user=True,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _sibling_thread_run_keys
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sibling_finds_other_users_run_in_same_thread():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA")
|
||||
key_b = _per_user_key("userB")
|
||||
runner._running_agents = {key_b: _FakeAgent()}
|
||||
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b]
|
||||
|
||||
|
||||
def test_sibling_excludes_callers_own_key():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA")
|
||||
key_b = _per_user_key("userB")
|
||||
runner._running_agents = {key_a: _FakeAgent(), key_b: _FakeAgent()}
|
||||
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b]
|
||||
|
||||
|
||||
def test_sibling_skips_pending_sentinel():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA")
|
||||
key_b = _per_user_key("userB")
|
||||
runner._running_agents = {key_b: _AGENT_PENDING_SENTINEL}
|
||||
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == []
|
||||
|
||||
|
||||
def test_sibling_does_not_match_different_thread_same_chat():
|
||||
# thr1 caller must not match a run in thr11 (prefix-collision guard).
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA", thread_id="thr1")
|
||||
key_b_other = _per_user_key("userB", thread_id="thr11")
|
||||
runner._running_agents = {key_b_other: _FakeAgent()}
|
||||
assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == []
|
||||
|
||||
|
||||
def test_sibling_returns_empty_for_non_thread_source():
|
||||
# Non-thread group/channel must NOT trigger the cross-user fallback.
|
||||
runner = object.__new__(GatewayRunner)
|
||||
nonthread = SessionSource(
|
||||
platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userA"
|
||||
)
|
||||
grp_b = build_session_key(
|
||||
SessionSource(
|
||||
platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userB"
|
||||
)
|
||||
)
|
||||
runner._running_agents = {grp_b: _FakeAgent()}
|
||||
assert runner._sibling_thread_run_keys(nonthread, "agent:main:discord:group:chan1:userA") == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_stop_command fallback path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StoreEntry:
|
||||
def __init__(self, session_key):
|
||||
self.session_key = session_key
|
||||
|
||||
|
||||
class _FakeStore:
|
||||
def __init__(self, session_key):
|
||||
self._key = session_key
|
||||
|
||||
def get_or_create_session(self, source):
|
||||
return _StoreEntry(self._key)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_interrupts_sibling_thread_run_when_authorized(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA")
|
||||
key_b = _per_user_key("userB")
|
||||
runner._running_agents = {key_b: _FakeAgent()}
|
||||
runner.session_store = _FakeStore(key_a)
|
||||
|
||||
interrupted = []
|
||||
|
||||
async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason):
|
||||
interrupted.append((session_key, interrupt_reason, invalidation_reason))
|
||||
|
||||
runner._interrupt_and_clear_session = _fake_interrupt
|
||||
runner._is_user_authorized = lambda source: True
|
||||
|
||||
event = MessageEvent(
|
||||
text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA")
|
||||
)
|
||||
result = await runner._handle_stop_command(event)
|
||||
|
||||
assert interrupted == [(key_b, _INTERRUPT_REASON_STOP, "stop_command_thread_sibling")]
|
||||
# EphemeralReply or str — both carry the "stopped" message, not "no_active".
|
||||
assert "no active" not in str(getattr(result, "text", result)).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_does_not_interrupt_sibling_when_unauthorized(monkeypatch):
|
||||
runner = object.__new__(GatewayRunner)
|
||||
key_a = _per_user_key("userA")
|
||||
key_b = _per_user_key("userB")
|
||||
runner._running_agents = {key_b: _FakeAgent()}
|
||||
runner.session_store = _FakeStore(key_a)
|
||||
|
||||
interrupted = []
|
||||
|
||||
async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason):
|
||||
interrupted.append(session_key)
|
||||
|
||||
runner._interrupt_and_clear_session = _fake_interrupt
|
||||
runner._is_user_authorized = lambda source: False
|
||||
|
||||
event = MessageEvent(
|
||||
text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA")
|
||||
)
|
||||
result = await runner._handle_stop_command(event)
|
||||
|
||||
assert interrupted == []
|
||||
assert "no active" in str(getattr(result, "text", result)).lower()
|
||||
Reference in New Issue
Block a user