148 lines
5.9 KiB
Python
148 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy import delete, func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models.permission_group import PermissionGroup
|
|
from app.models.permission_group_member import PermissionGroupMember
|
|
from app.models.permission_group_permission import PermissionGroupPermission
|
|
|
|
|
|
class PermissionGroupsRepository:
|
|
def __init__(self, db: Session) -> None:
|
|
self.db = db
|
|
|
|
def get_by_key(self, group_key: str) -> PermissionGroup | None:
|
|
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key))
|
|
|
|
def get_by_id(self, group_id: str) -> PermissionGroup | None:
|
|
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id))
|
|
|
|
def get_by_keys(self, group_keys: list[str]) -> list[PermissionGroup]:
|
|
if not group_keys:
|
|
return []
|
|
stmt = select(PermissionGroup).where(PermissionGroup.group_key.in_(group_keys))
|
|
return list(self.db.scalars(stmt).all())
|
|
|
|
def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]:
|
|
stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset)
|
|
count_stmt = select(func.count()).select_from(PermissionGroup)
|
|
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
|
|
|
def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup:
|
|
item = PermissionGroup(group_key=group_key, name=name, status=status)
|
|
self.db.add(item)
|
|
self.db.commit()
|
|
self.db.refresh(item)
|
|
return item
|
|
|
|
def update(self, item: PermissionGroup, *, name: str | None = None, status: str | None = None) -> PermissionGroup:
|
|
if name is not None:
|
|
item.name = name
|
|
if status is not None:
|
|
item.status = status
|
|
self.db.commit()
|
|
self.db.refresh(item)
|
|
return item
|
|
|
|
def add_member_if_not_exists(self, group_id: str, authentik_sub: str) -> PermissionGroupMember:
|
|
existing = self.db.scalar(
|
|
select(PermissionGroupMember).where(
|
|
PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub
|
|
)
|
|
)
|
|
if existing:
|
|
return existing
|
|
row = PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub)
|
|
self.db.add(row)
|
|
self.db.commit()
|
|
self.db.refresh(row)
|
|
return row
|
|
|
|
def remove_member(self, group_id: str, authentik_sub: str) -> int:
|
|
result = self.db.execute(
|
|
delete(PermissionGroupMember).where(
|
|
PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub
|
|
)
|
|
)
|
|
self.db.commit()
|
|
return int(result.rowcount or 0)
|
|
|
|
def list_group_keys_by_member_sub(self, authentik_sub: str) -> list[str]:
|
|
stmt = (
|
|
select(PermissionGroup.group_key)
|
|
.select_from(PermissionGroupMember)
|
|
.join(PermissionGroup, PermissionGroup.id == PermissionGroupMember.group_id)
|
|
.where(PermissionGroupMember.authentik_sub == authentik_sub)
|
|
.order_by(PermissionGroup.group_key.asc())
|
|
)
|
|
return [row[0] for row in self.db.execute(stmt).all()]
|
|
|
|
def replace_member_groups(self, authentik_sub: str, group_ids: list[str]) -> None:
|
|
self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.authentik_sub == authentik_sub))
|
|
for group_id in group_ids:
|
|
self.db.add(PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub))
|
|
self.db.commit()
|
|
|
|
def grant_group_permission(
|
|
self,
|
|
group_id: str,
|
|
system: str,
|
|
module: str,
|
|
action: str,
|
|
scope_type: str,
|
|
scope_id: str,
|
|
) -> PermissionGroupPermission:
|
|
where = [
|
|
PermissionGroupPermission.group_id == group_id,
|
|
PermissionGroupPermission.system == system,
|
|
PermissionGroupPermission.module == module,
|
|
PermissionGroupPermission.action == action,
|
|
PermissionGroupPermission.scope_type == scope_type,
|
|
PermissionGroupPermission.scope_id == scope_id,
|
|
]
|
|
existing = self.db.scalar(select(PermissionGroupPermission).where(*where))
|
|
if existing:
|
|
return existing
|
|
row = PermissionGroupPermission(
|
|
group_id=group_id,
|
|
system=system,
|
|
module=module,
|
|
action=action,
|
|
scope_type=scope_type,
|
|
scope_id=scope_id,
|
|
)
|
|
self.db.add(row)
|
|
self.db.commit()
|
|
self.db.refresh(row)
|
|
return row
|
|
|
|
def list_group_permissions(self, group_id: str) -> list[PermissionGroupPermission]:
|
|
stmt = (
|
|
select(PermissionGroupPermission)
|
|
.where(PermissionGroupPermission.group_id == group_id)
|
|
.order_by(PermissionGroupPermission.scope_type.asc(), PermissionGroupPermission.scope_id.asc(), PermissionGroupPermission.system.asc(), PermissionGroupPermission.module.asc(), PermissionGroupPermission.action.asc())
|
|
)
|
|
return list(self.db.scalars(stmt).all())
|
|
|
|
def revoke_group_permission(
|
|
self,
|
|
group_id: str,
|
|
system: str,
|
|
module: str,
|
|
action: str,
|
|
scope_type: str,
|
|
scope_id: str,
|
|
) -> int:
|
|
stmt = delete(PermissionGroupPermission).where(
|
|
PermissionGroupPermission.group_id == group_id,
|
|
PermissionGroupPermission.system == system,
|
|
PermissionGroupPermission.module == module,
|
|
PermissionGroupPermission.action == action,
|
|
PermissionGroupPermission.scope_type == scope_type,
|
|
PermissionGroupPermission.scope_id == scope_id,
|
|
)
|
|
result = self.db.execute(stmt)
|
|
self.db.commit()
|
|
return int(result.rowcount or 0)
|