from __future__ import annotations from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from app.models.permission_group import PermissionGroup from app.models.permission_group_member import PermissionGroupMember from app.models.permission_group_permission import PermissionGroupPermission from app.models.user import User class PermissionGroupsRepository: def __init__(self, db: Session) -> None: self.db = db def get_by_key(self, group_key: str) -> PermissionGroup | None: return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key)) def get_by_id(self, group_id: str) -> PermissionGroup | None: return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id)) def get_by_keys(self, group_keys: list[str]) -> list[PermissionGroup]: if not group_keys: return [] stmt = select(PermissionGroup).where(PermissionGroup.group_key.in_(group_keys)) return list(self.db.scalars(stmt).all()) def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]: stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset) count_stmt = select(func.count()).select_from(PermissionGroup) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup: item = PermissionGroup(group_key=group_key, name=name, status=status) self.db.add(item) self.db.commit() self.db.refresh(item) return item def update(self, item: PermissionGroup, *, name: str | None = None, status: str | None = None) -> PermissionGroup: if name is not None: item.name = name if status is not None: item.status = status self.db.commit() self.db.refresh(item) return item def add_member_if_not_exists(self, group_id: str, authentik_sub: str) -> PermissionGroupMember: existing = self.db.scalar( select(PermissionGroupMember).where( PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub ) ) if existing: return existing row = PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub) self.db.add(row) self.db.commit() self.db.refresh(row) return row def remove_member(self, group_id: str, authentik_sub: str) -> int: result = self.db.execute( delete(PermissionGroupMember).where( PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub ) ) self.db.commit() return int(result.rowcount or 0) def list_group_keys_by_member_sub(self, authentik_sub: str) -> list[str]: stmt = ( select(PermissionGroup.group_key) .select_from(PermissionGroupMember) .join(PermissionGroup, PermissionGroup.id == PermissionGroupMember.group_id) .where(PermissionGroupMember.authentik_sub == authentik_sub) .order_by(PermissionGroup.group_key.asc()) ) return [row[0] for row in self.db.execute(stmt).all()] def replace_member_groups(self, authentik_sub: str, group_ids: list[str]) -> None: self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.authentik_sub == authentik_sub)) for group_id in group_ids: self.db.add(PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub)) self.db.commit() def grant_group_permission( self, group_id: str, system: str, module: str, action: str, scope_type: str, scope_id: str, ) -> PermissionGroupPermission: where = [ PermissionGroupPermission.group_id == group_id, PermissionGroupPermission.system == system, PermissionGroupPermission.module == module, PermissionGroupPermission.action == action, PermissionGroupPermission.scope_type == scope_type, PermissionGroupPermission.scope_id == scope_id, ] existing = self.db.scalar(select(PermissionGroupPermission).where(*where)) if existing: return existing row = PermissionGroupPermission( group_id=group_id, system=system, module=module, action=action, scope_type=scope_type, scope_id=scope_id, ) self.db.add(row) self.db.commit() self.db.refresh(row) return row def list_group_permissions(self, group_id: str) -> list[PermissionGroupPermission]: stmt = ( select(PermissionGroupPermission) .where(PermissionGroupPermission.group_id == group_id) .order_by(PermissionGroupPermission.scope_type.asc(), PermissionGroupPermission.scope_id.asc(), PermissionGroupPermission.system.asc(), PermissionGroupPermission.module.asc(), PermissionGroupPermission.action.asc()) ) return list(self.db.scalars(stmt).all()) def replace_group_bindings( self, *, group_id: str, site_keys: list[str], system_keys: list[str], module_keys: list[str], member_subs: list[str], actions: list[str], ) -> None: normalized_sites = list(dict.fromkeys([s for s in site_keys if s])) normalized_actions = [a for a in list(dict.fromkeys(actions)) if a in {"view", "edit"}] normalized_member_subs = list(dict.fromkeys([s for s in member_subs if s])) normalized_systems = set([s for s in system_keys if s]) module_pairs = [] for pair in module_keys: if "|" not in pair: continue system_key, module_key = pair.split("|", 1) if not system_key or not module_key: continue module_pairs.append((system_key, module_key)) normalized_systems.add(system_key) self.db.execute(delete(PermissionGroupPermission).where(PermissionGroupPermission.group_id == group_id)) self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.group_id == group_id)) for sub in normalized_member_subs: self.db.add(PermissionGroupMember(group_id=group_id, authentik_sub=sub)) for site_key in normalized_sites: for action in normalized_actions: for system_key in sorted(normalized_systems): module_names = [m for s, m in module_pairs if s == system_key] or ["__system__"] for module_name in module_names: self.db.add( PermissionGroupPermission( group_id=group_id, system=system_key, module=module_name, action=action, scope_type="site", scope_id=site_key, ) ) self.db.commit() def get_group_binding_snapshot(self, group_id: str, group_key: str) -> dict: permissions = self.list_group_permissions(group_id) site_keys = sorted({p.scope_id for p in permissions if p.scope_type == "site"}) system_keys = sorted({p.system for p in permissions}) actions = sorted({p.action for p in permissions if p.action in {"view", "edit"}}) module_keys = sorted( { f"{p.system}|{p.module}" for p in permissions if p.module and p.module != "__system__" } ) member_subs = sorted(self.list_group_member_subs(group_id)) return { "group_key": group_key, "site_keys": site_keys, "system_keys": system_keys, "module_keys": module_keys, "member_subs": member_subs, "actions": actions, } def list_group_member_subs(self, group_id: str) -> list[str]: stmt = ( select(PermissionGroupMember.authentik_sub) .where(PermissionGroupMember.group_id == group_id) .order_by(PermissionGroupMember.authentik_sub.asc()) ) return [row[0] for row in self.db.execute(stmt).all()] def list_system_groups(self, system_key: str) -> list[PermissionGroup]: stmt = ( select(PermissionGroup) .join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id) .where(PermissionGroupPermission.system == system_key) .order_by(PermissionGroup.name.asc()) .distinct() ) return list(self.db.scalars(stmt).all()) def list_system_members(self, system_key: str) -> list[User]: stmt = ( select(User) .join(PermissionGroupMember, PermissionGroupMember.authentik_sub == User.authentik_sub) .join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id) .where(PermissionGroupPermission.system == system_key) .order_by(User.email.asc(), User.authentik_sub.asc()) .distinct() ) return list(self.db.scalars(stmt).all()) def list_module_groups(self, system_key: str, module_name: str) -> list[PermissionGroup]: stmt = ( select(PermissionGroup) .join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id) .where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name) .order_by(PermissionGroup.name.asc()) .distinct() ) return list(self.db.scalars(stmt).all()) def list_module_members(self, system_key: str, module_name: str) -> list[User]: stmt = ( select(User) .join(PermissionGroupMember, PermissionGroupMember.authentik_sub == User.authentik_sub) .join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id) .where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name) .order_by(User.email.asc(), User.authentik_sub.asc()) .distinct() ) return list(self.db.scalars(stmt).all()) def revoke_group_permission( self, group_id: str, system: str, module: str, action: str, scope_type: str, scope_id: str, ) -> int: stmt = delete(PermissionGroupPermission).where( PermissionGroupPermission.group_id == group_id, PermissionGroupPermission.system == system, PermissionGroupPermission.module == module, PermissionGroupPermission.action == action, PermissionGroupPermission.scope_type == scope_type, PermissionGroupPermission.scope_id == scope_id, ) result = self.db.execute(stmt) self.db.commit() return int(result.rowcount or 0)