Speed up auth verification with JWKS/admin token caching
This commit is contained in:
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import jwt
|
import jwt
|
||||||
@@ -49,7 +50,19 @@ class ProviderTokenVerifier:
|
|||||||
if not self.jwks_url:
|
if not self.jwks_url:
|
||||||
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
||||||
|
|
||||||
self._jwk_client = jwt.PyJWKClient(self.jwks_url)
|
self._jwk_client = jwt.PyJWKClient(
|
||||||
|
self.jwks_url,
|
||||||
|
cache_jwk_set=True,
|
||||||
|
lifespan=600,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": "member-ose-backend/1.0",
|
||||||
|
},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
self._admin_token_cached: str | None = None
|
||||||
|
self._admin_token_expires_at: float = 0
|
||||||
|
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
||||||
@@ -142,6 +155,9 @@ class ProviderTokenVerifier:
|
|||||||
return self._enrich_groups_from_admin(enriched)
|
return self._enrich_groups_from_admin(enriched)
|
||||||
|
|
||||||
def _get_admin_token(self) -> str | None:
|
def _get_admin_token(self) -> str | None:
|
||||||
|
now = time.time()
|
||||||
|
if self._admin_token_cached and now < self._admin_token_expires_at:
|
||||||
|
return self._admin_token_cached
|
||||||
if (
|
if (
|
||||||
not self.base_url
|
not self.base_url
|
||||||
or not self.admin_realm
|
or not self.admin_realm
|
||||||
@@ -167,7 +183,14 @@ class ProviderTokenVerifier:
|
|||||||
if resp.status_code >= 400:
|
if resp.status_code >= 400:
|
||||||
return None
|
return None
|
||||||
token = resp.json().get("access_token")
|
token = resp.json().get("access_token")
|
||||||
return str(token) if token else None
|
expires_in = resp.json().get("expires_in")
|
||||||
|
if token:
|
||||||
|
ttl = int(expires_in) if isinstance(expires_in, int) else 30
|
||||||
|
# Keep a small buffer to avoid using near-expiry admin token.
|
||||||
|
self._admin_token_cached = str(token)
|
||||||
|
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
|
||||||
|
return self._admin_token_cached
|
||||||
|
return None
|
||||||
|
|
||||||
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
||||||
if principal.groups:
|
if principal.groups:
|
||||||
@@ -213,6 +236,11 @@ class ProviderTokenVerifier:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
||||||
|
now = time.time()
|
||||||
|
cached = self._principal_cache.get(token)
|
||||||
|
if cached and now < cached[0]:
|
||||||
|
return cached[1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
algorithm = str(header.get("alg", "")).upper()
|
algorithm = str(header.get("alg", "")).upper()
|
||||||
@@ -262,7 +290,20 @@ class ProviderTokenVerifier:
|
|||||||
preferred_username=claims.get("preferred_username"),
|
preferred_username=claims.get("preferred_username"),
|
||||||
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
||||||
)
|
)
|
||||||
return self._enrich_from_userinfo(principal, token)
|
enriched = self._enrich_from_userinfo(principal, token)
|
||||||
|
|
||||||
|
exp = claims.get("exp")
|
||||||
|
if isinstance(exp, int):
|
||||||
|
cache_until = min(float(exp), now + 60)
|
||||||
|
else:
|
||||||
|
cache_until = now + 30
|
||||||
|
if cache_until > now:
|
||||||
|
self._principal_cache[token] = (cache_until, enriched)
|
||||||
|
if len(self._principal_cache) > 512:
|
||||||
|
# Simple bound to avoid unbounded memory growth.
|
||||||
|
self._principal_cache.clear()
|
||||||
|
|
||||||
|
return enriched
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
|
|||||||
Reference in New Issue
Block a user