Files
member-backend/app/security/api_client_auth.py

80 lines
2.8 KiB
Python

import hashlib
import hmac
from datetime import datetime, timezone
from fastapi import Depends, Header, HTTPException, Request, status
from passlib.context import CryptContext
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.db.session import get_db
from app.models.api_client import ApiClient
pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
def hash_api_key(plain_key: str) -> str:
return pwd_context.hash(plain_key)
def _verify_api_key(plain_key: str, stored_hash: str) -> bool:
# Support sha256:<hex> for bootstrap, and bcrypt/argon2 for production.
if stored_hash.startswith("sha256:"):
hex_hash = hashlib.sha256(plain_key.encode("utf-8")).hexdigest()
return hmac.compare_digest(stored_hash.removeprefix("sha256:"), hex_hash)
try:
return pwd_context.verify(plain_key, stored_hash)
except Exception:
return False
def _is_expired(expires_at: datetime | None) -> bool:
if expires_at is None:
return False
now = datetime.now(timezone.utc)
if expires_at.tzinfo is None:
expires_at = expires_at.replace(tzinfo=timezone.utc)
return expires_at <= now
def _check_request_whitelist(client: ApiClient, request: Request) -> None:
origin = request.headers.get("origin")
client_ip = request.client.host if request.client else None
path = request.url.path
if client.allowed_origins and origin and origin not in client.allowed_origins:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="origin_not_allowed")
if client.allowed_ips and client_ip and client_ip not in client.allowed_ips:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="ip_not_allowed")
if client.allowed_paths and not any(path.startswith(prefix) for prefix in client.allowed_paths):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="path_not_allowed")
def require_api_client(
request: Request,
x_client_key: str = Header(alias="X-Client-Key"),
x_api_key: str = Header(alias="X-API-Key"),
db: Session = Depends(get_db),
) -> ApiClient:
stmt = select(ApiClient).where(ApiClient.client_key == x_client_key)
client = db.scalar(stmt)
if client is None or client.status != "active":
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_client")
if _is_expired(client.expires_at):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="client_expired")
if not _verify_api_key(x_api_key, client.api_key_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_api_key")
_check_request_whitelist(client, request)
client.last_used_at = datetime.now(timezone.utc)
db.commit()
return client