80 lines
2.8 KiB
Python
80 lines
2.8 KiB
Python
import hashlib
|
|
import hmac
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import Depends, Header, HTTPException, Request, status
|
|
from passlib.context import CryptContext
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db.session import get_db
|
|
from app.models.api_client import ApiClient
|
|
|
|
pwd_context = CryptContext(schemes=["argon2", "bcrypt"], deprecated="auto")
|
|
|
|
|
|
def hash_api_key(plain_key: str) -> str:
|
|
return pwd_context.hash(plain_key)
|
|
|
|
|
|
def _verify_api_key(plain_key: str, stored_hash: str) -> bool:
|
|
# Support sha256:<hex> for bootstrap, and bcrypt/argon2 for production.
|
|
if stored_hash.startswith("sha256:"):
|
|
hex_hash = hashlib.sha256(plain_key.encode("utf-8")).hexdigest()
|
|
return hmac.compare_digest(stored_hash.removeprefix("sha256:"), hex_hash)
|
|
|
|
try:
|
|
return pwd_context.verify(plain_key, stored_hash)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _is_expired(expires_at: datetime | None) -> bool:
|
|
if expires_at is None:
|
|
return False
|
|
now = datetime.now(timezone.utc)
|
|
if expires_at.tzinfo is None:
|
|
expires_at = expires_at.replace(tzinfo=timezone.utc)
|
|
return expires_at <= now
|
|
|
|
|
|
def _check_request_whitelist(client: ApiClient, request: Request) -> None:
|
|
origin = request.headers.get("origin")
|
|
client_ip = request.client.host if request.client else None
|
|
path = request.url.path
|
|
|
|
if client.allowed_origins and origin and origin not in client.allowed_origins:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="origin_not_allowed")
|
|
|
|
if client.allowed_ips and client_ip and client_ip not in client.allowed_ips:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="ip_not_allowed")
|
|
|
|
if client.allowed_paths and not any(path.startswith(prefix) for prefix in client.allowed_paths):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="path_not_allowed")
|
|
|
|
|
|
def require_api_client(
|
|
request: Request,
|
|
x_client_key: str = Header(alias="X-Client-Key"),
|
|
x_api_key: str = Header(alias="X-API-Key"),
|
|
db: Session = Depends(get_db),
|
|
) -> ApiClient:
|
|
stmt = select(ApiClient).where(ApiClient.client_key == x_client_key)
|
|
client = db.scalar(stmt)
|
|
|
|
if client is None or client.status != "active":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_client")
|
|
|
|
if _is_expired(client.expires_at):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="client_expired")
|
|
|
|
if not _verify_api_key(x_api_key, client.api_key_hash):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_api_key")
|
|
|
|
_check_request_whitelist(client, request)
|
|
|
|
client.last_used_at = datetime.now(timezone.utc)
|
|
db.commit()
|
|
|
|
return client
|