fix(gateway): preserve WhatsApp pairing approvals across JID/LID alias flips

This commit is contained in:
QuenVix
2026-05-23 10:25:35 +03:00
committed by Teknium
parent 3127a41cb1
commit 52a368fa72
2 changed files with 136 additions and 10 deletions

View File

@ -28,6 +28,10 @@ import time
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from gateway.whatsapp_identity import (
expand_whatsapp_aliases,
normalize_whatsapp_identifier,
)
from hermes_constants import get_hermes_dir from hermes_constants import get_hermes_dir
from utils import atomic_replace from utils import atomic_replace
@ -110,12 +114,40 @@ class PairingStore:
def _save_json(self, path: Path, data: dict) -> None: def _save_json(self, path: Path, data: dict) -> None:
_secure_write(path, json.dumps(data, indent=2, ensure_ascii=False)) _secure_write(path, json.dumps(data, indent=2, ensure_ascii=False))
def _normalize_user_id(self, platform: str, user_id: str) -> str:
"""Normalize platform-specific user IDs before persisting them."""
raw_user_id = str(user_id or "").strip()
if platform == "whatsapp":
return normalize_whatsapp_identifier(raw_user_id) or raw_user_id
return raw_user_id
def _user_id_aliases(self, platform: str, user_id: str) -> set[str]:
"""Return all known equivalent user IDs for auth/rate-limit checks."""
raw_user_id = str(user_id or "").strip()
if not raw_user_id:
return set()
aliases = {raw_user_id, self._normalize_user_id(platform, raw_user_id)}
if platform == "whatsapp":
aliases.update(expand_whatsapp_aliases(raw_user_id))
aliases.discard("")
return aliases
def _user_ids_match(self, platform: str, left: str, right: str) -> bool:
"""Return True when two user IDs represent the same principal."""
left_aliases = self._user_id_aliases(platform, left)
right_aliases = self._user_id_aliases(platform, right)
return bool(left_aliases and right_aliases and (left_aliases & right_aliases))
# ----- Approved users ----- # ----- Approved users -----
def is_approved(self, platform: str, user_id: str) -> bool: def is_approved(self, platform: str, user_id: str) -> bool:
"""Check if a user is approved (paired) on a platform.""" """Check if a user is approved (paired) on a platform."""
approved = self._load_json(self._approved_path(platform)) approved = self._load_json(self._approved_path(platform))
return user_id in approved for approved_user_id in approved:
if self._user_ids_match(platform, approved_user_id, user_id):
return True
return False
def list_approved(self, platform: str = None) -> list: def list_approved(self, platform: str = None) -> list:
"""List approved users, optionally filtered by platform.""" """List approved users, optionally filtered by platform."""
@ -130,7 +162,16 @@ class PairingStore:
def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None: def _approve_user(self, platform: str, user_id: str, user_name: str = "") -> None:
"""Add a user to the approved list. Must be called under self._lock.""" """Add a user to the approved list. Must be called under self._lock."""
approved = self._load_json(self._approved_path(platform)) approved = self._load_json(self._approved_path(platform))
approved[user_id] = { normalized_user_id = self._normalize_user_id(platform, user_id)
duplicate_ids = [
approved_user_id
for approved_user_id in approved
if self._user_ids_match(platform, approved_user_id, normalized_user_id)
]
for approved_user_id in duplicate_ids:
del approved[approved_user_id]
approved[normalized_user_id] = {
"user_name": user_name, "user_name": user_name,
"approved_at": time.time(), "approved_at": time.time(),
} }
@ -141,8 +182,14 @@ class PairingStore:
path = self._approved_path(platform) path = self._approved_path(platform)
with self._lock: with self._lock:
approved = self._load_json(path) approved = self._load_json(path)
if user_id in approved: matching_ids = [
del approved[user_id] approved_user_id
for approved_user_id in approved
if self._user_ids_match(platform, approved_user_id, user_id)
]
if matching_ids:
for approved_user_id in matching_ids:
del approved[approved_user_id]
self._save_json(path, approved) self._save_json(path, approved)
return True return True
return False return False
@ -170,6 +217,7 @@ class PairingStore:
""" """
with self._lock: with self._lock:
self._cleanup_expired(platform) self._cleanup_expired(platform)
normalized_user_id = self._normalize_user_id(platform, user_id)
# Check lockout # Check lockout
if self._is_locked_out(platform): if self._is_locked_out(platform):
@ -198,7 +246,7 @@ class PairingStore:
pending[entry_id] = { pending[entry_id] = {
"hash": code_hash, "hash": code_hash,
"salt": salt.hex(), "salt": salt.hex(),
"user_id": user_id, "user_id": normalized_user_id,
"user_name": user_name, "user_name": user_name,
"created_at": time.time(), "created_at": time.time(),
} }
@ -326,15 +374,20 @@ class PairingStore:
def _is_rate_limited(self, platform: str, user_id: str) -> bool: def _is_rate_limited(self, platform: str, user_id: str) -> bool:
"""Check if a user has requested a code too recently.""" """Check if a user has requested a code too recently."""
limits = self._load_json(self._rate_limit_path()) limits = self._load_json(self._rate_limit_path())
key = f"{platform}:{user_id}" for alias in self._user_id_aliases(platform, user_id):
last_request = limits.get(key, 0) key = f"{platform}:{alias}"
return (time.time() - last_request) < RATE_LIMIT_SECONDS last_request = limits.get(key, 0)
if (time.time() - last_request) < RATE_LIMIT_SECONDS:
return True
return False
def _record_rate_limit(self, platform: str, user_id: str) -> None: def _record_rate_limit(self, platform: str, user_id: str) -> None:
"""Record the time of a pairing request for rate limiting.""" """Record the time of a pairing request for rate limiting."""
limits = self._load_json(self._rate_limit_path()) limits = self._load_json(self._rate_limit_path())
key = f"{platform}:{user_id}" now = time.time()
limits[key] = time.time() for alias in self._user_id_aliases(platform, user_id):
key = f"{platform}:{alias}"
limits[key] = now
self._save_json(self._rate_limit_path(), limits) self._save_json(self._rate_limit_path(), limits)
def _is_locked_out(self, platform: str) -> bool: def _is_locked_out(self, platform: str) -> bool:

View File

@ -2,10 +2,13 @@
import json import json
import os import os
import sys
import time import time
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
from gateway.pairing import ( from gateway.pairing import (
PairingStore, PairingStore,
ALPHABET, ALPHABET,
@ -37,6 +40,10 @@ class TestSecureWrite:
assert target.exists() assert target.exists()
assert json.loads(target.read_text()) == {"hello": "world"} assert json.loads(target.read_text()) == {"hello": "world"}
@pytest.mark.skipif(
sys.platform.startswith("win"),
reason="POSIX file modes are not enforced on Windows",
)
def test_sets_file_permissions(self, tmp_path): def test_sets_file_permissions(self, tmp_path):
target = tmp_path / "secret.json" target = tmp_path / "secret.json"
_secure_write(target, "data") _secure_write(target, "data")
@ -305,6 +312,23 @@ class TestRateLimiting:
assert isinstance(code2, str) and len(code2) == CODE_LENGTH assert isinstance(code2, str) and len(code2) == CODE_LENGTH
assert code2 != code1 assert code2 != code1
def test_whatsapp_alias_flip_hits_same_rate_limit(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code1 = store.generate_code("whatsapp", "15551234567@s.whatsapp.net")
code2 = store.generate_code("whatsapp", "999999999999999@lid")
assert isinstance(code1, str) and len(code1) == CODE_LENGTH
assert code2 is None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Max pending limit # Max pending limit
@ -397,6 +421,55 @@ class TestApprovalFlow:
result = store.approve_code("telegram", "INVALIDCODE") result = store.approve_code("telegram", "INVALIDCODE")
assert result is None assert result is None
def test_whatsapp_approved_user_survives_alias_flip(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
code = store.generate_code("whatsapp", "15551234567@s.whatsapp.net", "Alice")
store.approve_code("whatsapp", code)
assert store.is_approved("whatsapp", "15551234567@s.whatsapp.net") is True
assert store.is_approved("whatsapp", "999999999999999@lid") is True
approved = store.list_approved("whatsapp")
assert len(approved) == 1
assert approved[0]["user_id"] == "15551234567"
def test_whatsapp_legacy_raw_jid_approval_survives_alias_flip(self, tmp_path, monkeypatch):
mapping_dir = tmp_path / "whatsapp" / "session"
mapping_dir.mkdir(parents=True, exist_ok=True)
(mapping_dir / "lid-mapping-999999999999999.json").write_text(
json.dumps("15551234567@s.whatsapp.net"),
encoding="utf-8",
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
approved_path = tmp_path / "whatsapp-approved.json"
approved_path.write_text(
json.dumps(
{
"15551234567@s.whatsapp.net": {
"user_name": "Legacy Alice",
"approved_at": time.time(),
}
},
indent=2,
),
encoding="utf-8",
)
with patch("gateway.pairing.PAIRING_DIR", tmp_path):
store = PairingStore()
assert store.is_approved("whatsapp", "999999999999999@lid") is True
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Lockout after failed attempts # Lockout after failed attempts