from __future__ import annotations from functools import lru_cache import logging import time import httpx import jwt from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.core.config import get_settings from app.schemas.auth import ProviderPrincipal bearer_scheme = HTTPBearer(auto_error=False) logger = logging.getLogger(__name__) class ProviderTokenVerifier: def __init__( self, issuer: str | None, jwks_url: str | None, audience: str | None, client_id: str | None, client_secret: str | None, base_url: str | None, userinfo_endpoint: str | None, verify_tls: bool, realm: str | None, admin_realm: str | None, admin_client_id: str | None, admin_client_secret: str | None, ) -> None: self.issuer = issuer.strip() if issuer else None self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer) self.audience = audience.strip() if audience else None self.client_id = client_id.strip() if client_id else None self.client_secret = client_secret.strip() if client_secret else None self.base_url = base_url.strip() if base_url else None self.realm = realm.strip() if realm else None self.admin_realm = admin_realm.strip() if admin_realm else self.realm self.admin_client_id = admin_client_id.strip() if admin_client_id else None self.admin_client_secret = admin_client_secret.strip() if admin_client_secret else None self.userinfo_endpoint = ( userinfo_endpoint.strip() if userinfo_endpoint else self._infer_userinfo_endpoint(self.issuer, self.base_url) ) self.verify_tls = verify_tls if not self.jwks_url: raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required") self._jwk_client = jwt.PyJWKClient( self.jwks_url, cache_jwk_set=True, lifespan=600, headers={ "Accept": "application/json", "User-Agent": "member-ose-backend/1.0", }, timeout=5, ) self._admin_token_cached: str | None = None self._admin_token_expires_at: float = 0 self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {} @staticmethod def _infer_introspection_endpoint(issuer: str | None) -> str | None: if not issuer: return None normalized = issuer.rstrip("/") if "/realms/" in normalized: return normalized + "/protocol/openid-connect/token/introspect" return None def _introspect_token(self, token: str) -> dict | None: endpoint = self._infer_introspection_endpoint(self.issuer) if not endpoint or not self.client_id or not self.client_secret: return None try: resp = httpx.post( endpoint, timeout=8, verify=self.verify_tls, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "token": token, "client_id": self.client_id, "client_secret": self.client_secret, }, ) except Exception: return None if resp.status_code >= 400: return None data = resp.json() if resp.content else {} if not isinstance(data, dict) or not data.get("active"): return None return data @staticmethod def _infer_jwks_url(issuer: str | None) -> str | None: if not issuer: return None return issuer.rstrip("/") + "/protocol/openid-connect/certs" @staticmethod def _infer_userinfo_endpoint(issuer: str | None, base_url: str | None) -> str | None: if issuer: return issuer.rstrip("/") + "/protocol/openid-connect/userinfo" if base_url: return base_url.rstrip("/") + "/realms/master/protocol/openid-connect/userinfo" return None def _enrich_from_userinfo(self, principal: ProviderPrincipal, token: str) -> ProviderPrincipal: if principal.email and (principal.name or principal.preferred_username) and principal.groups: return principal if not self.userinfo_endpoint: return self._enrich_groups_from_admin(principal) try: resp = httpx.get( self.userinfo_endpoint, timeout=5, verify=self.verify_tls, headers={"Authorization": f"Bearer {token}", "Accept": "application/json"}, ) except Exception: return self._enrich_groups_from_admin(principal) if resp.status_code >= 400: return self._enrich_groups_from_admin(principal) data = resp.json() if resp.content else {} sub = data.get("sub") if isinstance(sub, str) and sub and sub != principal.sub: return self._enrich_groups_from_admin(principal) email = principal.email or (data.get("email") if isinstance(data.get("email"), str) else None) name = principal.name or (data.get("name") if isinstance(data.get("name"), str) else None) preferred_username = principal.preferred_username or ( data.get("preferred_username") if isinstance(data.get("preferred_username"), str) else None ) groups = principal.groups payload_groups = data.get("groups") if isinstance(payload_groups, list): groups = [str(g) for g in payload_groups if str(g)] enriched = ProviderPrincipal( sub=principal.sub, email=email, name=name, preferred_username=preferred_username, groups=groups, ) return self._enrich_groups_from_admin(enriched) def _get_admin_token(self) -> str | None: now = time.time() if self._admin_token_cached and now < self._admin_token_expires_at: return self._admin_token_cached if ( not self.base_url or not self.admin_realm or not self.admin_client_id or not self.admin_client_secret ): return None token_endpoint = f"{self.base_url}/realms/{self.admin_realm}/protocol/openid-connect/token" try: resp = httpx.post( token_endpoint, data={ "grant_type": "client_credentials", "client_id": self.admin_client_id, "client_secret": self.admin_client_secret, }, timeout=6, verify=self.verify_tls, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) except Exception: return None if resp.status_code >= 400: return None token = resp.json().get("access_token") expires_in = resp.json().get("expires_in") if token: ttl = int(expires_in) if isinstance(expires_in, int) else 30 # Keep a small buffer to avoid using near-expiry admin token. self._admin_token_cached = str(token) self._admin_token_expires_at = max(now + ttl - 15, now + 5) return self._admin_token_cached return None def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal: if principal.groups: return principal if not self.base_url or not self.realm: return principal admin_token = self._get_admin_token() if not admin_token: return principal try: resp = httpx.get( f"{self.base_url}/admin/realms/{self.realm}/users/{principal.sub}/groups", timeout=6, verify=self.verify_tls, headers={"Authorization": f"Bearer {admin_token}", "Accept": "application/json"}, ) except Exception: return principal if resp.status_code >= 400: return principal payload = resp.json() if resp.content else [] groups: list[str] = [] if isinstance(payload, list): for item in payload: if not isinstance(item, dict): continue path = item.get("path") name = item.get("name") if isinstance(path, str) and path: groups.append(path) elif isinstance(name, str) and name: groups.append(name) if not groups: return principal return ProviderPrincipal( sub=principal.sub, email=principal.email, name=principal.name, preferred_username=principal.preferred_username, groups=groups, ) def verify_access_token(self, token: str) -> ProviderPrincipal: now = time.time() cached = self._principal_cache.get(token) if cached and now < cached[0]: return cached[1] try: header = jwt.get_unverified_header(token) algorithm = str(header.get("alg", "")).upper() options = { "verify_signature": True, "verify_exp": True, "verify_aud": bool(self.audience), "verify_iss": bool(self.issuer), } if algorithm.startswith("HS"): if not self.client_secret: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_idp_client_secret", ) key = self.client_secret allowed_algorithms = ["HS256", "HS384", "HS512"] else: signing_key = self._jwk_client.get_signing_key_from_jwt(token) key = signing_key.key allowed_algorithms = ["RS256", "RS384", "RS512"] claims = jwt.decode( token, key, algorithms=allowed_algorithms, audience=self.audience, issuer=self.issuer, options=options, ) except Exception as exc: claims = self._introspect_token(token) if claims: logger.debug("jwt verify failed, used introspection fallback: %s", exc) else: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc sub = claims.get("sub") if not sub: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_missing_sub") principal = ProviderPrincipal( sub=sub, email=claims.get("email"), name=claims.get("name"), preferred_username=claims.get("preferred_username"), groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [], ) enriched = self._enrich_from_userinfo(principal, token) exp = claims.get("exp") if isinstance(exp, int): cache_until = min(float(exp), now + 60) else: cache_until = now + 30 if cache_until > now: self._principal_cache[token] = (cache_until, enriched) if len(self._principal_cache) > 512: # Simple bound to avoid unbounded memory growth. self._principal_cache.clear() return enriched @lru_cache def _get_verifier() -> ProviderTokenVerifier: settings = get_settings() return ProviderTokenVerifier( issuer=settings.idp_issuer, jwks_url=settings.idp_jwks_url, audience=settings.idp_audience, client_id=settings.idp_client_id, client_secret=settings.idp_client_secret, base_url=settings.idp_base_url, userinfo_endpoint=settings.idp_userinfo_endpoint, verify_tls=settings.idp_verify_tls, realm=settings.keycloak_realm, admin_realm=settings.keycloak_admin_realm, admin_client_id=settings.keycloak_admin_client_id, admin_client_secret=settings.keycloak_admin_client_secret, ) def require_authenticated_principal( credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> ProviderPrincipal: if credentials is None or credentials.scheme.lower() != "bearer": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_bearer_token") verifier = _get_verifier() return verifier.verify_access_token(credentials.credentials)