Speed up auth verification with JWKS/admin token caching
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import time
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
@@ -49,7 +50,19 @@ class ProviderTokenVerifier:
|
||||
if not self.jwks_url:
|
||||
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
||||
|
||||
self._jwk_client = jwt.PyJWKClient(self.jwks_url)
|
||||
self._jwk_client = jwt.PyJWKClient(
|
||||
self.jwks_url,
|
||||
cache_jwk_set=True,
|
||||
lifespan=600,
|
||||
headers={
|
||||
"Accept": "application/json",
|
||||
"User-Agent": "member-ose-backend/1.0",
|
||||
},
|
||||
timeout=5,
|
||||
)
|
||||
self._admin_token_cached: str | None = None
|
||||
self._admin_token_expires_at: float = 0
|
||||
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
|
||||
|
||||
@staticmethod
|
||||
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
||||
@@ -142,6 +155,9 @@ class ProviderTokenVerifier:
|
||||
return self._enrich_groups_from_admin(enriched)
|
||||
|
||||
def _get_admin_token(self) -> str | None:
|
||||
now = time.time()
|
||||
if self._admin_token_cached and now < self._admin_token_expires_at:
|
||||
return self._admin_token_cached
|
||||
if (
|
||||
not self.base_url
|
||||
or not self.admin_realm
|
||||
@@ -167,7 +183,14 @@ class ProviderTokenVerifier:
|
||||
if resp.status_code >= 400:
|
||||
return None
|
||||
token = resp.json().get("access_token")
|
||||
return str(token) if token else None
|
||||
expires_in = resp.json().get("expires_in")
|
||||
if token:
|
||||
ttl = int(expires_in) if isinstance(expires_in, int) else 30
|
||||
# Keep a small buffer to avoid using near-expiry admin token.
|
||||
self._admin_token_cached = str(token)
|
||||
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
|
||||
return self._admin_token_cached
|
||||
return None
|
||||
|
||||
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
||||
if principal.groups:
|
||||
@@ -213,6 +236,11 @@ class ProviderTokenVerifier:
|
||||
)
|
||||
|
||||
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
||||
now = time.time()
|
||||
cached = self._principal_cache.get(token)
|
||||
if cached and now < cached[0]:
|
||||
return cached[1]
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
algorithm = str(header.get("alg", "")).upper()
|
||||
@@ -262,7 +290,20 @@ class ProviderTokenVerifier:
|
||||
preferred_username=claims.get("preferred_username"),
|
||||
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
||||
)
|
||||
return self._enrich_from_userinfo(principal, token)
|
||||
enriched = self._enrich_from_userinfo(principal, token)
|
||||
|
||||
exp = claims.get("exp")
|
||||
if isinstance(exp, int):
|
||||
cache_until = min(float(exp), now + 60)
|
||||
else:
|
||||
cache_until = now + 30
|
||||
if cache_until > now:
|
||||
self._principal_cache[token] = (cache_until, enriched)
|
||||
if len(self._principal_cache) > 512:
|
||||
# Simple bound to avoid unbounded memory growth.
|
||||
self._principal_cache.clear()
|
||||
|
||||
return enriched
|
||||
|
||||
|
||||
@lru_cache
|
||||
|
||||
Reference in New Issue
Block a user