diff --git a/gateway/run.py b/gateway/run.py index 55d1131f1..9d3f1019e 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -10124,6 +10124,45 @@ class GatewayRunner: return "\n".join(lines) + def _sibling_thread_run_keys(self, source: SessionSource, own_key: str) -> list: + """Find running-agent keys for OTHER participants in the same thread. + + Only applies when the message originates in a thread. In per-user + thread mode (``thread_sessions_per_user=True``) each participant gets + an isolated session key of the form + ``agent:main:{platform}:{chat_type}:{chat_id}:{thread_id}:{user_id}``, + so a run started by another user is invisible to the caller's own + ``/stop``. This returns the keys of any *actually running* agents + (not the pending sentinel, not the caller's own key) whose key shares + the caller's ``{chat_id}:{thread_id}`` prefix. + + Returns an empty list when the source is not in a thread, or when no + sibling runs exist — callers must still gate on authorization. + """ + thread_id = getattr(source, "thread_id", None) + chat_id = getattr(source, "chat_id", None) + if not thread_id or not chat_id: + return [] + platform = source.platform.value + chat_type = getattr(source, "chat_type", None) or "" + # Prefix that every per-user key in this thread shares, up to and + # including the thread_id segment. Matching either the exact + # shared-thread key or any key with a further (user_id) segment + # (prefix + ":") avoids cross-matching an unrelated thread whose id + # merely starts with this one. + prefix = ":".join( + ["agent:main", platform, chat_type, str(chat_id), str(thread_id)] + ) + matches = [] + for key, agent in list(self._running_agents.items()): + if key == own_key: + continue + if agent is _AGENT_PENDING_SENTINEL or not agent: + continue + if key == prefix or key.startswith(prefix + ":"): + matches.append(key) + return matches + async def _handle_stop_command(self, event: MessageEvent) -> Union[str, EphemeralReply]: """Handle /stop command - interrupt a running agent. @@ -10160,8 +10199,31 @@ class GatewayRunner: invalidation_reason="stop_command_handler", ) return EphemeralReply(t("gateway.stop.stopped")) - else: - return t("gateway.stop.no_active") + + # No run under the caller's own session key. In a per-user thread + # (thread_sessions_per_user=True) each participant is isolated even + # inside one shared thread, so a run another user started lives under + # a different key. Authorized users should still be able to /stop it + # (#bernard-thread-stop). Fall back to interrupting any running + # agent(s) that share this thread, gated on authorization. + sibling_keys = self._sibling_thread_run_keys(source, session_key) + if sibling_keys and self._is_user_authorized(source): + for sibling_key in sibling_keys: + await self._interrupt_and_clear_session( + sibling_key, + source, + interrupt_reason=_INTERRUPT_REASON_STOP, + invalidation_reason="stop_command_thread_sibling", + ) + logger.info( + "STOP (thread sibling) by %s — interrupted %d run(s) in thread: %s", + session_key, + len(sibling_keys), + ", ".join(sibling_keys), + ) + return EphemeralReply(t("gateway.stop.stopped")) + + return t("gateway.stop.no_active") async def _handle_platform_command(self, event: MessageEvent) -> str: """Handle ``/platform list|pause|resume [name]`` — surface and diff --git a/tests/gateway/test_stop_thread_sibling.py b/tests/gateway/test_stop_thread_sibling.py new file mode 100644 index 000000000..d8076ba6e --- /dev/null +++ b/tests/gateway/test_stop_thread_sibling.py @@ -0,0 +1,158 @@ +"""Regression tests: /stop can interrupt a sibling participant's run in a +per-user thread. + +When ``thread_sessions_per_user=True``, each participant in a thread gets an +isolated session key (``...:{thread_id}:{user_id}``). A run another user +started lives under a different key, so the caller's own ``/stop`` used to find +nothing and reply "no active task to stop". Authorized users should be able to +stop any run in the same thread. +""" + +import pytest + +from gateway.run import GatewayRunner, _AGENT_PENDING_SENTINEL, _INTERRUPT_REASON_STOP +from gateway.session import SessionSource, build_session_key +from gateway.platforms.base import Platform, MessageEvent, MessageType + + +class _FakeAgent: + pass + + +def _thread_source(uid, thread_id="thr1", chat_id="chan1"): + return SessionSource( + platform=Platform.DISCORD, + chat_type="forum", + chat_id=chat_id, + thread_id=thread_id, + user_id=uid, + ) + + +def _per_user_key(uid, thread_id="thr1", chat_id="chan1"): + return build_session_key( + _thread_source(uid, thread_id, chat_id), + thread_sessions_per_user=True, + ) + + +# --------------------------------------------------------------------------- +# _sibling_thread_run_keys +# --------------------------------------------------------------------------- + + +def test_sibling_finds_other_users_run_in_same_thread(): + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA") + key_b = _per_user_key("userB") + runner._running_agents = {key_b: _FakeAgent()} + assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b] + + +def test_sibling_excludes_callers_own_key(): + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA") + key_b = _per_user_key("userB") + runner._running_agents = {key_a: _FakeAgent(), key_b: _FakeAgent()} + assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [key_b] + + +def test_sibling_skips_pending_sentinel(): + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA") + key_b = _per_user_key("userB") + runner._running_agents = {key_b: _AGENT_PENDING_SENTINEL} + assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [] + + +def test_sibling_does_not_match_different_thread_same_chat(): + # thr1 caller must not match a run in thr11 (prefix-collision guard). + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA", thread_id="thr1") + key_b_other = _per_user_key("userB", thread_id="thr11") + runner._running_agents = {key_b_other: _FakeAgent()} + assert runner._sibling_thread_run_keys(_thread_source("userA"), key_a) == [] + + +def test_sibling_returns_empty_for_non_thread_source(): + # Non-thread group/channel must NOT trigger the cross-user fallback. + runner = object.__new__(GatewayRunner) + nonthread = SessionSource( + platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userA" + ) + grp_b = build_session_key( + SessionSource( + platform=Platform.DISCORD, chat_type="group", chat_id="chan1", user_id="userB" + ) + ) + runner._running_agents = {grp_b: _FakeAgent()} + assert runner._sibling_thread_run_keys(nonthread, "agent:main:discord:group:chan1:userA") == [] + + +# --------------------------------------------------------------------------- +# _handle_stop_command fallback path +# --------------------------------------------------------------------------- + + +class _StoreEntry: + def __init__(self, session_key): + self.session_key = session_key + + +class _FakeStore: + def __init__(self, session_key): + self._key = session_key + + def get_or_create_session(self, source): + return _StoreEntry(self._key) + + +@pytest.mark.asyncio +async def test_stop_interrupts_sibling_thread_run_when_authorized(monkeypatch): + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA") + key_b = _per_user_key("userB") + runner._running_agents = {key_b: _FakeAgent()} + runner.session_store = _FakeStore(key_a) + + interrupted = [] + + async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason): + interrupted.append((session_key, interrupt_reason, invalidation_reason)) + + runner._interrupt_and_clear_session = _fake_interrupt + runner._is_user_authorized = lambda source: True + + event = MessageEvent( + text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA") + ) + result = await runner._handle_stop_command(event) + + assert interrupted == [(key_b, _INTERRUPT_REASON_STOP, "stop_command_thread_sibling")] + # EphemeralReply or str — both carry the "stopped" message, not "no_active". + assert "no active" not in str(getattr(result, "text", result)).lower() + + +@pytest.mark.asyncio +async def test_stop_does_not_interrupt_sibling_when_unauthorized(monkeypatch): + runner = object.__new__(GatewayRunner) + key_a = _per_user_key("userA") + key_b = _per_user_key("userB") + runner._running_agents = {key_b: _FakeAgent()} + runner.session_store = _FakeStore(key_a) + + interrupted = [] + + async def _fake_interrupt(session_key, source, *, interrupt_reason, invalidation_reason): + interrupted.append(session_key) + + runner._interrupt_and_clear_session = _fake_interrupt + runner._is_user_authorized = lambda source: False + + event = MessageEvent( + text="/stop", message_type=MessageType.TEXT, source=_thread_source("userA") + ) + result = await runner._handle_stop_command(event) + + assert interrupted == [] + assert "no active" in str(getattr(result, "text", result)).lower()