Merge pull request #37697 from NousResearch/bb/grok-provider-desktop
feat(desktop): make xAI Grok a first-class OAuth provider in the launcher
This commit is contained in:
@ -107,8 +107,9 @@ const PROVIDER_DISPLAY: Record<string, { order: number; title: string }> = {
|
||||
anthropic: { order: 1, title: 'Anthropic Claude' },
|
||||
'openai-codex': { order: 2, title: 'OpenAI Codex / ChatGPT' },
|
||||
'minimax-oauth': { order: 3, title: 'MiniMax' },
|
||||
'claude-code': { order: 4, title: 'Claude Code' },
|
||||
'qwen-oauth': { order: 5, title: 'Qwen Code' }
|
||||
'xai-oauth': { order: 4, title: 'xAI Grok' },
|
||||
'claude-code': { order: 5, title: 'Claude Code' },
|
||||
'qwen-oauth': { order: 6, title: 'Qwen Code' }
|
||||
}
|
||||
|
||||
const assetPath = (path: string) => `${import.meta.env.BASE_URL}${path.replace(/^\/+/, '')}`
|
||||
@ -116,6 +117,7 @@ const assetPath = (path: string) => `${import.meta.env.BASE_URL}${path.replace(/
|
||||
const FLOW_SUBTITLES: Record<OAuthProvider['flow'], string> = {
|
||||
pkce: 'Opens your browser to sign in, then continues here',
|
||||
device_code: 'Opens a verification page in your browser — Hermes connects automatically',
|
||||
loopback: 'Opens your browser to sign in — Hermes connects automatically',
|
||||
external: 'Sign in once in your terminal, then come back to chat'
|
||||
}
|
||||
|
||||
@ -565,6 +567,24 @@ function FlowPanel({ ctx, flow }: { ctx: OnboardingContext; flow: OnboardingFlow
|
||||
)
|
||||
}
|
||||
|
||||
if (flow.status === 'awaiting_browser') {
|
||||
return (
|
||||
<Step title={`Sign in with ${title}`}>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
We opened {title} in your browser. Authorize Hermes there and you'll be connected
|
||||
automatically — nothing to copy or paste.
|
||||
</p>
|
||||
<FlowFooter left={<DocsLink href={flow.start.auth_url}>Re-open sign-in page</DocsLink>}>
|
||||
<span className="flex items-center gap-2 text-xs text-muted-foreground">
|
||||
<Loader2 className="size-3 animate-spin" />
|
||||
Waiting for you to authorize...
|
||||
</span>
|
||||
<CancelBtn size="sm" />
|
||||
</FlowFooter>
|
||||
</Step>
|
||||
)
|
||||
}
|
||||
|
||||
if (flow.status === 'external_pending') {
|
||||
return (
|
||||
<Step title={`Sign in with ${title}`}>
|
||||
|
||||
@ -18,6 +18,7 @@ import type { ModelOptionProvider, OAuthProvider, OAuthStartResponse } from '@/t
|
||||
|
||||
type PkceStart = Extract<OAuthStartResponse, { flow: 'pkce' }>
|
||||
type DeviceStart = Extract<OAuthStartResponse, { flow: 'device_code' }>
|
||||
type LoopbackStart = Extract<OAuthStartResponse, { flow: 'loopback' }>
|
||||
|
||||
export type OnboardingMode = 'apikey' | 'oauth'
|
||||
|
||||
@ -26,6 +27,10 @@ export type OnboardingFlow =
|
||||
| { provider: OAuthProvider; status: 'starting' }
|
||||
| { code: string; provider: OAuthProvider; start: PkceStart; status: 'awaiting_user' }
|
||||
| { copied: boolean; provider: OAuthProvider; start: DeviceStart; status: 'polling' }
|
||||
// Loopback PKCE (xAI Grok): browser opens, the local backend's 127.0.0.1
|
||||
// listener catches the redirect, and we poll until the worker finishes.
|
||||
// No code to paste and no user_code to show — just a waiting state.
|
||||
| { provider: OAuthProvider; start: LoopbackStart; status: 'awaiting_browser' }
|
||||
| { provider: OAuthProvider; start: OAuthStartResponse; status: 'submitting' }
|
||||
| { copied: boolean; provider: OAuthProvider; status: 'external_pending' }
|
||||
| { provider: OAuthProvider; status: 'success' }
|
||||
@ -406,6 +411,26 @@ export async function refreshOnboarding(ctx: OnboardingContext) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Open a sign-in URL via the desktop bridge, falling back to window.open
|
||||
// when the bridge isn't present (e.g. the web dashboard / dev preview) so
|
||||
// the flow never silently stalls in a waiting state. Mirrors the pattern in
|
||||
// apps/desktop/src/app/artifacts/index.tsx.
|
||||
async function openSignInUrl(url: string) {
|
||||
if (window.hermesDesktop?.openExternal) {
|
||||
try {
|
||||
await window.hermesDesktop.openExternal(url)
|
||||
|
||||
return
|
||||
} catch {
|
||||
// Bridge present but failed (no OS handler, user denied, etc.). Fall
|
||||
// through to window.open so the sign-in URL still opens and the flow
|
||||
// doesn't strand a pending OAuth session in a waiting state.
|
||||
}
|
||||
}
|
||||
|
||||
window.open(url, '_blank', 'noopener,noreferrer')
|
||||
}
|
||||
|
||||
export async function startProviderOAuth(provider: OAuthProvider, ctx: OnboardingContext) {
|
||||
clearPoll()
|
||||
|
||||
@ -419,7 +444,8 @@ export async function startProviderOAuth(provider: OAuthProvider, ctx: Onboardin
|
||||
|
||||
try {
|
||||
const start = await startOAuthLogin(provider.id)
|
||||
await window.hermesDesktop?.openExternal(start.flow === 'pkce' ? start.auth_url : start.verification_url)
|
||||
const browserUrl = start.flow === 'device_code' ? start.verification_url : start.auth_url
|
||||
await openSignInUrl(browserUrl)
|
||||
|
||||
if (start.flow === 'pkce') {
|
||||
setFlow({ status: 'awaiting_user', provider, start, code: '' })
|
||||
@ -427,14 +453,26 @@ export async function startProviderOAuth(provider: OAuthProvider, ctx: Onboardin
|
||||
return
|
||||
}
|
||||
|
||||
if (start.flow === 'loopback') {
|
||||
// No code to paste: the redirect lands on the backend's loopback
|
||||
// listener. Just wait and poll the session until the worker finishes.
|
||||
setFlow({ status: 'awaiting_browser', provider, start })
|
||||
pollTimer = window.setInterval(() => void pollSession(provider, start, ctx), POLL_MS)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
setFlow({ status: 'polling', provider, start, copied: false })
|
||||
pollTimer = window.setInterval(() => void pollDevice(provider, start, ctx), POLL_MS)
|
||||
pollTimer = window.setInterval(() => void pollSession(provider, start, ctx), POLL_MS)
|
||||
} catch (error) {
|
||||
setFlow({ status: 'error', provider, message: `Could not start sign-in: ${errMessage(error)}` })
|
||||
}
|
||||
}
|
||||
|
||||
async function pollDevice(provider: OAuthProvider, start: DeviceStart, ctx: OnboardingContext) {
|
||||
// Poll a session-backed flow (device_code or loopback) until it resolves.
|
||||
// Both shapes only need the session_id to poll; the start is threaded
|
||||
// through to the error flow so the user can retry from the same context.
|
||||
async function pollSession(provider: OAuthProvider, start: DeviceStart | LoopbackStart, ctx: OnboardingContext) {
|
||||
try {
|
||||
const { error_message, status } = await pollOAuthSession(provider.id, start.session_id)
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ export interface OAuthProviderStatus {
|
||||
export interface OAuthProvider {
|
||||
cli_command: string
|
||||
docs_url: string
|
||||
flow: 'device_code' | 'external' | 'pkce'
|
||||
flow: 'device_code' | 'external' | 'loopback' | 'pkce'
|
||||
id: string
|
||||
name: string
|
||||
status: OAuthProviderStatus
|
||||
@ -73,6 +73,12 @@ export type OAuthStartResponse =
|
||||
user_code: string
|
||||
verification_url: string
|
||||
}
|
||||
| {
|
||||
auth_url: string
|
||||
expires_in: number
|
||||
flow: 'loopback'
|
||||
session_id: string
|
||||
}
|
||||
|
||||
export interface OAuthSubmitResponse {
|
||||
message?: string
|
||||
|
||||
@ -2973,6 +2973,17 @@ _OAUTH_PROVIDER_CATALOG: tuple[Dict[str, Any], ...] = (
|
||||
"docs_url": "https://www.minimax.io",
|
||||
"status_fn": None, # dispatched via auth.get_minimax_oauth_auth_status
|
||||
},
|
||||
{
|
||||
"id": "xai-oauth",
|
||||
"name": "xAI Grok OAuth (SuperGrok / Premium+)",
|
||||
# Loopback PKCE: the desktop's local backend binds a 127.0.0.1
|
||||
# callback server, the client opens the browser, and the redirect
|
||||
# lands back on the loopback listener — no code to copy/paste.
|
||||
"flow": "loopback",
|
||||
"cli_command": "hermes auth add xai-oauth",
|
||||
"docs_url": "https://hermes-agent.nousresearch.com/docs/guides/xai-grok-oauth",
|
||||
"status_fn": None, # dispatched via auth.get_xai_oauth_auth_status
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -3026,6 +3037,20 @@ def _resolve_provider_status(provider_id: str, status_fn) -> Dict[str, Any]:
|
||||
"expires_at": raw.get("expires_at"),
|
||||
"has_refresh_token": True,
|
||||
}
|
||||
if provider_id == "xai-oauth":
|
||||
raw = hauth.get_xai_oauth_auth_status()
|
||||
# source_label is meant to be a human-readable origin (auth-store
|
||||
# path / credential source), not the internal auth_mode string
|
||||
# ("oauth_pkce"). Prefer the store path, then the source slug.
|
||||
return {
|
||||
"logged_in": bool(raw.get("logged_in")),
|
||||
"source": raw.get("source") or "xai_oauth",
|
||||
"source_label": raw.get("auth_store") or raw.get("source") or "xAI Grok OAuth",
|
||||
"token_preview": _truncate_token(raw.get("api_key")),
|
||||
"expires_at": None,
|
||||
"has_refresh_token": True,
|
||||
"last_refresh": raw.get("last_refresh"),
|
||||
}
|
||||
except Exception as e:
|
||||
return {"logged_in": False, "error": str(e)}
|
||||
return {"logged_in": False}
|
||||
@ -3038,7 +3063,7 @@ async def list_oauth_providers():
|
||||
Response shape (per provider):
|
||||
id stable identifier (used in DELETE path)
|
||||
name human label
|
||||
flow "pkce" | "device_code" | "external"
|
||||
flow "pkce" | "device_code" | "external" | "loopback"
|
||||
cli_command fallback CLI command for users to run manually
|
||||
docs_url external docs/portal link for the "Learn more" link
|
||||
status:
|
||||
@ -3138,6 +3163,19 @@ async def disconnect_oauth_provider(provider_id: str, request: Request):
|
||||
# 4. On "approved" the background thread has already saved creds; UI
|
||||
# refreshes the providers list.
|
||||
#
|
||||
# Loopback PKCE (xAI Grok):
|
||||
# 1. POST /api/providers/oauth/xai-oauth/start
|
||||
# → server binds a 127.0.0.1 callback listener, builds the xAI
|
||||
# authorize URL, spawns a background worker waiting on the redirect
|
||||
# → returns { session_id, flow: "loopback", auth_url, expires_in }
|
||||
# 2. UI opens auth_url in the browser. There is NO user_code/code to
|
||||
# paste — the redirect lands back on the loopback listener.
|
||||
# 3. UI polls GET /api/providers/oauth/{provider}/poll/{session_id}
|
||||
# (same endpoint as device_code) until status != "pending".
|
||||
# 4. The worker exchanges the code, persists creds, sets "approved".
|
||||
# DELETE /sessions/{id} cancels: the worker bails before persisting
|
||||
# and the callback server is shut down to free the port immediately.
|
||||
#
|
||||
# Sessions are kept in-memory only (single-process FastAPI) and time out
|
||||
# after 15 minutes. A periodic cleanup runs on each /start call to GC
|
||||
# expired sessions so the dict doesn't grow without bound.
|
||||
@ -3521,6 +3559,220 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
|
||||
raise HTTPException(status_code=400, detail=f"Provider {provider_id} does not support device-code flow")
|
||||
|
||||
|
||||
# xAI Grok OAuth uses a loopback-redirect PKCE flow (RFC 8252). Unlike the
|
||||
# device-code providers there is no user_code to display: the local backend
|
||||
# binds a 127.0.0.1 callback server, the client opens the authorize URL in
|
||||
# the browser, and the redirect lands back on the loopback listener. The
|
||||
# background worker waits for that callback, exchanges the code, and persists
|
||||
# the tokens exactly like `hermes auth add xai-oauth`.
|
||||
_XAI_LOOPBACK_TIMEOUT_SECONDS = 300.0
|
||||
|
||||
|
||||
def _start_xai_loopback_flow() -> Dict[str, Any]:
|
||||
"""Begin the xAI loopback PKCE flow.
|
||||
|
||||
Binds the local callback server, builds the authorize URL, and spawns a
|
||||
background worker that waits for the redirect and finishes the exchange.
|
||||
Returns the authorize URL for the client to open in the browser.
|
||||
"""
|
||||
from hermes_cli import auth as hauth
|
||||
|
||||
discovery = hauth._xai_oauth_discovery()
|
||||
server, thread, callback_result, redirect_uri = hauth._xai_start_callback_server()
|
||||
try:
|
||||
hauth._xai_validate_loopback_redirect_uri(redirect_uri)
|
||||
verifier = hauth._oauth_pkce_code_verifier()
|
||||
challenge = hauth._oauth_pkce_code_challenge(verifier)
|
||||
state = secrets.token_hex(16)
|
||||
nonce = secrets.token_hex(16)
|
||||
authorize_url = hauth._xai_oauth_build_authorize_url(
|
||||
authorization_endpoint=discovery["authorization_endpoint"],
|
||||
redirect_uri=redirect_uri,
|
||||
code_challenge=challenge,
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
)
|
||||
except Exception:
|
||||
# Binding succeeded but URL construction failed — release the socket
|
||||
# and join the serving thread so we don't leak a listener (or a
|
||||
# lingering daemon thread) on the loopback port.
|
||||
try:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
thread.join(timeout=1.0)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
sid, sess = _new_oauth_session("xai-oauth", "loopback")
|
||||
sess["server"] = server
|
||||
sess["thread"] = thread
|
||||
sess["callback_result"] = callback_result
|
||||
sess["redirect_uri"] = redirect_uri
|
||||
sess["verifier"] = verifier
|
||||
sess["challenge"] = challenge
|
||||
sess["state"] = state
|
||||
sess["token_endpoint"] = discovery["token_endpoint"]
|
||||
sess["discovery"] = discovery
|
||||
sess["expires_at"] = time.time() + _XAI_LOOPBACK_TIMEOUT_SECONDS
|
||||
threading.Thread(
|
||||
target=_xai_loopback_worker, args=(sid,), daemon=True,
|
||||
name=f"oauth-xai-{sid[:6]}",
|
||||
).start()
|
||||
return {
|
||||
"session_id": sid,
|
||||
"flow": "loopback",
|
||||
"auth_url": authorize_url,
|
||||
"expires_in": int(_XAI_LOOPBACK_TIMEOUT_SECONDS),
|
||||
}
|
||||
|
||||
|
||||
def _xai_loopback_worker(session_id: str) -> None:
|
||||
"""Wait for the xAI loopback callback, exchange the code, persist tokens."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from hermes_cli import auth as hauth
|
||||
|
||||
with _oauth_sessions_lock:
|
||||
sess = _oauth_sessions.get(session_id)
|
||||
if not sess:
|
||||
return
|
||||
|
||||
def _fail(message: str) -> None:
|
||||
with _oauth_sessions_lock:
|
||||
s = _oauth_sessions.get(session_id)
|
||||
if s is not None:
|
||||
s["status"] = "error"
|
||||
s["error_message"] = message
|
||||
|
||||
def _cancelled() -> bool:
|
||||
# The session is removed from the registry when the user cancels
|
||||
# (DELETE /sessions/{id}). If that happened while we were blocked on
|
||||
# the callback or token exchange, abort instead of persisting tokens
|
||||
# the user no longer wants.
|
||||
with _oauth_sessions_lock:
|
||||
return session_id not in _oauth_sessions
|
||||
|
||||
try:
|
||||
callback = hauth._xai_wait_for_callback(
|
||||
sess["server"],
|
||||
sess["thread"],
|
||||
sess["callback_result"],
|
||||
timeout_seconds=_XAI_LOOPBACK_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as exc:
|
||||
_fail(f"xAI authorization timed out: {exc}")
|
||||
return
|
||||
|
||||
if _cancelled():
|
||||
return
|
||||
|
||||
if callback.get("error"):
|
||||
detail = callback.get("error_description") or callback["error"]
|
||||
_fail(f"xAI authorization failed: {detail}")
|
||||
return
|
||||
if callback.get("state") != sess["state"]:
|
||||
_fail("xAI authorization failed: state mismatch.")
|
||||
return
|
||||
code = str(callback.get("code") or "").strip()
|
||||
if not code:
|
||||
_fail("xAI authorization failed: missing authorization code.")
|
||||
return
|
||||
|
||||
try:
|
||||
payload = hauth._xai_oauth_exchange_code_for_tokens(
|
||||
token_endpoint=sess["token_endpoint"],
|
||||
code=code,
|
||||
redirect_uri=sess["redirect_uri"],
|
||||
code_verifier=sess["verifier"],
|
||||
code_challenge=sess["challenge"],
|
||||
)
|
||||
access_token = str(payload.get("access_token", "") or "").strip()
|
||||
refresh_token = str(payload.get("refresh_token", "") or "").strip()
|
||||
if not access_token or not refresh_token:
|
||||
_fail("xAI token exchange did not return the expected tokens.")
|
||||
return
|
||||
base_url = hauth._xai_validate_inference_base_url(
|
||||
os.getenv("HERMES_XAI_BASE_URL", "").strip().rstrip("/")
|
||||
or os.getenv("XAI_BASE_URL", "").strip().rstrip("/"),
|
||||
fallback=hauth.DEFAULT_XAI_OAUTH_BASE_URL,
|
||||
)
|
||||
last_refresh = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z")
|
||||
tokens = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"id_token": str(payload.get("id_token", "") or "").strip(),
|
||||
"expires_in": payload.get("expires_in"),
|
||||
"token_type": str(payload.get("token_type") or "Bearer").strip() or "Bearer",
|
||||
}
|
||||
if _cancelled():
|
||||
return
|
||||
hauth._save_xai_oauth_tokens(
|
||||
tokens,
|
||||
discovery=sess.get("discovery"),
|
||||
redirect_uri=sess["redirect_uri"],
|
||||
last_refresh=last_refresh,
|
||||
)
|
||||
_add_xai_oauth_pool_entry(access_token, refresh_token, base_url, last_refresh)
|
||||
except Exception as exc:
|
||||
_fail(f"xAI token exchange failed: {exc}")
|
||||
return
|
||||
|
||||
with _oauth_sessions_lock:
|
||||
s = _oauth_sessions.get(session_id)
|
||||
if s is not None:
|
||||
s["status"] = "approved"
|
||||
_log.info("oauth/loopback: xai-oauth login completed (session=%s)", session_id)
|
||||
|
||||
|
||||
def _add_xai_oauth_pool_entry(
|
||||
access_token: str, refresh_token: str, base_url: str, last_refresh: str
|
||||
) -> None:
|
||||
"""Mirror `hermes auth add xai-oauth`'s credential-pool insert.
|
||||
|
||||
Best-effort: the auth-store write in _save_xai_oauth_tokens is the source
|
||||
of truth for runtime resolution; the pool entry only matters for the
|
||||
rotation strategy.
|
||||
"""
|
||||
try:
|
||||
import uuid
|
||||
|
||||
from agent.credential_pool import (
|
||||
PooledCredential,
|
||||
load_pool,
|
||||
AUTH_TYPE_OAUTH,
|
||||
SOURCE_MANUAL,
|
||||
)
|
||||
pool = load_pool("xai-oauth")
|
||||
existing = [
|
||||
e for e in pool.entries()
|
||||
if getattr(e, "source", "").startswith(f"{SOURCE_MANUAL}:dashboard_xai_pkce")
|
||||
]
|
||||
for e in existing:
|
||||
try:
|
||||
pool.remove_entry(getattr(e, "id", ""))
|
||||
except Exception:
|
||||
pass
|
||||
entry = PooledCredential(
|
||||
provider="xai-oauth",
|
||||
id=uuid.uuid4().hex[:6],
|
||||
label="dashboard PKCE",
|
||||
auth_type=AUTH_TYPE_OAUTH,
|
||||
priority=0,
|
||||
source=f"{SOURCE_MANUAL}:dashboard_xai_pkce",
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
base_url=base_url,
|
||||
last_refresh=last_refresh,
|
||||
)
|
||||
pool.add_entry(entry)
|
||||
except Exception as e:
|
||||
_log.warning("xai-oauth pool add (dashboard) failed: %s", e)
|
||||
|
||||
|
||||
def _nous_poller(session_id: str) -> None:
|
||||
"""Background poller that drives a Nous device-code flow to completion."""
|
||||
from hermes_cli.auth import (
|
||||
@ -3810,6 +4062,10 @@ async def start_oauth_login(provider_id: str, request: Request):
|
||||
return _start_anthropic_pkce()
|
||||
if catalog_entry["flow"] == "device_code":
|
||||
return await _start_device_code_flow(provider_id)
|
||||
if catalog_entry["flow"] == "loopback" and provider_id == "xai-oauth":
|
||||
return await asyncio.get_running_loop().run_in_executor(
|
||||
None, _start_xai_loopback_flow
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -3836,7 +4092,13 @@ async def submit_oauth_code(provider_id: str, body: OAuthSubmitBody, request: Re
|
||||
|
||||
@app.get("/api/providers/oauth/{provider_id}/poll/{session_id}")
|
||||
async def poll_oauth_session(provider_id: str, session_id: str):
|
||||
"""Poll a device-code session's status (no auth — read-only state)."""
|
||||
"""Poll a session's status (no auth — read-only state).
|
||||
|
||||
Shared by the device-code flows (Nous, OpenAI Codex, MiniMax) and the
|
||||
loopback flow (xAI Grok). Both surface progress through the same
|
||||
background-worker-updated ``status`` field, so a single poll endpoint
|
||||
serves them all.
|
||||
"""
|
||||
with _oauth_sessions_lock:
|
||||
sess = _oauth_sessions.get(session_id)
|
||||
if not sess:
|
||||
@ -3859,6 +4121,33 @@ async def cancel_oauth_session(session_id: str, request: Request):
|
||||
sess = _oauth_sessions.pop(session_id, None)
|
||||
if sess is None:
|
||||
return {"ok": False, "message": "session not found"}
|
||||
# Loopback sessions own a bound 127.0.0.1 callback server. Without an
|
||||
# explicit shutdown the worker would keep that port held until
|
||||
# _xai_wait_for_callback times out (up to 5 min). Free it immediately so
|
||||
# an orphaned listener can't block a subsequent sign-in attempt.
|
||||
if sess.get("flow") == "loopback":
|
||||
# The worker is blocked in _xai_wait_for_callback, which polls
|
||||
# callback_result rather than the server state. Flag the result as
|
||||
# cancelled so that loop returns on its next tick instead of spinning
|
||||
# until the timeout — otherwise repeated cancel/retry piles up daemon
|
||||
# threads. (_cancelled() in the worker then short-circuits before any
|
||||
# persist.)
|
||||
result = sess.get("callback_result")
|
||||
if isinstance(result, dict):
|
||||
result["error"] = result.get("error") or "cancelled"
|
||||
server = sess.get("server")
|
||||
thread = sess.get("thread")
|
||||
try:
|
||||
if server is not None:
|
||||
server.shutdown()
|
||||
server.server_close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if thread is not None:
|
||||
thread.join(timeout=1.0)
|
||||
except Exception:
|
||||
pass
|
||||
return {"ok": True, "session_id": session_id}
|
||||
|
||||
|
||||
|
||||
@ -327,6 +327,258 @@ def test_anthropic_pkce_branch_still_works():
|
||||
assert "claude.ai" in body["auth_url"]
|
||||
|
||||
|
||||
def test_xai_oauth_listed_as_loopback_flow():
|
||||
"""xAI Grok OAuth must surface in the catalog as a first-class loopback flow."""
|
||||
resp = client.get("/api/providers/oauth", headers=HEADERS)
|
||||
assert resp.status_code == 200, resp.text
|
||||
providers = {p["id"]: p for p in resp.json()["providers"]}
|
||||
assert "xai-oauth" in providers
|
||||
assert providers["xai-oauth"]["flow"] == "loopback"
|
||||
assert "grok" in providers["xai-oauth"]["name"].lower()
|
||||
|
||||
|
||||
def test_xai_loopback_start_returns_authorize_url(monkeypatch):
|
||||
"""Start MUST bind the loopback listener and hand back an xAI authorize URL."""
|
||||
from hermes_cli import auth as auth_mod
|
||||
from hermes_cli import web_server as ws
|
||||
|
||||
class _FakeServer:
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def server_close(self):
|
||||
pass
|
||||
|
||||
class _FakeThread:
|
||||
def join(self, timeout=None):
|
||||
pass
|
||||
|
||||
redirect_uri = (
|
||||
f"http://{auth_mod.XAI_OAUTH_REDIRECT_HOST}:{auth_mod.XAI_OAUTH_REDIRECT_PORT}"
|
||||
f"{auth_mod.XAI_OAUTH_REDIRECT_PATH}"
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_xai_oauth_discovery",
|
||||
lambda *a, **k: {
|
||||
"authorization_endpoint": "https://auth.x.ai/oauth2/auth",
|
||||
"token_endpoint": "https://auth.x.ai/oauth2/token",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_xai_start_callback_server",
|
||||
lambda *a, **k: (_FakeServer(), _FakeThread(), {"code": None, "error": None}, redirect_uri),
|
||||
)
|
||||
# Don't let the background worker run a real callback wait/exchange.
|
||||
monkeypatch.setattr(ws, "_xai_loopback_worker", lambda sid: None)
|
||||
|
||||
resp = client.post("/api/providers/oauth/xai-oauth/start", headers=HEADERS)
|
||||
assert resp.status_code == 200, resp.text
|
||||
body = resp.json()
|
||||
try:
|
||||
assert body["flow"] == "loopback"
|
||||
assert "user_code" not in body # loopback has nothing to paste/show
|
||||
assert body["auth_url"].startswith("https://auth.x.ai/oauth2/auth?")
|
||||
assert "code_challenge" in body["auth_url"]
|
||||
sess = ws._oauth_sessions[body["session_id"]]
|
||||
assert sess["provider"] == "xai-oauth"
|
||||
assert sess["flow"] == "loopback"
|
||||
finally:
|
||||
ws._oauth_sessions.pop(body["session_id"], None)
|
||||
|
||||
|
||||
def test_xai_loopback_worker_persists_tokens_on_success(monkeypatch):
|
||||
"""The worker exchanges the callback code and marks the session approved."""
|
||||
from hermes_cli import auth as auth_mod
|
||||
from hermes_cli import web_server as ws
|
||||
|
||||
saved = {}
|
||||
session_id = "xai-loopback-success-test"
|
||||
ws._oauth_sessions[session_id] = {
|
||||
"session_id": session_id,
|
||||
"provider": "xai-oauth",
|
||||
"flow": "loopback",
|
||||
"created_at": time.time(),
|
||||
"status": "pending",
|
||||
"error_message": None,
|
||||
"server": object(),
|
||||
"thread": object(),
|
||||
"callback_result": {"code": "auth-code", "state": "st"},
|
||||
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||
"verifier": "verifier",
|
||||
"challenge": "challenge",
|
||||
"state": "st",
|
||||
"token_endpoint": "https://auth.x.ai/oauth2/token",
|
||||
"discovery": {"token_endpoint": "https://auth.x.ai/oauth2/token"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_xai_wait_for_callback",
|
||||
lambda *a, **k: {"code": "auth-code", "state": "st"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_xai_oauth_exchange_code_for_tokens",
|
||||
lambda **k: {
|
||||
"access_token": "xai-access",
|
||||
"refresh_token": "xai-refresh",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_save_xai_oauth_tokens",
|
||||
lambda tokens, **k: saved.update(tokens),
|
||||
)
|
||||
monkeypatch.setattr(ws, "_add_xai_oauth_pool_entry", lambda *a, **k: None)
|
||||
|
||||
try:
|
||||
ws._xai_loopback_worker(session_id)
|
||||
assert ws._oauth_sessions[session_id]["status"] == "approved"
|
||||
assert saved["access_token"] == "xai-access"
|
||||
assert saved["refresh_token"] == "xai-refresh"
|
||||
finally:
|
||||
ws._oauth_sessions.pop(session_id, None)
|
||||
|
||||
|
||||
def test_xai_loopback_worker_fails_on_state_mismatch(monkeypatch):
|
||||
"""A mismatched OAuth state must fail the session, not persist tokens."""
|
||||
from hermes_cli import auth as auth_mod
|
||||
from hermes_cli import web_server as ws
|
||||
|
||||
session_id = "xai-loopback-state-test"
|
||||
ws._oauth_sessions[session_id] = {
|
||||
"session_id": session_id,
|
||||
"provider": "xai-oauth",
|
||||
"flow": "loopback",
|
||||
"created_at": time.time(),
|
||||
"status": "pending",
|
||||
"error_message": None,
|
||||
"server": object(),
|
||||
"thread": object(),
|
||||
"callback_result": {},
|
||||
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||
"verifier": "verifier",
|
||||
"challenge": "challenge",
|
||||
"state": "expected-state",
|
||||
"token_endpoint": "https://auth.x.ai/oauth2/token",
|
||||
"discovery": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_mod,
|
||||
"_xai_wait_for_callback",
|
||||
lambda *a, **k: {"code": "auth-code", "state": "ATTACKER-state"},
|
||||
)
|
||||
|
||||
def _boom(**kwargs):
|
||||
raise AssertionError("token exchange must not run on state mismatch")
|
||||
|
||||
monkeypatch.setattr(auth_mod, "_xai_oauth_exchange_code_for_tokens", _boom)
|
||||
|
||||
try:
|
||||
ws._xai_loopback_worker(session_id)
|
||||
sess = ws._oauth_sessions[session_id]
|
||||
assert sess["status"] == "error"
|
||||
assert "state mismatch" in sess["error_message"].lower()
|
||||
finally:
|
||||
ws._oauth_sessions.pop(session_id, None)
|
||||
|
||||
|
||||
def test_xai_loopback_worker_skips_persist_when_cancelled(monkeypatch):
|
||||
"""If the session is cancelled while waiting, the worker must not persist."""
|
||||
from hermes_cli import auth as auth_mod
|
||||
from hermes_cli import web_server as ws
|
||||
|
||||
session_id = "xai-loopback-cancel-test"
|
||||
ws._oauth_sessions[session_id] = {
|
||||
"session_id": session_id,
|
||||
"provider": "xai-oauth",
|
||||
"flow": "loopback",
|
||||
"created_at": time.time(),
|
||||
"status": "pending",
|
||||
"error_message": None,
|
||||
"server": object(),
|
||||
"thread": object(),
|
||||
"callback_result": {},
|
||||
"redirect_uri": "http://127.0.0.1:56121/callback",
|
||||
"verifier": "verifier",
|
||||
"challenge": "challenge",
|
||||
"state": "st",
|
||||
"token_endpoint": "https://auth.x.ai/oauth2/token",
|
||||
"discovery": {},
|
||||
}
|
||||
|
||||
def _wait_then_cancel(*args, **kwargs):
|
||||
# Simulate the user cancelling (DELETE /sessions/{id}) while we were
|
||||
# blocked on the callback: the session vanishes, then a valid code
|
||||
# arrives. The worker must notice and bail before persisting.
|
||||
ws._oauth_sessions.pop(session_id, None)
|
||||
return {"code": "auth-code", "state": "st"}
|
||||
|
||||
monkeypatch.setattr(auth_mod, "_xai_wait_for_callback", _wait_then_cancel)
|
||||
|
||||
def _must_not_persist(*args, **kwargs):
|
||||
raise AssertionError("tokens must not be persisted for a cancelled session")
|
||||
|
||||
monkeypatch.setattr(auth_mod, "_save_xai_oauth_tokens", _must_not_persist)
|
||||
monkeypatch.setattr(ws, "_add_xai_oauth_pool_entry", _must_not_persist)
|
||||
|
||||
# Should return cleanly without raising and without persisting.
|
||||
ws._xai_loopback_worker(session_id)
|
||||
assert session_id not in ws._oauth_sessions
|
||||
|
||||
|
||||
def test_cancel_loopback_session_shuts_down_callback_server():
|
||||
"""Cancelling a loopback session must free the bound callback port now."""
|
||||
from hermes_cli import web_server as ws
|
||||
|
||||
shutdown_calls = {"shutdown": 0, "close": 0, "join": 0}
|
||||
|
||||
class _FakeServer:
|
||||
def shutdown(self):
|
||||
shutdown_calls["shutdown"] += 1
|
||||
|
||||
def server_close(self):
|
||||
shutdown_calls["close"] += 1
|
||||
|
||||
class _FakeThread:
|
||||
def join(self, timeout=None):
|
||||
shutdown_calls["join"] += 1
|
||||
|
||||
# callback_result is the dict the worker's _xai_wait_for_callback polls.
|
||||
callback_result = {"code": None, "error": None}
|
||||
session_id = "xai-loopback-cancel-shutdown-test"
|
||||
ws._oauth_sessions[session_id] = {
|
||||
"session_id": session_id,
|
||||
"provider": "xai-oauth",
|
||||
"flow": "loopback",
|
||||
"created_at": time.time(),
|
||||
"status": "pending",
|
||||
"server": _FakeServer(),
|
||||
"thread": _FakeThread(),
|
||||
"callback_result": callback_result,
|
||||
}
|
||||
|
||||
try:
|
||||
resp = client.delete(
|
||||
f"/api/providers/oauth/sessions/{session_id}", headers=HEADERS
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
assert resp.json()["ok"] is True
|
||||
assert shutdown_calls == {"shutdown": 1, "close": 1, "join": 1}
|
||||
# The waiting worker must be signalled so it returns promptly instead
|
||||
# of spinning until the timeout.
|
||||
assert callback_result["error"] == "cancelled"
|
||||
assert session_id not in ws._oauth_sessions
|
||||
finally:
|
||||
ws._oauth_sessions.pop(session_id, None)
|
||||
|
||||
|
||||
def test_unknown_pkce_provider_rejected_cleanly():
|
||||
"""A future PKCE provider without an explicit branch must NOT silently route to Anthropic.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user