feat(web): wire local/custom endpoints into model assignment
The runtime resolver reads model.base_url from config and ignores the OPENAI_BASE_URL env var, so a self-hosted endpoint could not be configured from the GUI. Two changes enable it: - POST /api/model/set accepts an optional base_url and persists it as model.base_url when provider=custom (still clearing stale base_url for hosted providers). - POST /api/providers/validate now returns the model ids a custom endpoint advertises at /v1/models, so the GUI can auto-pick a default without asking the user to type a model name. Refs desktop onboarding "Local / custom endpoint" bug.
This commit is contained in:
@ -608,6 +608,12 @@ class ModelAssignment(BaseModel):
|
||||
provider: str
|
||||
model: str
|
||||
task: str = ""
|
||||
# Optional OpenAI-compatible endpoint URL. Only honored for custom/local
|
||||
# providers on the main slot — lets the GUI configure a self-hosted endpoint
|
||||
# (vLLM, llama.cpp, Ollama, …) that needs no API key. The runtime resolver
|
||||
# reads model.base_url from config (it ignores OPENAI_BASE_URL), so this is
|
||||
# the path that actually wires a local endpoint into resolution.
|
||||
base_url: str = ""
|
||||
|
||||
|
||||
_GATEWAY_HEALTH_URL = os.getenv("GATEWAY_HEALTH_URL")
|
||||
@ -1954,6 +1960,7 @@ async def set_model_assignment(body: ModelAssignment):
|
||||
provider = (body.provider or "").strip()
|
||||
model = (body.model or "").strip()
|
||||
task = (body.task or "").strip().lower()
|
||||
base_url = (body.base_url or "").strip()
|
||||
|
||||
if scope not in {"main", "auxiliary"}:
|
||||
raise HTTPException(status_code=400, detail="scope must be 'main' or 'auxiliary'")
|
||||
@ -1969,8 +1976,14 @@ async def set_model_assignment(body: ModelAssignment):
|
||||
model_cfg = {}
|
||||
model_cfg["provider"] = provider
|
||||
model_cfg["default"] = model
|
||||
# Clear stale base_url so the resolver picks the provider's own default.
|
||||
if "base_url" in model_cfg and model_cfg.get("base_url"):
|
||||
# Custom/local providers are defined by their endpoint URL, so a
|
||||
# base_url must be persisted here — the runtime resolver reads
|
||||
# model.base_url from config and no longer consults OPENAI_BASE_URL.
|
||||
# For every other provider, clear any stale base_url so the
|
||||
# resolver picks the provider's own default endpoint.
|
||||
if provider.strip().lower() == "custom" and base_url:
|
||||
model_cfg["base_url"] = base_url
|
||||
elif "base_url" in model_cfg and model_cfg.get("base_url"):
|
||||
model_cfg["base_url"] = ""
|
||||
# Also clear hardcoded context_length override — new model may have
|
||||
# a different context window.
|
||||
@ -2013,6 +2026,7 @@ async def set_model_assignment(body: ModelAssignment):
|
||||
"scope": "main",
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"base_url": model_cfg.get("base_url", ""),
|
||||
"gateway_tools": gateway_tools,
|
||||
}
|
||||
|
||||
@ -2181,6 +2195,33 @@ _CREDENTIAL_PROBES: dict[str, tuple[str, str]] = {
|
||||
}
|
||||
|
||||
|
||||
def _parse_model_ids(resp: "Any") -> List[str]:
|
||||
"""Extract model ids from an OpenAI-compatible ``/v1/models`` response.
|
||||
|
||||
Tolerant of the common shapes: ``{"data": [{"id": ...}]}`` (OpenAI / vLLM /
|
||||
llama.cpp) and a bare ``{"data": ["id", ...]}``. Returns ``[]`` on any
|
||||
parse/HTTP error so a slightly non-standard endpoint never hard-blocks.
|
||||
"""
|
||||
try:
|
||||
if not resp.is_success:
|
||||
return []
|
||||
payload = resp.json()
|
||||
except Exception:
|
||||
return []
|
||||
data = payload.get("data") if isinstance(payload, dict) else payload
|
||||
if not isinstance(data, list):
|
||||
return []
|
||||
ids: List[str] = []
|
||||
for item in data:
|
||||
if isinstance(item, dict):
|
||||
mid = str(item.get("id") or "").strip()
|
||||
else:
|
||||
mid = str(item or "").strip()
|
||||
if mid:
|
||||
ids.append(mid)
|
||||
return ids
|
||||
|
||||
|
||||
@app.post("/api/providers/validate")
|
||||
async def validate_provider_credential(body: EnvVarUpdate, request: Request):
|
||||
"""Live-probe a provider credential before it's saved.
|
||||
@ -2199,13 +2240,15 @@ async def validate_provider_credential(body: EnvVarUpdate, request: Request):
|
||||
return {"ok": False, "reachable": True, "message": "Enter a value first."}
|
||||
|
||||
# Local / custom endpoint: validate connectivity, not auth — any HTTP
|
||||
# response (even 401) proves the endpoint is up.
|
||||
# response (even 401) proves the endpoint is up. Also surface the model
|
||||
# ids the endpoint advertises (OpenAI ``/v1/models`` shape) so the GUI can
|
||||
# auto-pick a default without asking the user to type a model name.
|
||||
if key == "OPENAI_BASE_URL":
|
||||
url = value.rstrip("/") + "/models"
|
||||
try:
|
||||
with httpx.Client(timeout=httpx.Timeout(8.0)) as client:
|
||||
client.get(url)
|
||||
return {"ok": True, "reachable": True, "message": ""}
|
||||
resp = client.get(url)
|
||||
return {"ok": True, "reachable": True, "message": "", "models": _parse_model_ids(resp)}
|
||||
except Exception:
|
||||
return {"ok": False, "reachable": False, "message": f"Could not reach {url}."}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user