refactor: rebuild backend around role-site authorization model
This commit is contained in:
94
app/repositories/api_clients_repo.py
Normal file
94
app/repositories/api_clients_repo.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.api_client import ApiClient
|
||||
|
||||
|
||||
class ApiClientsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, client_key: str) -> ApiClient | None:
|
||||
return self.db.scalar(select(ApiClient).where(ApiClient.client_key == client_key))
|
||||
|
||||
def list(self, *, keyword: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[ApiClient], int]:
|
||||
stmt = select(ApiClient)
|
||||
count_stmt = select(func.count()).select_from(ApiClient)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(ApiClient.client_key.ilike(pattern), ApiClient.name.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if status:
|
||||
stmt = stmt.where(ApiClient.status == status)
|
||||
count_stmt = count_stmt.where(ApiClient.status == status)
|
||||
|
||||
stmt = stmt.order_by(ApiClient.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
client_key: str,
|
||||
name: str,
|
||||
status: str,
|
||||
api_key_hash: str,
|
||||
allowed_origins: list[str],
|
||||
allowed_ips: list[str],
|
||||
allowed_paths: list[str],
|
||||
rate_limit_per_min: int | None,
|
||||
expires_at,
|
||||
) -> ApiClient:
|
||||
item = ApiClient(
|
||||
client_key=client_key,
|
||||
name=name,
|
||||
status=status,
|
||||
api_key_hash=api_key_hash,
|
||||
allowed_origins=allowed_origins,
|
||||
allowed_ips=allowed_ips,
|
||||
allowed_paths=allowed_paths,
|
||||
rate_limit_per_min=rate_limit_per_min,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(
|
||||
self,
|
||||
item: ApiClient,
|
||||
*,
|
||||
name: str | None = None,
|
||||
status: str | None = None,
|
||||
api_key_hash: str | None = None,
|
||||
allowed_origins: list[str] | None = None,
|
||||
allowed_ips: list[str] | None = None,
|
||||
allowed_paths: list[str] | None = None,
|
||||
rate_limit_per_min: int | None = None,
|
||||
expires_at=None,
|
||||
) -> ApiClient:
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if status is not None:
|
||||
item.status = status
|
||||
if api_key_hash is not None:
|
||||
item.api_key_hash = api_key_hash
|
||||
if allowed_origins is not None:
|
||||
item.allowed_origins = allowed_origins
|
||||
if allowed_ips is not None:
|
||||
item.allowed_ips = allowed_ips
|
||||
if allowed_paths is not None:
|
||||
item.allowed_paths = allowed_paths
|
||||
if rate_limit_per_min is not None:
|
||||
item.rate_limit_per_min = rate_limit_per_min
|
||||
if expires_at is not None:
|
||||
item.expires_at = expires_at
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def delete(self, item: ApiClient) -> None:
|
||||
self.db.delete(item)
|
||||
self.db.commit()
|
||||
@@ -9,36 +9,55 @@ class CompaniesRepository:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, company_key: str) -> Company | None:
|
||||
stmt = select(Company).where(Company.company_key == company_key)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(Company).where(Company.company_key == company_key))
|
||||
|
||||
def get_by_id(self, company_id: str) -> Company | None:
|
||||
stmt = select(Company).where(Company.id == company_id)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(Company).where(Company.id == company_id))
|
||||
|
||||
def list(self, keyword: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[Company], int]:
|
||||
stmt = select(Company)
|
||||
count_stmt = select(func.count()).select_from(Company)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(Company.company_key.ilike(pattern), Company.name.ilike(pattern))
|
||||
cond = or_(
|
||||
Company.company_key.ilike(pattern),
|
||||
Company.display_name.ilike(pattern),
|
||||
Company.legal_name.ilike(pattern),
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, company_key: str, name: str, status: str = "active") -> Company:
|
||||
item = Company(company_key=company_key, name=name, status=status)
|
||||
def create(self, *, company_key: str, display_name: str, legal_name: str | None, status: str = "active") -> Company:
|
||||
item = Company(company_key=company_key, display_name=display_name, legal_name=legal_name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(self, item: Company, *, name: str | None = None, status: str | None = None) -> Company:
|
||||
if name is not None:
|
||||
item.name = name
|
||||
def update(
|
||||
self,
|
||||
item: Company,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
legal_name: str | None = None,
|
||||
idp_group_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> Company:
|
||||
if display_name is not None:
|
||||
item.display_name = display_name
|
||||
if legal_name is not None:
|
||||
item.legal_name = legal_name
|
||||
if idp_group_id is not None:
|
||||
item.idp_group_id = idp_group_id
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def delete(self, item: Company) -> None:
|
||||
self.db.delete(item)
|
||||
self.db.commit()
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.module import Module
|
||||
|
||||
|
||||
class ModulesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, module_key: str) -> Module | None:
|
||||
stmt = select(Module).where(Module.module_key == module_key)
|
||||
return self.db.scalar(stmt)
|
||||
|
||||
def list(self, limit: int = 200, offset: int = 0) -> tuple[list[Module], int]:
|
||||
stmt = select(Module)
|
||||
count_stmt = select(func.count()).select_from(Module)
|
||||
stmt = stmt.order_by(Module.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, module_key: str, system_key: str, name: str, status: str = "active") -> Module:
|
||||
item = Module(module_key=module_key, system_key=system_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(self, item: Module, *, name: str | None = None, status: str | None = None) -> Module:
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
@@ -1,269 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.permission_group import PermissionGroup
|
||||
from app.models.permission_group_member import PermissionGroupMember
|
||||
from app.models.permission_group_permission import PermissionGroupPermission
|
||||
from app.models.user import User
|
||||
|
||||
|
||||
class PermissionGroupsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, group_key: str) -> PermissionGroup | None:
|
||||
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key))
|
||||
|
||||
def get_by_id(self, group_id: str) -> PermissionGroup | None:
|
||||
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id))
|
||||
|
||||
def get_by_keys(self, group_keys: list[str]) -> list[PermissionGroup]:
|
||||
if not group_keys:
|
||||
return []
|
||||
stmt = select(PermissionGroup).where(PermissionGroup.group_key.in_(group_keys))
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]:
|
||||
stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset)
|
||||
count_stmt = select(func.count()).select_from(PermissionGroup)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup:
|
||||
item = PermissionGroup(group_key=group_key, name=name, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(self, item: PermissionGroup, *, name: str | None = None, status: str | None = None) -> PermissionGroup:
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def add_member_if_not_exists(self, group_id: str, user_sub: str) -> PermissionGroupMember:
|
||||
existing = self.db.scalar(
|
||||
select(PermissionGroupMember).where(
|
||||
PermissionGroupMember.group_id == group_id, PermissionGroupMember.user_sub == user_sub
|
||||
)
|
||||
)
|
||||
if existing:
|
||||
return existing
|
||||
row = PermissionGroupMember(group_id=group_id, user_sub=user_sub)
|
||||
self.db.add(row)
|
||||
self.db.commit()
|
||||
self.db.refresh(row)
|
||||
return row
|
||||
|
||||
def remove_member(self, group_id: str, user_sub: str) -> int:
|
||||
result = self.db.execute(
|
||||
delete(PermissionGroupMember).where(
|
||||
PermissionGroupMember.group_id == group_id, PermissionGroupMember.user_sub == user_sub
|
||||
)
|
||||
)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def list_group_keys_by_member_sub(self, user_sub: str) -> list[str]:
|
||||
stmt = (
|
||||
select(PermissionGroup.group_key)
|
||||
.select_from(PermissionGroupMember)
|
||||
.join(PermissionGroup, PermissionGroup.id == PermissionGroupMember.group_id)
|
||||
.where(PermissionGroupMember.user_sub == user_sub)
|
||||
.order_by(PermissionGroup.group_key.asc())
|
||||
)
|
||||
return [row[0] for row in self.db.execute(stmt).all()]
|
||||
|
||||
def replace_member_groups(self, user_sub: str, group_ids: list[str]) -> None:
|
||||
self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.user_sub == user_sub))
|
||||
for group_id in group_ids:
|
||||
self.db.add(PermissionGroupMember(group_id=group_id, user_sub=user_sub))
|
||||
self.db.commit()
|
||||
|
||||
def grant_group_permission(
|
||||
self,
|
||||
group_id: str,
|
||||
system: str,
|
||||
module: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
) -> PermissionGroupPermission:
|
||||
where = [
|
||||
PermissionGroupPermission.group_id == group_id,
|
||||
PermissionGroupPermission.system == system,
|
||||
PermissionGroupPermission.module == module,
|
||||
PermissionGroupPermission.action == action,
|
||||
PermissionGroupPermission.scope_type == scope_type,
|
||||
PermissionGroupPermission.scope_id == scope_id,
|
||||
]
|
||||
existing = self.db.scalar(select(PermissionGroupPermission).where(*where))
|
||||
if existing:
|
||||
return existing
|
||||
row = PermissionGroupPermission(
|
||||
group_id=group_id,
|
||||
system=system,
|
||||
module=module,
|
||||
action=action,
|
||||
scope_type=scope_type,
|
||||
scope_id=scope_id,
|
||||
)
|
||||
self.db.add(row)
|
||||
self.db.commit()
|
||||
self.db.refresh(row)
|
||||
return row
|
||||
|
||||
def list_group_permissions(self, group_id: str) -> list[PermissionGroupPermission]:
|
||||
stmt = (
|
||||
select(PermissionGroupPermission)
|
||||
.where(PermissionGroupPermission.group_id == group_id)
|
||||
.order_by(PermissionGroupPermission.scope_type.asc(), PermissionGroupPermission.scope_id.asc(), PermissionGroupPermission.system.asc(), PermissionGroupPermission.module.asc(), PermissionGroupPermission.action.asc())
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def replace_group_bindings(
|
||||
self,
|
||||
*,
|
||||
group_id: str,
|
||||
site_keys: list[str],
|
||||
system_keys: list[str],
|
||||
module_keys: list[str],
|
||||
member_subs: list[str],
|
||||
actions: list[str],
|
||||
) -> None:
|
||||
normalized_sites = list(dict.fromkeys([s for s in site_keys if s]))
|
||||
normalized_actions = [a for a in list(dict.fromkeys(actions)) if a in {"view", "edit"}]
|
||||
normalized_member_subs = list(dict.fromkeys([s for s in member_subs if s]))
|
||||
|
||||
normalized_systems = set([s for s in system_keys if s])
|
||||
module_pairs = []
|
||||
for pair in module_keys:
|
||||
if "|" not in pair:
|
||||
continue
|
||||
system_key, module_key = pair.split("|", 1)
|
||||
if not system_key or not module_key:
|
||||
continue
|
||||
module_pairs.append((system_key, module_key))
|
||||
normalized_systems.add(system_key)
|
||||
|
||||
self.db.execute(delete(PermissionGroupPermission).where(PermissionGroupPermission.group_id == group_id))
|
||||
self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.group_id == group_id))
|
||||
|
||||
for sub in normalized_member_subs:
|
||||
self.db.add(PermissionGroupMember(group_id=group_id, user_sub=sub))
|
||||
|
||||
for site_key in normalized_sites:
|
||||
for action in normalized_actions:
|
||||
for system_key in sorted(normalized_systems):
|
||||
module_names = [m for s, m in module_pairs if s == system_key] or ["__system__"]
|
||||
for module_name in module_names:
|
||||
self.db.add(
|
||||
PermissionGroupPermission(
|
||||
group_id=group_id,
|
||||
system=system_key,
|
||||
module=module_name,
|
||||
action=action,
|
||||
scope_type="site",
|
||||
scope_id=site_key,
|
||||
)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_group_binding_snapshot(self, group_id: str, group_key: str) -> dict:
|
||||
permissions = self.list_group_permissions(group_id)
|
||||
site_keys = sorted({p.scope_id for p in permissions if p.scope_type == "site"})
|
||||
system_keys = sorted({p.system for p in permissions})
|
||||
actions = sorted({p.action for p in permissions if p.action in {"view", "edit"}})
|
||||
module_keys = sorted(
|
||||
{
|
||||
f"{p.system}|{p.module}"
|
||||
for p in permissions
|
||||
if p.module and p.module != "__system__"
|
||||
}
|
||||
)
|
||||
member_subs = sorted(self.list_group_member_subs(group_id))
|
||||
return {
|
||||
"group_key": group_key,
|
||||
"site_keys": site_keys,
|
||||
"system_keys": system_keys,
|
||||
"module_keys": module_keys,
|
||||
"member_subs": member_subs,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
def list_group_member_subs(self, group_id: str) -> list[str]:
|
||||
stmt = (
|
||||
select(PermissionGroupMember.user_sub)
|
||||
.where(PermissionGroupMember.group_id == group_id)
|
||||
.order_by(PermissionGroupMember.user_sub.asc())
|
||||
)
|
||||
return [row[0] for row in self.db.execute(stmt).all()]
|
||||
|
||||
def list_system_groups(self, system_key: str) -> list[PermissionGroup]:
|
||||
stmt = (
|
||||
select(PermissionGroup)
|
||||
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id)
|
||||
.where(PermissionGroupPermission.system == system_key)
|
||||
.order_by(PermissionGroup.name.asc())
|
||||
.distinct()
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def list_system_members(self, system_key: str) -> list[User]:
|
||||
stmt = (
|
||||
select(User)
|
||||
.join(PermissionGroupMember, PermissionGroupMember.user_sub == User.user_sub)
|
||||
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id)
|
||||
.where(PermissionGroupPermission.system == system_key)
|
||||
.order_by(User.email.asc(), User.user_sub.asc())
|
||||
.distinct()
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def list_module_groups(self, system_key: str, module_name: str) -> list[PermissionGroup]:
|
||||
stmt = (
|
||||
select(PermissionGroup)
|
||||
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id)
|
||||
.where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name)
|
||||
.order_by(PermissionGroup.name.asc())
|
||||
.distinct()
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def list_module_members(self, system_key: str, module_name: str) -> list[User]:
|
||||
stmt = (
|
||||
select(User)
|
||||
.join(PermissionGroupMember, PermissionGroupMember.user_sub == User.user_sub)
|
||||
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id)
|
||||
.where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name)
|
||||
.order_by(User.email.asc(), User.user_sub.asc())
|
||||
.distinct()
|
||||
)
|
||||
return list(self.db.scalars(stmt).all())
|
||||
|
||||
def revoke_group_permission(
|
||||
self,
|
||||
group_id: str,
|
||||
system: str,
|
||||
module: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
scope_id: str,
|
||||
) -> int:
|
||||
stmt = delete(PermissionGroupPermission).where(
|
||||
PermissionGroupPermission.group_id == group_id,
|
||||
PermissionGroupPermission.system == system,
|
||||
PermissionGroupPermission.module == module,
|
||||
PermissionGroupPermission.action == action,
|
||||
PermissionGroupPermission.scope_type == scope_type,
|
||||
PermissionGroupPermission.scope_id == scope_id,
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
@@ -1,231 +0,0 @@
|
||||
from sqlalchemy import and_, delete, func, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.company import Company
|
||||
from app.models.module import Module
|
||||
from app.models.permission_group_member import PermissionGroupMember
|
||||
from app.models.permission_group_permission import PermissionGroupPermission
|
||||
from app.models.site import Site
|
||||
from app.models.user import User
|
||||
from app.models.user_scope_permission import UserScopePermission
|
||||
|
||||
|
||||
class PermissionsRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_by_user(self, user_id: str, user_sub: str) -> list[tuple[str, str, str | None, str, str]]:
|
||||
direct_stmt = (
|
||||
select(
|
||||
literal("direct"),
|
||||
UserScopePermission.scope_type,
|
||||
Company.company_key,
|
||||
Site.site_key,
|
||||
Module.system_key,
|
||||
Module.module_key,
|
||||
UserScopePermission.action,
|
||||
)
|
||||
.select_from(UserScopePermission)
|
||||
.join(Module, Module.id == UserScopePermission.module_id)
|
||||
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
|
||||
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
|
||||
.where(UserScopePermission.user_id == user_id)
|
||||
.where(UserScopePermission.action.in_(["view", "edit"]))
|
||||
.where(UserScopePermission.scope_type == "site")
|
||||
)
|
||||
group_stmt = (
|
||||
select(
|
||||
literal("group"),
|
||||
PermissionGroupPermission.scope_type,
|
||||
PermissionGroupPermission.scope_id,
|
||||
PermissionGroupPermission.system,
|
||||
PermissionGroupPermission.module,
|
||||
PermissionGroupPermission.action,
|
||||
)
|
||||
.select_from(PermissionGroupPermission)
|
||||
.join(PermissionGroupMember, PermissionGroupMember.group_id == PermissionGroupPermission.group_id)
|
||||
.where(PermissionGroupMember.user_sub == user_sub)
|
||||
.where(PermissionGroupPermission.action.in_(["view", "edit"]))
|
||||
.where(PermissionGroupPermission.scope_type == "site")
|
||||
)
|
||||
rows = self.db.execute(direct_stmt).all() + self.db.execute(group_stmt).all()
|
||||
result: list[tuple[str, str, str | None, str, str]] = []
|
||||
dedup = set()
|
||||
for row in rows:
|
||||
source = row[0]
|
||||
if source == "group":
|
||||
_, scope_type, scope_id, system_key, module_key, action = row
|
||||
if module_key == "__system__":
|
||||
module_key = f"__system__{system_key}"
|
||||
else:
|
||||
_, scope_type, company_key, site_key, system_key, module_key, action = row
|
||||
scope_id = company_key if scope_type == "company" else site_key
|
||||
key = (scope_type, scope_id or "", system_key, module_key, action)
|
||||
if key in dedup:
|
||||
continue
|
||||
dedup.add(key)
|
||||
result.append(key)
|
||||
return result
|
||||
|
||||
def create_if_not_exists(
|
||||
self,
|
||||
user_id: str,
|
||||
module_id: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
company_id: str | None,
|
||||
site_id: str | None,
|
||||
) -> UserScopePermission:
|
||||
where_expr = [
|
||||
UserScopePermission.user_id == user_id,
|
||||
UserScopePermission.module_id == module_id,
|
||||
UserScopePermission.action == action,
|
||||
UserScopePermission.scope_type == scope_type,
|
||||
]
|
||||
if scope_type == "company":
|
||||
where_expr.append(UserScopePermission.company_id == company_id)
|
||||
else:
|
||||
where_expr.append(UserScopePermission.site_id == site_id)
|
||||
|
||||
existing = self.db.scalar(select(UserScopePermission).where(and_(*where_expr)))
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
item = UserScopePermission(
|
||||
user_id=user_id,
|
||||
module_id=module_id,
|
||||
action=action,
|
||||
scope_type=scope_type,
|
||||
company_id=company_id,
|
||||
site_id=site_id,
|
||||
)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def revoke(
|
||||
self,
|
||||
user_id: str,
|
||||
module_id: str,
|
||||
action: str,
|
||||
scope_type: str,
|
||||
company_id: str | None,
|
||||
site_id: str | None,
|
||||
) -> int:
|
||||
stmt = delete(UserScopePermission).where(
|
||||
UserScopePermission.user_id == user_id,
|
||||
UserScopePermission.module_id == module_id,
|
||||
UserScopePermission.action == action,
|
||||
UserScopePermission.scope_type == scope_type,
|
||||
or_(
|
||||
and_(scope_type == "company", UserScopePermission.company_id == company_id),
|
||||
and_(scope_type == "site", UserScopePermission.site_id == site_id),
|
||||
),
|
||||
)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
def list_direct_permissions(
|
||||
self,
|
||||
*,
|
||||
keyword: str | None = None,
|
||||
scope_type: str | None = None,
|
||||
limit: int = 200,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[dict], int]:
|
||||
stmt = (
|
||||
select(
|
||||
UserScopePermission.id,
|
||||
User.user_sub,
|
||||
User.email,
|
||||
User.display_name,
|
||||
UserScopePermission.scope_type,
|
||||
Company.company_key,
|
||||
Site.site_key,
|
||||
Module.system_key,
|
||||
Module.module_key,
|
||||
UserScopePermission.action,
|
||||
UserScopePermission.created_at,
|
||||
)
|
||||
.select_from(UserScopePermission)
|
||||
.join(User, User.id == UserScopePermission.user_id)
|
||||
.join(Module, Module.id == UserScopePermission.module_id)
|
||||
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
|
||||
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
|
||||
.where(UserScopePermission.action.in_(["view", "edit"]))
|
||||
.where(UserScopePermission.scope_type == "site")
|
||||
)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(UserScopePermission)
|
||||
.join(User, User.id == UserScopePermission.user_id)
|
||||
.join(Module, Module.id == UserScopePermission.module_id)
|
||||
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
|
||||
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
|
||||
.where(UserScopePermission.action.in_(["view", "edit"]))
|
||||
.where(UserScopePermission.scope_type == "site")
|
||||
)
|
||||
|
||||
if scope_type == "site":
|
||||
stmt = stmt.where(UserScopePermission.scope_type == scope_type)
|
||||
count_stmt = count_stmt.where(UserScopePermission.scope_type == scope_type)
|
||||
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(
|
||||
User.user_sub.ilike(pattern),
|
||||
User.email.ilike(pattern),
|
||||
User.display_name.ilike(pattern),
|
||||
Module.module_key.ilike(pattern),
|
||||
Company.company_key.ilike(pattern),
|
||||
Site.site_key.ilike(pattern),
|
||||
UserScopePermission.action.ilike(pattern),
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
stmt = stmt.order_by(UserScopePermission.created_at.desc()).limit(limit).offset(offset)
|
||||
rows = self.db.execute(stmt).all()
|
||||
total = int(self.db.scalar(count_stmt) or 0)
|
||||
items: list[dict] = []
|
||||
for row in rows:
|
||||
(
|
||||
permission_id,
|
||||
user_sub,
|
||||
email,
|
||||
display_name,
|
||||
row_scope_type,
|
||||
company_key,
|
||||
site_key,
|
||||
system_key,
|
||||
module_key,
|
||||
action,
|
||||
created_at,
|
||||
) = row
|
||||
scope_id = company_key if row_scope_type == "company" else site_key
|
||||
module_name = module_key
|
||||
if isinstance(module_name, str) and module_name.startswith("__system__"):
|
||||
module_name = None
|
||||
items.append(
|
||||
{
|
||||
"permission_id": permission_id,
|
||||
"user_sub": user_sub,
|
||||
"email": email,
|
||||
"display_name": display_name,
|
||||
"scope_type": row_scope_type,
|
||||
"scope_id": scope_id,
|
||||
"system": system_key,
|
||||
"module": module_name,
|
||||
"action": action,
|
||||
"created_at": created_at,
|
||||
}
|
||||
)
|
||||
return items, total
|
||||
|
||||
def revoke_by_permission_id(self, permission_id: str) -> int:
|
||||
stmt = delete(UserScopePermission).where(UserScopePermission.id == permission_id)
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
return int(result.rowcount or 0)
|
||||
97
app/repositories/roles_repo.py
Normal file
97
app/repositories/roles_repo.py
Normal file
@@ -0,0 +1,97 @@
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.role import Role
|
||||
|
||||
|
||||
class RolesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, role_key: str) -> Role | None:
|
||||
return self.db.scalar(select(Role).where(Role.role_key == role_key))
|
||||
|
||||
def get_by_id(self, role_id: str) -> Role | None:
|
||||
return self.db.scalar(select(Role).where(Role.id == role_id))
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
keyword: str | None = None,
|
||||
system_id: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[Role], int]:
|
||||
stmt = select(Role)
|
||||
count_stmt = select(func.count()).select_from(Role)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(
|
||||
Role.role_key.ilike(pattern),
|
||||
Role.name.ilike(pattern),
|
||||
Role.idp_role_name.ilike(pattern),
|
||||
Role.description.ilike(pattern),
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if system_id:
|
||||
stmt = stmt.where(Role.system_id == system_id)
|
||||
count_stmt = count_stmt.where(Role.system_id == system_id)
|
||||
if status:
|
||||
stmt = stmt.where(Role.status == status)
|
||||
count_stmt = count_stmt.where(Role.status == status)
|
||||
|
||||
stmt = stmt.order_by(Role.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
role_key: str,
|
||||
system_id: str,
|
||||
name: str,
|
||||
description: str | None,
|
||||
idp_role_name: str,
|
||||
status: str = "active",
|
||||
) -> Role:
|
||||
item = Role(
|
||||
role_key=role_key,
|
||||
system_id=system_id,
|
||||
name=name,
|
||||
description=description,
|
||||
idp_role_name=idp_role_name,
|
||||
status=status,
|
||||
)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(
|
||||
self,
|
||||
item: Role,
|
||||
*,
|
||||
system_id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
idp_role_name: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> Role:
|
||||
if system_id is not None:
|
||||
item.system_id = system_id
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if description is not None:
|
||||
item.description = description
|
||||
if idp_role_name is not None:
|
||||
item.idp_role_name = idp_role_name
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def delete(self, item: Role) -> None:
|
||||
self.db.delete(item)
|
||||
self.db.commit()
|
||||
37
app/repositories/site_roles_repo.py
Normal file
37
app/repositories/site_roles_repo.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.role import Role
|
||||
from app.models.site import Site
|
||||
from app.models.site_role import SiteRole
|
||||
from app.models.system import System
|
||||
|
||||
|
||||
class SiteRolesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_site_role_rows(self, site_id: str) -> list[tuple[SiteRole, Role, System]]:
|
||||
stmt = (
|
||||
select(SiteRole, Role, System)
|
||||
.join(Role, Role.id == SiteRole.role_id)
|
||||
.join(System, System.id == Role.system_id)
|
||||
.where(SiteRole.site_id == site_id)
|
||||
.order_by(System.name.asc(), Role.name.asc())
|
||||
)
|
||||
return list(self.db.execute(stmt).all())
|
||||
|
||||
def list_role_site_rows(self, role_id: str) -> list[tuple[SiteRole, Site]]:
|
||||
stmt = (
|
||||
select(SiteRole, Site)
|
||||
.join(Site, Site.id == SiteRole.site_id)
|
||||
.where(SiteRole.role_id == role_id)
|
||||
.order_by(Site.display_name.asc())
|
||||
)
|
||||
return list(self.db.execute(stmt).all())
|
||||
|
||||
def set_site_roles(self, *, site_id: str, role_ids: list[str]) -> None:
|
||||
self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id))
|
||||
for role_id in role_ids:
|
||||
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
|
||||
self.db.commit()
|
||||
@@ -9,11 +9,14 @@ class SitesRepository:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, site_key: str) -> Site | None:
|
||||
stmt = select(Site).where(Site.site_key == site_key)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(Site).where(Site.site_key == site_key))
|
||||
|
||||
def get_by_id(self, site_id: str) -> Site | None:
|
||||
return self.db.scalar(select(Site).where(Site.id == site_id))
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
keyword: str | None = None,
|
||||
company_id: str | None = None,
|
||||
limit: int = 100,
|
||||
@@ -21,19 +24,30 @@ class SitesRepository:
|
||||
) -> tuple[list[Site], int]:
|
||||
stmt = select(Site)
|
||||
count_stmt = select(func.count()).select_from(Site)
|
||||
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(Site.site_key.ilike(pattern), Site.name.ilike(pattern))
|
||||
cond = or_(Site.site_key.ilike(pattern), Site.display_name.ilike(pattern), Site.domain.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if company_id:
|
||||
stmt = stmt.where(Site.company_id == company_id)
|
||||
count_stmt = count_stmt.where(Site.company_id == company_id)
|
||||
|
||||
stmt = stmt.order_by(Site.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, site_key: str, company_id: str, name: str, status: str = "active") -> Site:
|
||||
item = Site(site_key=site_key, company_id=company_id, name=name, status=status)
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
site_key: str,
|
||||
company_id: str,
|
||||
display_name: str,
|
||||
domain: str | None,
|
||||
status: str = "active",
|
||||
) -> Site:
|
||||
item = Site(site_key=site_key, company_id=company_id, display_name=display_name, domain=domain, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
@@ -44,15 +58,25 @@ class SitesRepository:
|
||||
item: Site,
|
||||
*,
|
||||
company_id: str | None = None,
|
||||
name: str | None = None,
|
||||
display_name: str | None = None,
|
||||
domain: str | None = None,
|
||||
idp_group_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> Site:
|
||||
if company_id is not None:
|
||||
item.company_id = company_id
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if display_name is not None:
|
||||
item.display_name = display_name
|
||||
if domain is not None:
|
||||
item.domain = domain
|
||||
if idp_group_id is not None:
|
||||
item.idp_group_id = idp_group_id
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def delete(self, item: Site) -> None:
|
||||
self.db.delete(item)
|
||||
self.db.commit()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.system import System
|
||||
@@ -9,34 +9,51 @@ class SystemsRepository:
|
||||
self.db = db
|
||||
|
||||
def get_by_key(self, system_key: str) -> System | None:
|
||||
stmt = select(System).where(System.system_key == system_key)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(System).where(System.system_key == system_key))
|
||||
|
||||
def get_by_id(self, system_id: str) -> System | None:
|
||||
stmt = select(System).where(System.id == system_id)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(System).where(System.id == system_id))
|
||||
|
||||
def list(self, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[System], int]:
|
||||
def list(self, *, keyword: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[System], int]:
|
||||
stmt = select(System)
|
||||
count_stmt = select(func.count()).select_from(System)
|
||||
if keyword:
|
||||
pattern = f"%{keyword}%"
|
||||
cond = or_(System.system_key.ilike(pattern), System.name.ilike(pattern), System.idp_client_id.ilike(pattern))
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if status:
|
||||
stmt = stmt.where(System.status == status)
|
||||
count_stmt = count_stmt.where(System.status == status)
|
||||
|
||||
stmt = stmt.order_by(System.created_at.desc()).limit(limit).offset(offset)
|
||||
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
|
||||
|
||||
def create(self, system_key: str, name: str, status: str = "active") -> System:
|
||||
item = System(system_key=system_key, name=name, status=status)
|
||||
def create(self, *, system_key: str, name: str, idp_client_id: str, status: str = "active") -> System:
|
||||
item = System(system_key=system_key, name=name, idp_client_id=idp_client_id, status=status)
|
||||
self.db.add(item)
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def update(self, item: System, *, name: str | None = None, status: str | None = None) -> System:
|
||||
def update(
|
||||
self,
|
||||
item: System,
|
||||
*,
|
||||
name: str | None = None,
|
||||
idp_client_id: str | None = None,
|
||||
status: str | None = None,
|
||||
) -> System:
|
||||
if name is not None:
|
||||
item.name = name
|
||||
if idp_client_id is not None:
|
||||
item.idp_client_id = idp_client_id
|
||||
if status is not None:
|
||||
item.status = status
|
||||
self.db.commit()
|
||||
self.db.refresh(item)
|
||||
return item
|
||||
|
||||
def delete(self, item: System) -> None:
|
||||
self.db.delete(item)
|
||||
self.db.commit()
|
||||
|
||||
54
app/repositories/user_sites_repo.py
Normal file
54
app/repositories/user_sites_repo.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.company import Company
|
||||
from app.models.role import Role
|
||||
from app.models.site import Site
|
||||
from app.models.site_role import SiteRole
|
||||
from app.models.system import System
|
||||
from app.models.user import User
|
||||
from app.models.user_site import UserSite
|
||||
|
||||
|
||||
class UserSitesRepository:
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def list_user_site_rows(self, user_id: str) -> list[tuple[UserSite, Site, Company]]:
|
||||
stmt = (
|
||||
select(UserSite, Site, Company)
|
||||
.join(Site, Site.id == UserSite.site_id)
|
||||
.join(Company, Company.id == Site.company_id)
|
||||
.where(UserSite.user_id == user_id)
|
||||
.order_by(Company.display_name.asc(), Site.display_name.asc())
|
||||
)
|
||||
return list(self.db.execute(stmt).all())
|
||||
|
||||
def list_site_member_rows(self, site_id: str) -> list[tuple[UserSite, User]]:
|
||||
stmt = (
|
||||
select(UserSite, User)
|
||||
.join(User, User.id == UserSite.user_id)
|
||||
.where(UserSite.site_id == site_id)
|
||||
.order_by(User.display_name.asc().nulls_last(), User.username.asc().nulls_last(), User.user_sub.asc())
|
||||
)
|
||||
return list(self.db.execute(stmt).all())
|
||||
|
||||
def set_user_sites(self, *, user_id: str, site_ids: list[str]) -> None:
|
||||
self.db.execute(delete(UserSite).where(UserSite.user_id == user_id))
|
||||
for site_id in site_ids:
|
||||
self.db.add(UserSite(user_id=user_id, site_id=site_id))
|
||||
self.db.commit()
|
||||
|
||||
def get_user_role_rows(self, user_id: str) -> list[tuple[Site, Company, Role, System]]:
|
||||
stmt = (
|
||||
select(Site, Company, Role, System)
|
||||
.select_from(UserSite)
|
||||
.join(Site, Site.id == UserSite.site_id)
|
||||
.join(Company, Company.id == Site.company_id)
|
||||
.join(SiteRole, SiteRole.site_id == Site.id)
|
||||
.join(Role, Role.id == SiteRole.role_id)
|
||||
.join(System, System.id == Role.system_id)
|
||||
.where(UserSite.user_id == user_id)
|
||||
.order_by(Company.display_name.asc(), Site.display_name.asc(), System.name.asc(), Role.name.asc())
|
||||
)
|
||||
return list(self.db.execute(stmt).all())
|
||||
@@ -9,15 +9,14 @@ class UsersRepository:
|
||||
self.db = db
|
||||
|
||||
def get_by_sub(self, user_sub: str) -> User | None:
|
||||
stmt = select(User).where(User.user_sub == user_sub)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(User).where(User.user_sub == user_sub))
|
||||
|
||||
def get_by_id(self, user_id: str) -> User | None:
|
||||
stmt = select(User).where(User.id == user_id)
|
||||
return self.db.scalar(stmt)
|
||||
return self.db.scalar(select(User).where(User.id == user_id))
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
keyword: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
limit: int = 50,
|
||||
@@ -48,11 +47,13 @@ class UsersRepository:
|
||||
|
||||
def upsert_by_sub(
|
||||
self,
|
||||
*,
|
||||
user_sub: str,
|
||||
username: str | None,
|
||||
email: str | None,
|
||||
display_name: str | None,
|
||||
is_active: bool,
|
||||
status: str = "active",
|
||||
idp_user_id: str | None = None,
|
||||
) -> User:
|
||||
user = self.get_by_sub(user_sub)
|
||||
@@ -64,6 +65,7 @@ class UsersRepository:
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
is_active=is_active,
|
||||
status=status,
|
||||
)
|
||||
self.db.add(user)
|
||||
else:
|
||||
@@ -73,6 +75,7 @@ class UsersRepository:
|
||||
user.email = email
|
||||
user.display_name = display_name
|
||||
user.is_active = is_active
|
||||
user.status = status
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
@@ -86,6 +89,7 @@ class UsersRepository:
|
||||
email: str | None = None,
|
||||
display_name: str | None = None,
|
||||
is_active: bool | None = None,
|
||||
status: str | None = None,
|
||||
) -> User:
|
||||
if username is not None:
|
||||
user.username = username
|
||||
@@ -95,6 +99,13 @@ class UsersRepository:
|
||||
user.display_name = display_name
|
||||
if is_active is not None:
|
||||
user.is_active = is_active
|
||||
if status is not None:
|
||||
user.status = status
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
return user
|
||||
|
||||
def delete(self, user: User) -> None:
|
||||
self.db.delete(user)
|
||||
self.db.commit()
|
||||
|
||||
Reference in New Issue
Block a user