diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 825b5d2b1..c412d0c00 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -616,6 +616,33 @@ class ModelAssignment(BaseModel): base_url: str = "" +def _apply_main_model_assignment( + model_cfg: "Any", provider: str, model: str, base_url: str = "" +) -> dict: + """Apply a main-slot model assignment to a ``model`` config dict in place. + + Sets ``provider``/``default``, then reconciles ``base_url``: custom/local + providers persist the supplied endpoint URL (the runtime resolver reads + ``model.base_url`` from config and ignores ``OPENAI_BASE_URL``), while every + other provider clears any stale URL so the resolver picks that provider's + own default endpoint. The hardcoded ``context_length`` override is always + dropped since the new model may have a different context window. + + Returns the same dict (coerced to a fresh dict if the input wasn't one) so + callers can assign it straight back onto ``cfg["model"]``. + """ + if not isinstance(model_cfg, dict): + model_cfg = {} + model_cfg["provider"] = provider + model_cfg["default"] = model + if provider.strip().lower() == "custom" and base_url.strip(): + model_cfg["base_url"] = base_url.strip() + elif model_cfg.get("base_url"): + model_cfg["base_url"] = "" + model_cfg.pop("context_length", None) + return model_cfg + + _GATEWAY_HEALTH_URL = os.getenv("GATEWAY_HEALTH_URL") try: _GATEWAY_HEALTH_TIMEOUT = float(os.getenv("GATEWAY_HEALTH_TIMEOUT", "3")) @@ -2016,24 +2043,9 @@ async def set_model_assignment(body: ModelAssignment): if scope == "main": if not provider or not model: raise HTTPException(status_code=400, detail="provider and model required for main") - model_cfg = cfg.get("model", {}) - if not isinstance(model_cfg, dict): - model_cfg = {} - model_cfg["provider"] = provider - model_cfg["default"] = model - # Custom/local providers are defined by their endpoint URL, so a - # base_url must be persisted here — the runtime resolver reads - # model.base_url from config and no longer consults OPENAI_BASE_URL. - # For every other provider, clear any stale base_url so the - # resolver picks the provider's own default endpoint. - if provider.strip().lower() == "custom" and base_url: - model_cfg["base_url"] = base_url - elif "base_url" in model_cfg and model_cfg.get("base_url"): - model_cfg["base_url"] = "" - # Also clear hardcoded context_length override — new model may have - # a different context window. - if "context_length" in model_cfg: - model_cfg.pop("context_length", None) + model_cfg = _apply_main_model_assignment( + cfg.get("model", {}), provider, model, base_url + ) cfg["model"] = model_cfg # When switching the main provider to Nous, mirror the CLI's @@ -6281,15 +6293,7 @@ def _write_profile_model(profile_dir: Path, provider: str, model: str) -> None: token = set_hermes_home_override(str(profile_dir)) try: cfg = load_config() - model_cfg = cfg.get("model", {}) - if not isinstance(model_cfg, dict): - model_cfg = {} - model_cfg["provider"] = provider - model_cfg["default"] = model - if model_cfg.get("base_url"): - model_cfg["base_url"] = "" - model_cfg.pop("context_length", None) - cfg["model"] = model_cfg + cfg["model"] = _apply_main_model_assignment(cfg.get("model", {}), provider, model) save_config(cfg) finally: reset_hermes_home_override(token) diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index 570fe4bc7..16feba323 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -1047,6 +1047,39 @@ class TestWebServerEndpoints: assert data["ok"] is True assert data.get("gateway_tools", []) == [] + def test_apply_main_model_assignment_base_url_and_context_reconcile(self): + """The shared main-slot assignment helper must persist base_url only for + custom providers, clear stale base_url for hosted ones, and always drop + a hardcoded context_length override. Both POST /api/model/set and + profile-model writes route through this, so the contract is pinned here.""" + from hermes_cli.web_server import _apply_main_model_assignment + + # Custom + base_url → persisted; stale context_length dropped. + out = _apply_main_model_assignment( + {"context_length": 8192}, "custom", "llama-3.1-8b", "http://127.0.0.1:8000/v1" + ) + assert out["provider"] == "custom" + assert out["default"] == "llama-3.1-8b" + assert out["base_url"] == "http://127.0.0.1:8000/v1" + assert "context_length" not in out + + # Hosted provider → stale base_url cleared (no base_url supplied). + out = _apply_main_model_assignment( + {"base_url": "http://127.0.0.1:8000/v1"}, "openrouter", "anthropic/claude-opus-4.8" + ) + assert out["provider"] == "openrouter" + assert out["base_url"] == "" + + # Custom WITHOUT a base_url → don't invent one, clear any stale value. + out = _apply_main_model_assignment( + {"base_url": "http://stale:1/v1"}, "custom", "m" + ) + assert out["base_url"] == "" + + # Non-dict input is coerced to a fresh dict (never raises). + out = _apply_main_model_assignment("not-a-dict", "custom", "m", "http://x/v1") + assert out == {"provider": "custom", "default": "m", "base_url": "http://x/v1"} + def test_parse_model_ids_handles_openai_and_bare_shapes(self): """Model discovery must tolerate the common /v1/models shapes and never raise (so a slightly non-standard local endpoint still works)."""