Speed up auth verification with JWKS/admin token caching

This commit is contained in:
Chris
2026-04-03 02:20:54 +08:00
parent 418a7b7099
commit fa624127c8

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from functools import lru_cache
import logging
import time
import httpx
import jwt
@@ -49,7 +50,19 @@ class ProviderTokenVerifier:
if not self.jwks_url:
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
self._jwk_client = jwt.PyJWKClient(self.jwks_url)
self._jwk_client = jwt.PyJWKClient(
self.jwks_url,
cache_jwk_set=True,
lifespan=600,
headers={
"Accept": "application/json",
"User-Agent": "member-ose-backend/1.0",
},
timeout=5,
)
self._admin_token_cached: str | None = None
self._admin_token_expires_at: float = 0
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
@staticmethod
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
@@ -142,6 +155,9 @@ class ProviderTokenVerifier:
return self._enrich_groups_from_admin(enriched)
def _get_admin_token(self) -> str | None:
now = time.time()
if self._admin_token_cached and now < self._admin_token_expires_at:
return self._admin_token_cached
if (
not self.base_url
or not self.admin_realm
@@ -167,7 +183,14 @@ class ProviderTokenVerifier:
if resp.status_code >= 400:
return None
token = resp.json().get("access_token")
return str(token) if token else None
expires_in = resp.json().get("expires_in")
if token:
ttl = int(expires_in) if isinstance(expires_in, int) else 30
# Keep a small buffer to avoid using near-expiry admin token.
self._admin_token_cached = str(token)
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
return self._admin_token_cached
return None
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
if principal.groups:
@@ -213,6 +236,11 @@ class ProviderTokenVerifier:
)
def verify_access_token(self, token: str) -> ProviderPrincipal:
now = time.time()
cached = self._principal_cache.get(token)
if cached and now < cached[0]:
return cached[1]
try:
header = jwt.get_unverified_header(token)
algorithm = str(header.get("alg", "")).upper()
@@ -262,7 +290,20 @@ class ProviderTokenVerifier:
preferred_username=claims.get("preferred_username"),
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
)
return self._enrich_from_userinfo(principal, token)
enriched = self._enrich_from_userinfo(principal, token)
exp = claims.get("exp")
if isinstance(exp, int):
cache_until = min(float(exp), now + 60)
else:
cache_until = now + 30
if cache_until > now:
self._principal_cache[token] = (cache_until, enriched)
if len(self._principal_cache) > 512:
# Simple bound to avoid unbounded memory growth.
self._principal_cache.clear()
return enriched
@lru_cache