from sqlalchemy import and_, delete, literal, or_, select from sqlalchemy.orm import Session from app.models.company import Company from app.models.module import Module from app.models.permission_group_member import PermissionGroupMember from app.models.permission_group_permission import PermissionGroupPermission from app.models.site import Site from app.models.user_scope_permission import UserScopePermission class PermissionsRepository: def __init__(self, db: Session) -> None: self.db = db def list_by_user(self, user_id: str, authentik_sub: str) -> list[tuple[str, str, str | None, str, str]]: direct_stmt = ( select( literal("direct"), UserScopePermission.scope_type, Company.company_key, Site.site_key, Module.module_key, UserScopePermission.action, ) .select_from(UserScopePermission) .join(Module, Module.id == UserScopePermission.module_id) .join(Company, Company.id == UserScopePermission.company_id, isouter=True) .join(Site, Site.id == UserScopePermission.site_id, isouter=True) .where(UserScopePermission.user_id == user_id) ) group_stmt = ( select( literal("group"), PermissionGroupPermission.scope_type, PermissionGroupPermission.scope_id, PermissionGroupPermission.system, PermissionGroupPermission.module, PermissionGroupPermission.action, ) .select_from(PermissionGroupPermission) .join(PermissionGroupMember, PermissionGroupMember.group_id == PermissionGroupPermission.group_id) .where(PermissionGroupMember.authentik_sub == authentik_sub) ) rows = self.db.execute(direct_stmt).all() + self.db.execute(group_stmt).all() result: list[tuple[str, str, str | None, str, str]] = [] dedup = set() for row in rows: source = row[0] if source == "group": _, scope_type, scope_id, system_key, module_key, action = row else: _, scope_type, company_key, site_key, module_key, action = row scope_id = company_key if scope_type == "company" else site_key system_key = module_key.split(".", 1)[0] if isinstance(module_key, str) and "." in module_key else None key = (scope_type, scope_id or "", system_key, module_key, action) if key in dedup: continue dedup.add(key) result.append(key) return result def create_if_not_exists( self, user_id: str, module_id: str, action: str, scope_type: str, company_id: str | None, site_id: str | None, ) -> UserScopePermission: where_expr = [ UserScopePermission.user_id == user_id, UserScopePermission.module_id == module_id, UserScopePermission.action == action, UserScopePermission.scope_type == scope_type, ] if scope_type == "company": where_expr.append(UserScopePermission.company_id == company_id) else: where_expr.append(UserScopePermission.site_id == site_id) existing = self.db.scalar(select(UserScopePermission).where(and_(*where_expr))) if existing: return existing item = UserScopePermission( user_id=user_id, module_id=module_id, action=action, scope_type=scope_type, company_id=company_id, site_id=site_id, ) self.db.add(item) self.db.commit() self.db.refresh(item) return item def revoke( self, user_id: str, module_id: str, action: str, scope_type: str, company_id: str | None, site_id: str | None, ) -> int: stmt = delete(UserScopePermission).where( UserScopePermission.user_id == user_id, UserScopePermission.module_id == module_id, UserScopePermission.action == action, UserScopePermission.scope_type == scope_type, or_( and_(scope_type == "company", UserScopePermission.company_id == company_id), and_(scope_type == "site", UserScopePermission.site_id == site_id), ), ) result = self.db.execute(stmt) self.db.commit() return int(result.rowcount or 0)