refactor: align backend with company-site-member schema and system-level RBAC groups
This commit is contained in:
35
app/repositories/companies_repo.py
Normal file
35
app/repositories/companies_repo.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.company import Company
|
||||
|
||||
|
||||
class CompaniesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, company_key: str) -> Company | None:
|
||||
stmt = select(Company).where(Company.company_key == company_key)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_by_id(self, company_id: str) -> Company | None:
|
||||
stmt = select(Company).where(Company.id == company_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list(self, keyword: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[Company], int]:
|
||||
stmt = select(Company)
|
||||
count_stmt = select(func.count()).select_from(Company)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(Company.company_key.ilike(pattern), Company.name.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, company_key: str, name: str, status: str = "active") -> Company:
|
||||
item = Company(company_key=company_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
@@ -1,43 +0,0 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.member_organization import MemberOrganization
|
||||
from app.models.organization import Organization
|
||||
|
||||
|
||||
class MemberOrganizationsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_organizations_by_member_id(self, member_id: str) -> list[Organization]:
|
||||
stmt = (
|
||||
select(Organization)
|
||||
.join(MemberOrganization, MemberOrganization.organization_id == Organization.id)
|
||||
.where(MemberOrganization.member_id == member_id)
|
||||
.order_by(Organization.name.asc())
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def add_if_not_exists(self, member_id: str, organization_id: str) -> MemberOrganization:
|
||||
stmt = select(MemberOrganization).where(
|
||||
MemberOrganization.member_id == member_id,
|
||||
MemberOrganization.organization_id == organization_id,
|
||||
)
|
||||
existing = self.db.scalar(stmt)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
relation = MemberOrganization(member_id=member_id, organization_id=organization_id)
|
||||
self.db.add(relation)
|
||||
self.db.commit()
|
||||
self.db.refresh(relation)
|
||||
return relation
|
||||
|
||||
def remove(self, member_id: str, organization_id: str) -> int:
|
||||
stmt = delete(MemberOrganization).where(
|
||||
MemberOrganization.member_id == member_id,
|
||||
MemberOrganization.organization_id == organization_id,
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
26
app/repositories/modules_repo.py
Normal file
26
app/repositories/modules_repo.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.module import Module
|
||||
|
||||
|
||||
class ModulesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, module_key: str) -> Module | None:
|
||||
stmt = select(Module).where(Module.module_key == module_key)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list(self, limit: int = 200, offset: int = 0) -> tuple[list[Module], int]:
|
||||
stmt = select(Module)
|
||||
count_stmt = select(func.count()).select_from(Module)
|
||||
stmt = stmt.order_by(Module.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, module_key: str, name: str, status: str = "active") -> Module:
|
||||
item = Module(module_key=module_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
@@ -1,67 +0,0 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.organization import Organization
|
||||
|
||||
|
||||
class OrganizationsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list(
|
||||
self,
|
||||
keyword: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[Organization], int]:
|
||||
stmt = select(Organization)
|
||||
count_stmt = select(func.count()).select_from(Organization)
|
||||
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(Organization.org_code.ilike(pattern), Organization.name.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(Organization.status == status)
|
||||
count_stmt = count_stmt.where(Organization.status == status)
|
||||
|
||||
stmt = stmt.order_by(Organization.created_at.desc()).limit(limit).offset(offset)
|
||||
items = list(self.db.scalars(stmt).all())
|
||||
total = int(self.db.scalar(count_stmt) or 0)
|
||||
return items, total
|
||||
|
||||
def get_by_id(self, org_id: str) -> Organization | None:
|
||||
stmt = select(Organization).where(Organization.id == org_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_by_code(self, org_code: str) -> Organization | None:
|
||||
stmt = select(Organization).where(Organization.org_code == org_code)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def create(self, org_code: str, name: str, tax_id: str | None, status: str = "active") -> Organization:
|
||||
org = Organization(org_code=org_code, name=name, tax_id=tax_id, status=status)
|
||||
self.db.add(org)
|
||||
self.db.commit()
|
||||
self.db.refresh(org)
|
||||
return org
|
||||
|
||||
def update(
|
||||
self,
|
||||
org: Organization,
|
||||
*,
|
||||
name: str | None = None,
|
||||
tax_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> Organization:
|
||||
if name is not None:
|
||||
org.name = name
|
||||
if tax_id is not None:
|
||||
org.tax_id = tax_id
|
||||
if status is not None:
|
||||
org.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(org)
|
||||
return org
|
||||
106
app/repositories/permission_groups_repo.py
Normal file
106
app/repositories/permission_groups_repo.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.permission_group import PermissionGroup
|
||||
from app.models.permission_group_member import PermissionGroupMember
|
||||
from app.models.permission_group_permission import PermissionGroupPermission
|
||||
|
||||
|
||||
class PermissionGroupsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, group_key: str) -> PermissionGroup | None:
|
||||
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key))
|
||||
|
||||
def get_by_id(self, group_id: str) -> PermissionGroup | None:
|
||||
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id))
|
||||
|
||||
def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]:
|
||||
stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset)
|
||||
count_stmt = select(func.count()).select_from(PermissionGroup)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup:
|
||||
item = PermissionGroup(group_key=group_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def add_member_if_not_exists(self, group_id: str, authentik_sub: str) -> PermissionGroupMember:
|
||||
existing = self.db.scalar(
|
||||
select(PermissionGroupMember).where(
|
||||
PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
row = PermissionGroupMember(group_id=group_id, authentik_sub=authentik_sub)
|
||||
self.db.add(row)
|
||||
self.db.commit()
|
||||
self.db.refresh(row)
|
||||
return row
|
||||
|
||||
def remove_member(self, group_id: str, authentik_sub: str) -> int:
|
||||
result = self.db.execute(
|
||||
delete(PermissionGroupMember).where(
|
||||
PermissionGroupMember.group_id == group_id, PermissionGroupMember.authentik_sub == authentik_sub
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def grant_group_permission(
|
||||
self,
|
||||
group_id: str,
|
||||
system: str,
|
||||
module: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
) -> PermissionGroupPermission:
|
||||
where = [
|
||||
PermissionGroupPermission.group_id == group_id,
|
||||
PermissionGroupPermission.system == system,
|
||||
PermissionGroupPermission.module == module,
|
||||
PermissionGroupPermission.action == action,
|
||||
PermissionGroupPermission.scope_type == scope_type,
|
||||
PermissionGroupPermission.scope_id == scope_id,
|
||||
]
|
||||
existing = self.db.scalar(select(PermissionGroupPermission).where(*where))
|
||||
if existing:
|
||||
return existing
|
||||
row = PermissionGroupPermission(
|
||||
group_id=group_id,
|
||||
system=system,
|
||||
module=module,
|
||||
action=action,
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
)
|
||||
self.db.add(row)
|
||||
self.db.commit()
|
||||
self.db.refresh(row)
|
||||
return row
|
||||
|
||||
def revoke_group_permission(
|
||||
self,
|
||||
group_id: str,
|
||||
system: str,
|
||||
module: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
) -> int:
|
||||
stmt = delete(PermissionGroupPermission).where(
|
||||
PermissionGroupPermission.group_id == group_id,
|
||||
PermissionGroupPermission.system == system,
|
||||
PermissionGroupPermission.module == module,
|
||||
PermissionGroupPermission.action == action,
|
||||
PermissionGroupPermission.scope_type == scope_type,
|
||||
PermissionGroupPermission.scope_id == scope_id,
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
@@ -1,42 +1,93 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import and_, delete, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.permission import Permission
|
||||
from app.models.company import Company
|
||||
from app.models.module import Module
|
||||
from app.models.permission_group_member import PermissionGroupMember
|
||||
from app.models.permission_group_permission import PermissionGroupPermission
|
||||
from app.models.site import Site
|
||||
from app.models.user_scope_permission import UserScopePermission
|
||||
|
||||
|
||||
class PermissionsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_by_user_id(self, user_id: str) -> list[Permission]:
|
||||
stmt = select(Permission).where(Permission.user_id == user_id)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
def list_by_user(self, user_id: str, authentik_sub: str) -> list[tuple[str, str, str | None, str, str]]:
|
||||
direct_stmt = (
|
||||
select(
|
||||
UserScopePermission.scope_type,
|
||||
Company.company_key,
|
||||
Site.site_key,
|
||||
Module.module_key,
|
||||
UserScopePermission.action,
|
||||
)
|
||||
.select_from(UserScopePermission)
|
||||
.join(Module, Module.id == UserScopePermission.module_id)
|
||||
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
|
||||
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
|
||||
.where(UserScopePermission.user_id == user_id)
|
||||
)
|
||||
group_stmt = (
|
||||
select(
|
||||
PermissionGroupPermission.scope_type,
|
||||
PermissionGroupPermission.scope_id,
|
||||
PermissionGroupPermission.system,
|
||||
PermissionGroupPermission.module,
|
||||
PermissionGroupPermission.action,
|
||||
)
|
||||
.select_from(PermissionGroupPermission)
|
||||
.join(PermissionGroupMember, PermissionGroupMember.group_id == PermissionGroupPermission.group_id)
|
||||
.where(PermissionGroupMember.authentik_sub == authentik_sub)
|
||||
)
|
||||
rows = self.db.execute(direct_stmt).all() + self.db.execute(group_stmt).all()
|
||||
result: list[tuple[str, str, str | None, str, str]] = []
|
||||
dedup = set()
|
||||
for row in rows:
|
||||
if len(row) == 5:
|
||||
scope_type, scope_id, system_key, module_key, action = row
|
||||
else:
|
||||
scope_type, company_key, site_key, module_key, action = row
|
||||
scope_id = company_key if scope_type == "company" else site_key
|
||||
system_key = module_key.split(".", 1)[0] if isinstance(module_key, str) and "." in module_key else None
|
||||
key = (scope_type, scope_id or "", system_key, module_key, action)
|
||||
if key in dedup:
|
||||
continue
|
||||
dedup.add(key)
|
||||
result.append(key)
|
||||
return result
|
||||
|
||||
def create_if_not_exists(
|
||||
self,
|
||||
user_id: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
module: str,
|
||||
module_id: str,
|
||||
action: str,
|
||||
) -> Permission:
|
||||
stmt = select(Permission).where(
|
||||
Permission.user_id == user_id,
|
||||
Permission.scope_type == scope_type,
|
||||
Permission.scope_id == scope_id,
|
||||
Permission.module == module,
|
||||
Permission.action == action,
|
||||
)
|
||||
existing = self.db.scalar(stmt)
|
||||
scope_type: str,
|
||||
company_id: str | None,
|
||||
site_id: str | None,
|
||||
) -> UserScopePermission:
|
||||
where_expr = [
|
||||
UserScopePermission.user_id == user_id,
|
||||
UserScopePermission.module_id == module_id,
|
||||
UserScopePermission.action == action,
|
||||
UserScopePermission.scope_type == scope_type,
|
||||
]
|
||||
if scope_type == "company":
|
||||
where_expr.append(UserScopePermission.company_id == company_id)
|
||||
else:
|
||||
where_expr.append(UserScopePermission.site_id == site_id)
|
||||
|
||||
existing = self.db.scalar(select(UserScopePermission).where(and_(*where_expr)))
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
item = Permission(
|
||||
item = UserScopePermission(
|
||||
user_id=user_id,
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
module=module,
|
||||
module_id=module_id,
|
||||
action=action,
|
||||
scope_type=scope_type,
|
||||
company_id=company_id,
|
||||
site_id=site_id,
|
||||
)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
@@ -46,17 +97,21 @@ class PermissionsRepository:
|
||||
def revoke(
|
||||
self,
|
||||
user_id: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
module: str,
|
||||
module_id: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
company_id: str | None,
|
||||
site_id: str | None,
|
||||
) -> int:
|
||||
stmt = delete(Permission).where(
|
||||
Permission.user_id == user_id,
|
||||
Permission.scope_type == scope_type,
|
||||
Permission.scope_id == scope_id,
|
||||
Permission.module == module,
|
||||
Permission.action == action,
|
||||
stmt = delete(UserScopePermission).where(
|
||||
UserScopePermission.user_id == user_id,
|
||||
UserScopePermission.module_id == module_id,
|
||||
UserScopePermission.action == action,
|
||||
UserScopePermission.scope_type == scope_type,
|
||||
or_(
|
||||
and_(scope_type == "company", UserScopePermission.company_id == company_id),
|
||||
and_(scope_type == "site", UserScopePermission.site_id == site_id),
|
||||
),
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
40
app/repositories/sites_repo.py
Normal file
40
app/repositories/sites_repo.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.site import Site
|
||||
|
||||
|
||||
class SitesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, site_key: str) -> Site | None:
|
||||
stmt = select(Site).where(Site.site_key == site_key)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list(
|
||||
self,
|
||||
keyword: str | None = None,
|
||||
company_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[Site], int]:
|
||||
stmt = select(Site)
|
||||
count_stmt = select(func.count()).select_from(Site)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(Site.site_key.ilike(pattern), Site.name.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if company_id:
|
||||
stmt = stmt.where(Site.company_id == company_id)
|
||||
count_stmt = count_stmt.where(Site.company_id == company_id)
|
||||
stmt = stmt.order_by(Site.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, site_key: str, company_id: str, name: str, status: str = "active") -> Site:
|
||||
item = Site(site_key=site_key, company_id=company_id, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
33
app/repositories/systems_repo.py
Normal file
33
app/repositories/systems_repo.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.system import System
|
||||
|
||||
|
||||
class SystemsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, system_key: str) -> System | None:
|
||||
stmt = select(System).where(System.system_key == system_key)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def get_by_id(self, system_id: str) -> System | None:
|
||||
stmt = select(System).where(System.id == system_id)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list(self, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[System], int]:
|
||||
stmt = select(System)
|
||||
count_stmt = select(func.count()).select_from(System)
|
||||
if status:
|
||||
stmt = stmt.where(System.status == status)
|
||||
count_stmt = count_stmt.where(System.status == status)
|
||||
stmt = stmt.order_by(System.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, system_key: str, name: str, status: str = "active") -> System:
|
||||
item = System(system_key=system_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
Reference in New Issue
Block a user