refactor(web): unify main-slot model assignment base_url/context handling (#38593)

Both POST /api/model/set and the profile-model writer hand-rolled the same
provider/default/base_url/context_length reconciliation. Extract it into
_apply_main_model_assignment so the custom-vs-hosted base_url logic lives in
one place — removing the future-drift risk where one site learns about
custom base_url persistence and the other forgets.

Behavior unchanged; pinned with a direct helper unit test.
This commit is contained in:
Teknium
2026-06-03 20:25:33 -07:00
committed by GitHub
parent e2ea648a08
commit e45dd2b0e7
2 changed files with 64 additions and 27 deletions

View File

@ -616,6 +616,33 @@ class ModelAssignment(BaseModel):
base_url: str = ""
def _apply_main_model_assignment(
model_cfg: "Any", provider: str, model: str, base_url: str = ""
) -> dict:
"""Apply a main-slot model assignment to a ``model`` config dict in place.
Sets ``provider``/``default``, then reconciles ``base_url``: custom/local
providers persist the supplied endpoint URL (the runtime resolver reads
``model.base_url`` from config and ignores ``OPENAI_BASE_URL``), while every
other provider clears any stale URL so the resolver picks that provider's
own default endpoint. The hardcoded ``context_length`` override is always
dropped since the new model may have a different context window.
Returns the same dict (coerced to a fresh dict if the input wasn't one) so
callers can assign it straight back onto ``cfg["model"]``.
"""
if not isinstance(model_cfg, dict):
model_cfg = {}
model_cfg["provider"] = provider
model_cfg["default"] = model
if provider.strip().lower() == "custom" and base_url.strip():
model_cfg["base_url"] = base_url.strip()
elif model_cfg.get("base_url"):
model_cfg["base_url"] = ""
model_cfg.pop("context_length", None)
return model_cfg
_GATEWAY_HEALTH_URL = os.getenv("GATEWAY_HEALTH_URL")
try:
_GATEWAY_HEALTH_TIMEOUT = float(os.getenv("GATEWAY_HEALTH_TIMEOUT", "3"))
@ -2016,24 +2043,9 @@ async def set_model_assignment(body: ModelAssignment):
if scope == "main":
if not provider or not model:
raise HTTPException(status_code=400, detail="provider and model required for main")
model_cfg = cfg.get("model", {})
if not isinstance(model_cfg, dict):
model_cfg = {}
model_cfg["provider"] = provider
model_cfg["default"] = model
# Custom/local providers are defined by their endpoint URL, so a
# base_url must be persisted here — the runtime resolver reads
# model.base_url from config and no longer consults OPENAI_BASE_URL.
# For every other provider, clear any stale base_url so the
# resolver picks the provider's own default endpoint.
if provider.strip().lower() == "custom" and base_url:
model_cfg["base_url"] = base_url
elif "base_url" in model_cfg and model_cfg.get("base_url"):
model_cfg["base_url"] = ""
# Also clear hardcoded context_length override — new model may have
# a different context window.
if "context_length" in model_cfg:
model_cfg.pop("context_length", None)
model_cfg = _apply_main_model_assignment(
cfg.get("model", {}), provider, model, base_url
)
cfg["model"] = model_cfg
# When switching the main provider to Nous, mirror the CLI's
@ -6281,15 +6293,7 @@ def _write_profile_model(profile_dir: Path, provider: str, model: str) -> None:
token = set_hermes_home_override(str(profile_dir))
try:
cfg = load_config()
model_cfg = cfg.get("model", {})
if not isinstance(model_cfg, dict):
model_cfg = {}
model_cfg["provider"] = provider
model_cfg["default"] = model
if model_cfg.get("base_url"):
model_cfg["base_url"] = ""
model_cfg.pop("context_length", None)
cfg["model"] = model_cfg
cfg["model"] = _apply_main_model_assignment(cfg.get("model", {}), provider, model)
save_config(cfg)
finally:
reset_hermes_home_override(token)

View File

@ -1047,6 +1047,39 @@ class TestWebServerEndpoints:
assert data["ok"] is True
assert data.get("gateway_tools", []) == []
def test_apply_main_model_assignment_base_url_and_context_reconcile(self):
"""The shared main-slot assignment helper must persist base_url only for
custom providers, clear stale base_url for hosted ones, and always drop
a hardcoded context_length override. Both POST /api/model/set and
profile-model writes route through this, so the contract is pinned here."""
from hermes_cli.web_server import _apply_main_model_assignment
# Custom + base_url → persisted; stale context_length dropped.
out = _apply_main_model_assignment(
{"context_length": 8192}, "custom", "llama-3.1-8b", "http://127.0.0.1:8000/v1"
)
assert out["provider"] == "custom"
assert out["default"] == "llama-3.1-8b"
assert out["base_url"] == "http://127.0.0.1:8000/v1"
assert "context_length" not in out
# Hosted provider → stale base_url cleared (no base_url supplied).
out = _apply_main_model_assignment(
{"base_url": "http://127.0.0.1:8000/v1"}, "openrouter", "anthropic/claude-opus-4.8"
)
assert out["provider"] == "openrouter"
assert out["base_url"] == ""
# Custom WITHOUT a base_url → don't invent one, clear any stale value.
out = _apply_main_model_assignment(
{"base_url": "http://stale:1/v1"}, "custom", "m"
)
assert out["base_url"] == ""
# Non-dict input is coerced to a fresh dict (never raises).
out = _apply_main_model_assignment("not-a-dict", "custom", "m", "http://x/v1")
assert out == {"provider": "custom", "default": "m", "base_url": "http://x/v1"}
def test_parse_model_ids_handles_openai_and_bare_shapes(self):
"""Model discovery must tolerate the common /v1/models shapes and
never raise (so a slightly non-standard local endpoint still works)."""