from __future__ import annotations from functools import lru_cache import httpx import jwt from fastapi import Depends, HTTPException, status from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from app.core.config import get_settings from app.schemas.auth import AuthentikPrincipal bearer_scheme = HTTPBearer(auto_error=False) class AuthentikTokenVerifier: def __init__( self, issuer: str | None, jwks_url: str | None, audience: str | None, client_secret: str | None, base_url: str | None, userinfo_endpoint: str | None, verify_tls: bool, ) -> None: self.issuer = issuer.strip() if issuer else None self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer) self.audience = audience.strip() if audience else None self.client_secret = client_secret.strip() if client_secret else None self.base_url = base_url.strip() if base_url else None self.userinfo_endpoint = ( userinfo_endpoint.strip() if userinfo_endpoint else self._infer_userinfo_endpoint(self.issuer, self.base_url) ) self.verify_tls = verify_tls if not self.jwks_url: raise ValueError("AUTHENTIK_JWKS_URL or AUTHENTIK_ISSUER is required") self._jwk_client = jwt.PyJWKClient(self.jwks_url) @staticmethod def _infer_jwks_url(issuer: str | None) -> str | None: if not issuer: return None normalized = issuer.rstrip("/") + "/" if normalized.endswith("/jwks/"): return normalized return normalized + "jwks/" @staticmethod def _infer_userinfo_endpoint(issuer: str | None, base_url: str | None) -> str | None: if issuer: normalized = issuer.rstrip("/") if "/realms/" in normalized: return normalized + "/protocol/openid-connect/userinfo" marker = "/application/o/" marker_index = normalized.find(marker) if marker_index != -1: root = normalized[:marker_index] return root + marker + "userinfo/" return normalized + "/userinfo/" if base_url: return base_url.rstrip("/") + "/application/o/userinfo/" return None def _enrich_from_userinfo(self, principal: AuthentikPrincipal, token: str) -> AuthentikPrincipal: if principal.email and (principal.name or principal.preferred_username) and principal.groups: return principal if not self.userinfo_endpoint: return principal try: resp = httpx.get( self.userinfo_endpoint, timeout=5, verify=self.verify_tls, headers={"Authorization": f"Bearer {token}", "Accept": "application/json"}, ) except Exception: return principal if resp.status_code >= 400: return principal data = resp.json() if resp.content else {} sub = data.get("sub") if isinstance(sub, str) and sub and sub != principal.sub: return principal email = principal.email or (data.get("email") if isinstance(data.get("email"), str) else None) name = principal.name or (data.get("name") if isinstance(data.get("name"), str) else None) preferred_username = principal.preferred_username or ( data.get("preferred_username") if isinstance(data.get("preferred_username"), str) else None ) groups = principal.groups payload_groups = data.get("groups") if isinstance(payload_groups, list): groups = [str(g) for g in payload_groups if str(g)] return AuthentikPrincipal( sub=principal.sub, email=email, name=name, preferred_username=preferred_username, groups=groups, ) def verify_access_token(self, token: str) -> AuthentikPrincipal: try: header = jwt.get_unverified_header(token) algorithm = str(header.get("alg", "")).upper() options = { "verify_signature": True, "verify_exp": True, "verify_aud": bool(self.audience), "verify_iss": bool(self.issuer), } if algorithm.startswith("HS"): if not self.client_secret: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_authentik_client_secret", ) key = self.client_secret allowed_algorithms = ["HS256", "HS384", "HS512"] else: signing_key = self._jwk_client.get_signing_key_from_jwt(token) key = signing_key.key allowed_algorithms = ["RS256", "RS384", "RS512"] claims = jwt.decode( token, key, algorithms=allowed_algorithms, audience=self.audience, issuer=self.issuer, options=options, ) except Exception as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc sub = claims.get("sub") if not sub: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_missing_sub") principal = AuthentikPrincipal( sub=sub, email=claims.get("email"), name=claims.get("name"), preferred_username=claims.get("preferred_username"), groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [], ) return self._enrich_from_userinfo(principal, token) @lru_cache def _get_verifier() -> AuthentikTokenVerifier: settings = get_settings() return AuthentikTokenVerifier( issuer=settings.idp_issuer, jwks_url=settings.idp_jwks_url, audience=settings.idp_audience, client_secret=settings.idp_client_secret, base_url=settings.idp_base_url, userinfo_endpoint=settings.idp_userinfo_endpoint, verify_tls=settings.idp_verify_tls, ) def require_authenticated_principal( credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme), ) -> AuthentikPrincipal: if credentials is None or credentials.scheme.lower() != "bearer": raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_bearer_token") verifier = _get_verifier() return verifier.verify_access_token(credentials.credentials)