from __future__ import annotations from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from app.models.api_client import ApiClient class ApiClientsRepository: def __init__(self, db: Session) -> None: self.db = db def get_by_key(self, client_key: str) -> ApiClient | None: return self.db.scalar(select(ApiClient).where(ApiClient.client_key == client_key)) def list(self, *, keyword: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[ApiClient], int]: stmt = select(ApiClient) count_stmt = select(func.count()).select_from(ApiClient) if keyword: pattern = f"%{keyword}%" cond = or_(ApiClient.client_key.ilike(pattern), ApiClient.name.ilike(pattern)) stmt = stmt.where(cond) count_stmt = count_stmt.where(cond) if status: stmt = stmt.where(ApiClient.status == status) count_stmt = count_stmt.where(ApiClient.status == status) stmt = stmt.order_by(ApiClient.created_at.desc()).limit(limit).offset(offset) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) def create( self, *, client_key: str, name: str, status: str, api_key_hash: str, allowed_origins: list[str], allowed_ips: list[str], allowed_paths: list[str], rate_limit_per_min: int | None, expires_at, ) -> ApiClient: item = ApiClient( client_key=client_key, name=name, status=status, api_key_hash=api_key_hash, allowed_origins=allowed_origins, allowed_ips=allowed_ips, allowed_paths=allowed_paths, rate_limit_per_min=rate_limit_per_min, expires_at=expires_at, ) self.db.add(item) self.db.commit() self.db.refresh(item) return item def update( self, item: ApiClient, *, name: str | None = None, status: str | None = None, api_key_hash: str | None = None, allowed_origins: list[str] | None = None, allowed_ips: list[str] | None = None, allowed_paths: list[str] | None = None, rate_limit_per_min: int | None = None, expires_at=None, ) -> ApiClient: if name is not None: item.name = name if status is not None: item.status = status if api_key_hash is not None: item.api_key_hash = api_key_hash if allowed_origins is not None: item.allowed_origins = allowed_origins if allowed_ips is not None: item.allowed_ips = allowed_ips if allowed_paths is not None: item.allowed_paths = allowed_paths if rate_limit_per_min is not None: item.rate_limit_per_min = rate_limit_per_min if expires_at is not None: item.expires_at = expires_at self.db.commit() self.db.refresh(item) return item def delete(self, item: ApiClient) -> None: self.db.delete(item) self.db.commit()