diff --git a/app/security/idp_jwt.py b/app/security/idp_jwt.py index 2a74956..5953c9e 100644 --- a/app/security/idp_jwt.py +++ b/app/security/idp_jwt.py @@ -2,6 +2,7 @@ from __future__ import annotations from functools import lru_cache import logging +import time import httpx import jwt @@ -49,7 +50,19 @@ class ProviderTokenVerifier: if not self.jwks_url: raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required") - self._jwk_client = jwt.PyJWKClient(self.jwks_url) + self._jwk_client = jwt.PyJWKClient( + self.jwks_url, + cache_jwk_set=True, + lifespan=600, + headers={ + "Accept": "application/json", + "User-Agent": "member-ose-backend/1.0", + }, + timeout=5, + ) + self._admin_token_cached: str | None = None + self._admin_token_expires_at: float = 0 + self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {} @staticmethod def _infer_introspection_endpoint(issuer: str | None) -> str | None: @@ -142,6 +155,9 @@ class ProviderTokenVerifier: return self._enrich_groups_from_admin(enriched) def _get_admin_token(self) -> str | None: + now = time.time() + if self._admin_token_cached and now < self._admin_token_expires_at: + return self._admin_token_cached if ( not self.base_url or not self.admin_realm @@ -167,7 +183,14 @@ class ProviderTokenVerifier: if resp.status_code >= 400: return None token = resp.json().get("access_token") - return str(token) if token else None + expires_in = resp.json().get("expires_in") + if token: + ttl = int(expires_in) if isinstance(expires_in, int) else 30 + # Keep a small buffer to avoid using near-expiry admin token. + self._admin_token_cached = str(token) + self._admin_token_expires_at = max(now + ttl - 15, now + 5) + return self._admin_token_cached + return None def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal: if principal.groups: @@ -213,6 +236,11 @@ class ProviderTokenVerifier: ) def verify_access_token(self, token: str) -> ProviderPrincipal: + now = time.time() + cached = self._principal_cache.get(token) + if cached and now < cached[0]: + return cached[1] + try: header = jwt.get_unverified_header(token) algorithm = str(header.get("alg", "")).upper() @@ -262,7 +290,20 @@ class ProviderTokenVerifier: preferred_username=claims.get("preferred_username"), groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [], ) - return self._enrich_from_userinfo(principal, token) + enriched = self._enrich_from_userinfo(principal, token) + + exp = claims.get("exp") + if isinstance(exp, int): + cache_until = min(float(exp), now + 60) + else: + cache_until = now + 30 + if cache_until > now: + self._principal_cache[token] = (cache_until, enriched) + if len(self._principal_cache) > 512: + # Simple bound to avoid unbounded memory growth. + self._principal_cache.clear() + + return enriched @lru_cache