from __future__ import annotations from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from app.models.permission_group import PermissionGroup from app.models.permission_group_member import PermissionGroupMember from app.models.permission_group_permission import PermissionGroupPermission class PermissionGroupsRepository: def __init__(self, db: Session) -> None: self.db = db def get_by_key(self, group_key: str) -> PermissionGroup | None: return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key)) def get_by_id(self, group_id: str) -> PermissionGroup | None: return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id)) def get_by_keys(self, group_keys: list[str]) -> list[PermissionGroup]: if not group_keys: return [] stmt = select(PermissionGroup).where(PermissionGroup.group_key.in_(group_keys)) return list(self.db.scalars(stmt).all()) def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]: stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset) count_stmt = select(func.count()).select_from(PermissionGroup) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup: item = PermissionGroup(group_key=group_key, name=name, status=status) self.db.add(item) self.db.commit() self.db.refresh(item) return item def update(self, item: PermissionGroup, *, name: str | None = None, status: str | None = None) -> PermissionGroup: if name is not None: item.name = name if status is not None: item.status = status self.db.commit() self.db.refresh(item) return item def add_member_if_not_exists(self, group_id: str, authentik_sub: str) -> PermissionGroupMember: existing = self.db.scalar( select(PermissionGroupMember).where( PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub ) ) if existing: return existing row = PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub) self.db.add(row) self.db.commit() self.db.refresh(row) return row def remove_member(self, group_id: str, authentik_sub: str) -> int: result = self.db.execute( delete(PermissionGroupMember).where( PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub ) ) self.db.commit() return int(result.rowcount or 0) def list_group_keys_by_member_sub(self, authentik_sub: str) -> list[str]: stmt = ( select(PermissionGroup.group_key) .select_from(PermissionGroupMember) .join(PermissionGroup, PermissionGroup.id == PermissionGroupMember.group_id) .where(PermissionGroupMember.authentik_sub == authentik_sub) .order_by(PermissionGroup.group_key.asc()) ) return [row[0] for row in self.db.execute(stmt).all()] def replace_member_groups(self, authentik_sub: str, group_ids: list[str]) -> None: self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.authentik_sub == authentik_sub)) for group_id in group_ids: self.db.add(PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub)) self.db.commit() def grant_group_permission( self, group_id: str, system: str, module: str, action: str, scope_type: str, scope_id: str, ) -> PermissionGroupPermission: where = [ PermissionGroupPermission.group_id == group_id, PermissionGroupPermission.system == system, PermissionGroupPermission.module == module, PermissionGroupPermission.action == action, PermissionGroupPermission.scope_type == scope_type, PermissionGroupPermission.scope_id == scope_id, ] existing = self.db.scalar(select(PermissionGroupPermission).where(*where)) if existing: return existing row = PermissionGroupPermission( group_id=group_id, system=system, module=module, action=action, scope_type=scope_type, scope_id=scope_id, ) self.db.add(row) self.db.commit() self.db.refresh(row) return row def list_group_permissions(self, group_id: str) -> list[PermissionGroupPermission]: stmt = ( select(PermissionGroupPermission) .where(PermissionGroupPermission.group_id == group_id) .order_by(PermissionGroupPermission.scope_type.asc(), PermissionGroupPermission.scope_id.asc(), PermissionGroupPermission.system.asc(), PermissionGroupPermission.module.asc(), PermissionGroupPermission.action.asc()) ) return list(self.db.scalars(stmt).all()) def revoke_group_permission( self, group_id: str, system: str, module: str, action: str, scope_type: str, scope_id: str, ) -> int: stmt = delete(PermissionGroupPermission).where( PermissionGroupPermission.group_id == group_id, PermissionGroupPermission.system == system, PermissionGroupPermission.module == module, PermissionGroupPermission.action == action, PermissionGroupPermission.scope_type == scope_type, PermissionGroupPermission.scope_id == scope_id, ) result = self.db.execute(stmt) self.db.commit() return int(result.rowcount or 0)