358 lines
14 KiB
Python
358 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from functools import lru_cache
|
|
import logging
|
|
import time
|
|
|
|
import httpx
|
|
import jwt
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
from app.core.config import get_settings
|
|
from app.schemas.auth import ProviderPrincipal
|
|
|
|
bearer_scheme = HTTPBearer(auto_error=False)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProviderTokenVerifier:
|
|
def __init__(
|
|
self,
|
|
issuer: str | None,
|
|
jwks_url: str | None,
|
|
audience: str | None,
|
|
client_id: str | None,
|
|
client_secret: str | None,
|
|
base_url: str | None,
|
|
userinfo_endpoint: str | None,
|
|
verify_tls: bool,
|
|
realm: str | None,
|
|
admin_realm: str | None,
|
|
admin_client_id: str | None,
|
|
admin_client_secret: str | None,
|
|
member_required_realm_roles: list[str],
|
|
) -> None:
|
|
self.issuer = issuer.strip() if issuer else None
|
|
self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer)
|
|
self.audience = audience.strip() if audience else None
|
|
self.client_id = client_id.strip() if client_id else None
|
|
self.client_secret = client_secret.strip() if client_secret else None
|
|
self.base_url = base_url.strip() if base_url else None
|
|
self.realm = realm.strip() if realm else None
|
|
self.admin_realm = admin_realm.strip() if admin_realm else self.realm
|
|
self.admin_client_id = admin_client_id.strip() if admin_client_id else None
|
|
self.admin_client_secret = admin_client_secret.strip() if admin_client_secret else None
|
|
self.userinfo_endpoint = (
|
|
userinfo_endpoint.strip() if userinfo_endpoint else self._infer_userinfo_endpoint(self.issuer, self.base_url)
|
|
)
|
|
self.verify_tls = verify_tls
|
|
|
|
if not self.jwks_url:
|
|
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
|
|
|
self._jwk_client = jwt.PyJWKClient(
|
|
self.jwks_url,
|
|
cache_jwk_set=True,
|
|
lifespan=600,
|
|
headers={
|
|
"Accept": "application/json",
|
|
"User-Agent": "member-ose-backend/1.0",
|
|
},
|
|
timeout=5,
|
|
)
|
|
self._admin_token_cached: str | None = None
|
|
self._admin_token_expires_at: float = 0
|
|
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
|
|
self.member_required_realm_roles = {r.strip().lower() for r in member_required_realm_roles if r and r.strip()}
|
|
|
|
@staticmethod
|
|
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
|
if not issuer:
|
|
return None
|
|
normalized = issuer.rstrip("/")
|
|
if "/realms/" in normalized:
|
|
return normalized + "/protocol/openid-connect/token/introspect"
|
|
return None
|
|
|
|
def _introspect_token(self, token: str) -> dict | None:
|
|
endpoint = self._infer_introspection_endpoint(self.issuer)
|
|
if not endpoint or not self.client_id or not self.client_secret:
|
|
return None
|
|
try:
|
|
resp = httpx.post(
|
|
endpoint,
|
|
timeout=8,
|
|
verify=self.verify_tls,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
data={
|
|
"token": token,
|
|
"client_id": self.client_id,
|
|
"client_secret": self.client_secret,
|
|
},
|
|
)
|
|
except Exception:
|
|
return None
|
|
if resp.status_code >= 400:
|
|
return None
|
|
data = resp.json() if resp.content else {}
|
|
if not isinstance(data, dict) or not data.get("active"):
|
|
return None
|
|
return data
|
|
|
|
@staticmethod
|
|
def _infer_jwks_url(issuer: str | None) -> str | None:
|
|
if not issuer:
|
|
return None
|
|
return issuer.rstrip("/") + "/protocol/openid-connect/certs"
|
|
|
|
@staticmethod
|
|
def _infer_userinfo_endpoint(issuer: str | None, base_url: str | None) -> str | None:
|
|
if issuer:
|
|
return issuer.rstrip("/") + "/protocol/openid-connect/userinfo"
|
|
if base_url:
|
|
return base_url.rstrip("/") + "/realms/master/protocol/openid-connect/userinfo"
|
|
return None
|
|
|
|
def _enrich_from_userinfo(self, principal: ProviderPrincipal, token: str) -> ProviderPrincipal:
|
|
if principal.email and (principal.name or principal.preferred_username) and principal.groups:
|
|
return principal
|
|
if not self.userinfo_endpoint:
|
|
return self._enrich_groups_from_admin(principal)
|
|
|
|
try:
|
|
resp = httpx.get(
|
|
self.userinfo_endpoint,
|
|
timeout=5,
|
|
verify=self.verify_tls,
|
|
headers={"Authorization": f"Bearer {token}", "Accept": "application/json"},
|
|
)
|
|
except Exception:
|
|
return self._enrich_groups_from_admin(principal)
|
|
|
|
if resp.status_code >= 400:
|
|
return self._enrich_groups_from_admin(principal)
|
|
|
|
data = resp.json() if resp.content else {}
|
|
sub = data.get("sub")
|
|
if isinstance(sub, str) and sub and sub != principal.sub:
|
|
return self._enrich_groups_from_admin(principal)
|
|
|
|
email = principal.email or (data.get("email") if isinstance(data.get("email"), str) else None)
|
|
name = principal.name or (data.get("name") if isinstance(data.get("name"), str) else None)
|
|
preferred_username = principal.preferred_username or (
|
|
data.get("preferred_username") if isinstance(data.get("preferred_username"), str) else None
|
|
)
|
|
groups = principal.groups
|
|
payload_groups = data.get("groups")
|
|
if isinstance(payload_groups, list):
|
|
groups = [str(g) for g in payload_groups if str(g)]
|
|
enriched = ProviderPrincipal(
|
|
sub=principal.sub,
|
|
email=email,
|
|
name=name,
|
|
preferred_username=preferred_username,
|
|
groups=groups,
|
|
realm_roles=principal.realm_roles,
|
|
)
|
|
return self._enrich_groups_from_admin(enriched)
|
|
|
|
def _get_admin_token(self) -> str | None:
|
|
now = time.time()
|
|
if self._admin_token_cached and now < self._admin_token_expires_at:
|
|
return self._admin_token_cached
|
|
if (
|
|
not self.base_url
|
|
or not self.admin_realm
|
|
or not self.admin_client_id
|
|
or not self.admin_client_secret
|
|
):
|
|
return None
|
|
token_endpoint = f"{self.base_url}/realms/{self.admin_realm}/protocol/openid-connect/token"
|
|
try:
|
|
resp = httpx.post(
|
|
token_endpoint,
|
|
data={
|
|
"grant_type": "client_credentials",
|
|
"client_id": self.admin_client_id,
|
|
"client_secret": self.admin_client_secret,
|
|
},
|
|
timeout=6,
|
|
verify=self.verify_tls,
|
|
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
|
)
|
|
except Exception:
|
|
return None
|
|
if resp.status_code >= 400:
|
|
return None
|
|
token = resp.json().get("access_token")
|
|
expires_in = resp.json().get("expires_in")
|
|
if token:
|
|
ttl = int(expires_in) if isinstance(expires_in, int) else 30
|
|
# Keep a small buffer to avoid using near-expiry admin token.
|
|
self._admin_token_cached = str(token)
|
|
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
|
|
return self._admin_token_cached
|
|
return None
|
|
|
|
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
|
if principal.groups:
|
|
return principal
|
|
if not self.base_url or not self.realm:
|
|
return principal
|
|
admin_token = self._get_admin_token()
|
|
if not admin_token:
|
|
return principal
|
|
|
|
try:
|
|
resp = httpx.get(
|
|
f"{self.base_url}/admin/realms/{self.realm}/users/{principal.sub}/groups",
|
|
timeout=6,
|
|
verify=self.verify_tls,
|
|
headers={"Authorization": f"Bearer {admin_token}", "Accept": "application/json"},
|
|
)
|
|
except Exception:
|
|
return principal
|
|
if resp.status_code >= 400:
|
|
return principal
|
|
|
|
payload = resp.json() if resp.content else []
|
|
groups: list[str] = []
|
|
if isinstance(payload, list):
|
|
for item in payload:
|
|
if not isinstance(item, dict):
|
|
continue
|
|
path = item.get("path")
|
|
name = item.get("name")
|
|
if isinstance(path, str) and path:
|
|
groups.append(path)
|
|
elif isinstance(name, str) and name:
|
|
groups.append(name)
|
|
if not groups:
|
|
return principal
|
|
return ProviderPrincipal(
|
|
sub=principal.sub,
|
|
email=principal.email,
|
|
name=principal.name,
|
|
preferred_username=principal.preferred_username,
|
|
groups=groups,
|
|
realm_roles=principal.realm_roles,
|
|
)
|
|
|
|
def _require_member_role(self, principal: ProviderPrincipal) -> None:
|
|
if not self.member_required_realm_roles:
|
|
return
|
|
user_roles = {r.strip().lower() for r in principal.realm_roles if isinstance(r, str) and r.strip()}
|
|
if not user_roles.intersection(self.member_required_realm_roles):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="member_forbidden")
|
|
|
|
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
|
now = time.time()
|
|
cached = self._principal_cache.get(token)
|
|
if cached and now < cached[0]:
|
|
return cached[1]
|
|
|
|
try:
|
|
header = jwt.get_unverified_header(token)
|
|
algorithm = str(header.get("alg", "")).upper()
|
|
options = {
|
|
"verify_signature": True,
|
|
"verify_exp": True,
|
|
"verify_aud": bool(self.audience),
|
|
"verify_iss": bool(self.issuer),
|
|
}
|
|
|
|
if algorithm.startswith("HS"):
|
|
if not self.client_secret:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="missing_idp_client_secret",
|
|
)
|
|
key = self.client_secret
|
|
allowed_algorithms = ["HS256", "HS384", "HS512"]
|
|
else:
|
|
signing_key = self._jwk_client.get_signing_key_from_jwt(token)
|
|
key = signing_key.key
|
|
allowed_algorithms = ["RS256", "RS384", "RS512"]
|
|
|
|
claims = jwt.decode(
|
|
token,
|
|
key,
|
|
algorithms=allowed_algorithms,
|
|
audience=self.audience,
|
|
issuer=self.issuer,
|
|
options=options,
|
|
)
|
|
except Exception as exc:
|
|
claims = self._introspect_token(token)
|
|
if claims:
|
|
logger.debug("jwt verify failed, used introspection fallback: %s", exc)
|
|
else:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc
|
|
|
|
sub = claims.get("sub")
|
|
if not sub:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="token_missing_sub")
|
|
|
|
principal = ProviderPrincipal(
|
|
sub=sub,
|
|
email=claims.get("email"),
|
|
name=claims.get("name"),
|
|
preferred_username=claims.get("preferred_username"),
|
|
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
|
realm_roles=[
|
|
str(r)
|
|
for r in (
|
|
claims.get("realm_access", {}).get("roles", [])
|
|
if isinstance(claims.get("realm_access"), dict)
|
|
else []
|
|
)
|
|
if str(r)
|
|
],
|
|
)
|
|
enriched = self._enrich_from_userinfo(principal, token)
|
|
self._require_member_role(enriched)
|
|
|
|
exp = claims.get("exp")
|
|
if isinstance(exp, int):
|
|
cache_until = min(float(exp), now + 60)
|
|
else:
|
|
cache_until = now + 30
|
|
if cache_until > now:
|
|
self._principal_cache[token] = (cache_until, enriched)
|
|
if len(self._principal_cache) > 512:
|
|
# Simple bound to avoid unbounded memory growth.
|
|
self._principal_cache.clear()
|
|
|
|
return enriched
|
|
|
|
|
|
@lru_cache
|
|
def _get_verifier() -> ProviderTokenVerifier:
|
|
settings = get_settings()
|
|
return ProviderTokenVerifier(
|
|
issuer=settings.idp_issuer,
|
|
jwks_url=settings.idp_jwks_url,
|
|
audience=settings.idp_audience,
|
|
client_id=settings.idp_client_id,
|
|
client_secret=settings.idp_client_secret,
|
|
base_url=settings.idp_base_url,
|
|
userinfo_endpoint=settings.idp_userinfo_endpoint,
|
|
verify_tls=settings.idp_verify_tls,
|
|
realm=settings.keycloak_realm,
|
|
admin_realm=settings.keycloak_admin_realm,
|
|
admin_client_id=settings.keycloak_admin_client_id,
|
|
admin_client_secret=settings.keycloak_admin_client_secret,
|
|
member_required_realm_roles=settings.member_required_realm_roles,
|
|
)
|
|
|
|
|
|
def require_authenticated_principal(
|
|
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
|
) -> ProviderPrincipal:
|
|
if credentials is None or credentials.scheme.lower() != "bearer":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="missing_bearer_token")
|
|
|
|
verifier = _get_verifier()
|
|
return verifier.verify_access_token(credentials.credentials)
|