refactor: rebuild backend around role-site authorization model

This commit is contained in:
Chris
2026-04-02 23:58:13 +08:00
parent 0bc667847d
commit 2f92b94f59
43 changed files with 1593 additions and 2257 deletions

View File

@@ -0,0 +1,94 @@
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.models.api_client import ApiClient
class ApiClientsRepository:
def __init__(self, db: Session) -> None:
self.db = db
def get_by_key(self, client_key: str) -> ApiClient | None:
return self.db.scalar(select(ApiClient).where(ApiClient.client_key == client_key))
def list(self, *, keyword: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[ApiClient], int]:
stmt = select(ApiClient)
count_stmt = select(func.count()).select_from(ApiClient)
if keyword:
pattern = f"%{keyword}%"
cond = or_(ApiClient.client_key.ilike(pattern), ApiClient.name.ilike(pattern))
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if status:
stmt = stmt.where(ApiClient.status == status)
count_stmt = count_stmt.where(ApiClient.status == status)
stmt = stmt.order_by(ApiClient.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(
self,
*,
client_key: str,
name: str,
status: str,
api_key_hash: str,
allowed_origins: list[str],
allowed_ips: list[str],
allowed_paths: list[str],
rate_limit_per_min: int | None,
expires_at,
) -> ApiClient:
item = ApiClient(
client_key=client_key,
name=name,
status=status,
api_key_hash=api_key_hash,
allowed_origins=allowed_origins,
allowed_ips=allowed_ips,
allowed_paths=allowed_paths,
rate_limit_per_min=rate_limit_per_min,
expires_at=expires_at,
)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(
self,
item: ApiClient,
*,
name: str | None = None,
status: str | None = None,
api_key_hash: str | None = None,
allowed_origins: list[str] | None = None,
allowed_ips: list[str] | None = None,
allowed_paths: list[str] | None = None,
rate_limit_per_min: int | None = None,
expires_at=None,
) -> ApiClient:
if name is not None:
item.name = name
if status is not None:
item.status = status
if api_key_hash is not None:
item.api_key_hash = api_key_hash
if allowed_origins is not None:
item.allowed_origins = allowed_origins
if allowed_ips is not None:
item.allowed_ips = allowed_ips
if allowed_paths is not None:
item.allowed_paths = allowed_paths
if rate_limit_per_min is not None:
item.rate_limit_per_min = rate_limit_per_min
if expires_at is not None:
item.expires_at = expires_at
self.db.commit()
self.db.refresh(item)
return item
def delete(self, item: ApiClient) -> None:
self.db.delete(item)
self.db.commit()

View File

@@ -9,36 +9,55 @@ class CompaniesRepository:
self.db = db
def get_by_key(self, company_key: str) -> Company | None:
stmt = select(Company).where(Company.company_key == company_key)
return self.db.scalar(stmt)
return self.db.scalar(select(Company).where(Company.company_key == company_key))
def get_by_id(self, company_id: str) -> Company | None:
stmt = select(Company).where(Company.id == company_id)
return self.db.scalar(stmt)
return self.db.scalar(select(Company).where(Company.id == company_id))
def list(self, keyword: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[Company], int]:
stmt = select(Company)
count_stmt = select(func.count()).select_from(Company)
if keyword:
pattern = f"%{keyword}%"
cond = or_(Company.company_key.ilike(pattern), Company.name.ilike(pattern))
cond = or_(
Company.company_key.ilike(pattern),
Company.display_name.ilike(pattern),
Company.legal_name.ilike(pattern),
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, company_key: str, name: str, status: str = "active") -> Company:
item = Company(company_key=company_key, name=name, status=status)
def create(self, *, company_key: str, display_name: str, legal_name: str | None, status: str = "active") -> Company:
item = Company(company_key=company_key, display_name=display_name, legal_name=legal_name, status=status)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(self, item: Company, *, name: str | None = None, status: str | None = None) -> Company:
if name is not None:
item.name = name
def update(
self,
item: Company,
*,
display_name: str | None = None,
legal_name: str | None = None,
idp_group_id: str | None = None,
status: str | None = None,
) -> Company:
if display_name is not None:
item.display_name = display_name
if legal_name is not None:
item.legal_name = legal_name
if idp_group_id is not None:
item.idp_group_id = idp_group_id
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item
def delete(self, item: Company) -> None:
self.db.delete(item)
self.db.commit()

View File

@@ -1,35 +0,0 @@
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.models.module import Module
class ModulesRepository:
def __init__(self, db: Session) -> None:
self.db = db
def get_by_key(self, module_key: str) -> Module | None:
stmt = select(Module).where(Module.module_key == module_key)
return self.db.scalar(stmt)
def list(self, limit: int = 200, offset: int = 0) -> tuple[list[Module], int]:
stmt = select(Module)
count_stmt = select(func.count()).select_from(Module)
stmt = stmt.order_by(Module.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, module_key: str, system_key: str, name: str, status: str = "active") -> Module:
item = Module(module_key=module_key, system_key=system_key, name=name, status=status)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(self, item: Module, *, name: str | None = None, status: str | None = None) -> Module:
if name is not None:
item.name = name
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item

View File

@@ -1,269 +0,0 @@
from __future__ import annotations
from sqlalchemy import delete, func, select
from sqlalchemy.orm import Session
from app.models.permission_group import PermissionGroup
from app.models.permission_group_member import PermissionGroupMember
from app.models.permission_group_permission import PermissionGroupPermission
from app.models.user import User
class PermissionGroupsRepository:
def __init__(self, db: Session) -> None:
self.db = db
def get_by_key(self, group_key: str) -> PermissionGroup | None:
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.group_key == group_key))
def get_by_id(self, group_id: str) -> PermissionGroup | None:
return self.db.scalar(select(PermissionGroup).where(PermissionGroup.id == group_id))
def get_by_keys(self, group_keys: list[str]) -> list[PermissionGroup]:
if not group_keys:
return []
stmt = select(PermissionGroup).where(PermissionGroup.group_key.in_(group_keys))
return list(self.db.scalars(stmt).all())
def list(self, limit: int = 100, offset: int = 0) -> tuple[list[PermissionGroup], int]:
stmt = select(PermissionGroup).order_by(PermissionGroup.created_at.desc()).limit(limit).offset(offset)
count_stmt = select(func.count()).select_from(PermissionGroup)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, group_key: str, name: str, status: str = "active") -> PermissionGroup:
item = PermissionGroup(group_key=group_key, name=name, status=status)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(self, item: PermissionGroup, *, name: str | None = None, status: str | None = None) -> PermissionGroup:
if name is not None:
item.name = name
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item
def add_member_if_not_exists(self, group_id: str, user_sub: str) -> PermissionGroupMember:
existing = self.db.scalar(
select(PermissionGroupMember).where(
PermissionGroupMember.group_id == group_id, PermissionGroupMember.user_sub == user_sub
)
)
if existing:
return existing
row = PermissionGroupMember(group_id=group_id, user_sub=user_sub)
self.db.add(row)
self.db.commit()
self.db.refresh(row)
return row
def remove_member(self, group_id: str, user_sub: str) -> int:
result = self.db.execute(
delete(PermissionGroupMember).where(
PermissionGroupMember.group_id == group_id, PermissionGroupMember.user_sub == user_sub
)
)
self.db.commit()
return int(result.rowcount or 0)
def list_group_keys_by_member_sub(self, user_sub: str) -> list[str]:
stmt = (
select(PermissionGroup.group_key)
.select_from(PermissionGroupMember)
.join(PermissionGroup, PermissionGroup.id == PermissionGroupMember.group_id)
.where(PermissionGroupMember.user_sub == user_sub)
.order_by(PermissionGroup.group_key.asc())
)
return [row[0] for row in self.db.execute(stmt).all()]
def replace_member_groups(self, user_sub: str, group_ids: list[str]) -> None:
self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.user_sub == user_sub))
for group_id in group_ids:
self.db.add(PermissionGroupMember(group_id=group_id, user_sub=user_sub))
self.db.commit()
def grant_group_permission(
self,
group_id: str,
system: str,
module: str,
action: str,
scope_type: str,
scope_id: str,
) -> PermissionGroupPermission:
where = [
PermissionGroupPermission.group_id == group_id,
PermissionGroupPermission.system == system,
PermissionGroupPermission.module == module,
PermissionGroupPermission.action == action,
PermissionGroupPermission.scope_type == scope_type,
PermissionGroupPermission.scope_id == scope_id,
]
existing = self.db.scalar(select(PermissionGroupPermission).where(*where))
if existing:
return existing
row = PermissionGroupPermission(
group_id=group_id,
system=system,
module=module,
action=action,
scope_type=scope_type,
scope_id=scope_id,
)
self.db.add(row)
self.db.commit()
self.db.refresh(row)
return row
def list_group_permissions(self, group_id: str) -> list[PermissionGroupPermission]:
stmt = (
select(PermissionGroupPermission)
.where(PermissionGroupPermission.group_id == group_id)
.order_by(PermissionGroupPermission.scope_type.asc(), PermissionGroupPermission.scope_id.asc(), PermissionGroupPermission.system.asc(), PermissionGroupPermission.module.asc(), PermissionGroupPermission.action.asc())
)
return list(self.db.scalars(stmt).all())
def replace_group_bindings(
self,
*,
group_id: str,
site_keys: list[str],
system_keys: list[str],
module_keys: list[str],
member_subs: list[str],
actions: list[str],
) -> None:
normalized_sites = list(dict.fromkeys([s for s in site_keys if s]))
normalized_actions = [a for a in list(dict.fromkeys(actions)) if a in {"view", "edit"}]
normalized_member_subs = list(dict.fromkeys([s for s in member_subs if s]))
normalized_systems = set([s for s in system_keys if s])
module_pairs = []
for pair in module_keys:
if "|" not in pair:
continue
system_key, module_key = pair.split("|", 1)
if not system_key or not module_key:
continue
module_pairs.append((system_key, module_key))
normalized_systems.add(system_key)
self.db.execute(delete(PermissionGroupPermission).where(PermissionGroupPermission.group_id == group_id))
self.db.execute(delete(PermissionGroupMember).where(PermissionGroupMember.group_id == group_id))
for sub in normalized_member_subs:
self.db.add(PermissionGroupMember(group_id=group_id, user_sub=sub))
for site_key in normalized_sites:
for action in normalized_actions:
for system_key in sorted(normalized_systems):
module_names = [m for s, m in module_pairs if s == system_key] or ["__system__"]
for module_name in module_names:
self.db.add(
PermissionGroupPermission(
group_id=group_id,
system=system_key,
module=module_name,
action=action,
scope_type="site",
scope_id=site_key,
)
)
self.db.commit()
def get_group_binding_snapshot(self, group_id: str, group_key: str) -> dict:
permissions = self.list_group_permissions(group_id)
site_keys = sorted({p.scope_id for p in permissions if p.scope_type == "site"})
system_keys = sorted({p.system for p in permissions})
actions = sorted({p.action for p in permissions if p.action in {"view", "edit"}})
module_keys = sorted(
{
f"{p.system}|{p.module}"
for p in permissions
if p.module and p.module != "__system__"
}
)
member_subs = sorted(self.list_group_member_subs(group_id))
return {
"group_key": group_key,
"site_keys": site_keys,
"system_keys": system_keys,
"module_keys": module_keys,
"member_subs": member_subs,
"actions": actions,
}
def list_group_member_subs(self, group_id: str) -> list[str]:
stmt = (
select(PermissionGroupMember.user_sub)
.where(PermissionGroupMember.group_id == group_id)
.order_by(PermissionGroupMember.user_sub.asc())
)
return [row[0] for row in self.db.execute(stmt).all()]
def list_system_groups(self, system_key: str) -> list[PermissionGroup]:
stmt = (
select(PermissionGroup)
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id)
.where(PermissionGroupPermission.system == system_key)
.order_by(PermissionGroup.name.asc())
.distinct()
)
return list(self.db.scalars(stmt).all())
def list_system_members(self, system_key: str) -> list[User]:
stmt = (
select(User)
.join(PermissionGroupMember, PermissionGroupMember.user_sub == User.user_sub)
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id)
.where(PermissionGroupPermission.system == system_key)
.order_by(User.email.asc(), User.user_sub.asc())
.distinct()
)
return list(self.db.scalars(stmt).all())
def list_module_groups(self, system_key: str, module_name: str) -> list[PermissionGroup]:
stmt = (
select(PermissionGroup)
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroup.id)
.where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name)
.order_by(PermissionGroup.name.asc())
.distinct()
)
return list(self.db.scalars(stmt).all())
def list_module_members(self, system_key: str, module_name: str) -> list[User]:
stmt = (
select(User)
.join(PermissionGroupMember, PermissionGroupMember.user_sub == User.user_sub)
.join(PermissionGroupPermission, PermissionGroupPermission.group_id == PermissionGroupMember.group_id)
.where(PermissionGroupPermission.system == system_key, PermissionGroupPermission.module == module_name)
.order_by(User.email.asc(), User.user_sub.asc())
.distinct()
)
return list(self.db.scalars(stmt).all())
def revoke_group_permission(
self,
group_id: str,
system: str,
module: str,
action: str,
scope_type: str,
scope_id: str,
) -> int:
stmt = delete(PermissionGroupPermission).where(
PermissionGroupPermission.group_id == group_id,
PermissionGroupPermission.system == system,
PermissionGroupPermission.module == module,
PermissionGroupPermission.action == action,
PermissionGroupPermission.scope_type == scope_type,
PermissionGroupPermission.scope_id == scope_id,
)
result = self.db.execute(stmt)
self.db.commit()
return int(result.rowcount or 0)

View File

@@ -1,231 +0,0 @@
from sqlalchemy import and_, delete, func, literal, or_, select
from sqlalchemy.orm import Session
from app.models.company import Company
from app.models.module import Module
from app.models.permission_group_member import PermissionGroupMember
from app.models.permission_group_permission import PermissionGroupPermission
from app.models.site import Site
from app.models.user import User
from app.models.user_scope_permission import UserScopePermission
class PermissionsRepository:
def __init__(self, db: Session) -> None:
self.db = db
def list_by_user(self, user_id: str, user_sub: str) -> list[tuple[str, str, str | None, str, str]]:
direct_stmt = (
select(
literal("direct"),
UserScopePermission.scope_type,
Company.company_key,
Site.site_key,
Module.system_key,
Module.module_key,
UserScopePermission.action,
)
.select_from(UserScopePermission)
.join(Module, Module.id == UserScopePermission.module_id)
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
.where(UserScopePermission.user_id == user_id)
.where(UserScopePermission.action.in_(["view", "edit"]))
.where(UserScopePermission.scope_type == "site")
)
group_stmt = (
select(
literal("group"),
PermissionGroupPermission.scope_type,
PermissionGroupPermission.scope_id,
PermissionGroupPermission.system,
PermissionGroupPermission.module,
PermissionGroupPermission.action,
)
.select_from(PermissionGroupPermission)
.join(PermissionGroupMember, PermissionGroupMember.group_id == PermissionGroupPermission.group_id)
.where(PermissionGroupMember.user_sub == user_sub)
.where(PermissionGroupPermission.action.in_(["view", "edit"]))
.where(PermissionGroupPermission.scope_type == "site")
)
rows = self.db.execute(direct_stmt).all() + self.db.execute(group_stmt).all()
result: list[tuple[str, str, str | None, str, str]] = []
dedup = set()
for row in rows:
source = row[0]
if source == "group":
_, scope_type, scope_id, system_key, module_key, action = row
if module_key == "__system__":
module_key = f"__system__{system_key}"
else:
_, scope_type, company_key, site_key, system_key, module_key, action = row
scope_id = company_key if scope_type == "company" else site_key
key = (scope_type, scope_id or "", system_key, module_key, action)
if key in dedup:
continue
dedup.add(key)
result.append(key)
return result
def create_if_not_exists(
self,
user_id: str,
module_id: str,
action: str,
scope_type: str,
company_id: str | None,
site_id: str | None,
) -> UserScopePermission:
where_expr = [
UserScopePermission.user_id == user_id,
UserScopePermission.module_id == module_id,
UserScopePermission.action == action,
UserScopePermission.scope_type == scope_type,
]
if scope_type == "company":
where_expr.append(UserScopePermission.company_id == company_id)
else:
where_expr.append(UserScopePermission.site_id == site_id)
existing = self.db.scalar(select(UserScopePermission).where(and_(*where_expr)))
if existing:
return existing
item = UserScopePermission(
user_id=user_id,
module_id=module_id,
action=action,
scope_type=scope_type,
company_id=company_id,
site_id=site_id,
)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def revoke(
self,
user_id: str,
module_id: str,
action: str,
scope_type: str,
company_id: str | None,
site_id: str | None,
) -> int:
stmt = delete(UserScopePermission).where(
UserScopePermission.user_id == user_id,
UserScopePermission.module_id == module_id,
UserScopePermission.action == action,
UserScopePermission.scope_type == scope_type,
or_(
and_(scope_type == "company", UserScopePermission.company_id == company_id),
and_(scope_type == "site", UserScopePermission.site_id == site_id),
),
)
result = self.db.execute(stmt)
self.db.commit()
return int(result.rowcount or 0)
def list_direct_permissions(
self,
*,
keyword: str | None = None,
scope_type: str | None = None,
limit: int = 200,
offset: int = 0,
) -> tuple[list[dict], int]:
stmt = (
select(
UserScopePermission.id,
User.user_sub,
User.email,
User.display_name,
UserScopePermission.scope_type,
Company.company_key,
Site.site_key,
Module.system_key,
Module.module_key,
UserScopePermission.action,
UserScopePermission.created_at,
)
.select_from(UserScopePermission)
.join(User, User.id == UserScopePermission.user_id)
.join(Module, Module.id == UserScopePermission.module_id)
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
.where(UserScopePermission.action.in_(["view", "edit"]))
.where(UserScopePermission.scope_type == "site")
)
count_stmt = (
select(func.count())
.select_from(UserScopePermission)
.join(User, User.id == UserScopePermission.user_id)
.join(Module, Module.id == UserScopePermission.module_id)
.join(Company, Company.id == UserScopePermission.company_id, isouter=True)
.join(Site, Site.id == UserScopePermission.site_id, isouter=True)
.where(UserScopePermission.action.in_(["view", "edit"]))
.where(UserScopePermission.scope_type == "site")
)
if scope_type == "site":
stmt = stmt.where(UserScopePermission.scope_type == scope_type)
count_stmt = count_stmt.where(UserScopePermission.scope_type == scope_type)
if keyword:
pattern = f"%{keyword}%"
cond = or_(
User.user_sub.ilike(pattern),
User.email.ilike(pattern),
User.display_name.ilike(pattern),
Module.module_key.ilike(pattern),
Company.company_key.ilike(pattern),
Site.site_key.ilike(pattern),
UserScopePermission.action.ilike(pattern),
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
stmt = stmt.order_by(UserScopePermission.created_at.desc()).limit(limit).offset(offset)
rows = self.db.execute(stmt).all()
total = int(self.db.scalar(count_stmt) or 0)
items: list[dict] = []
for row in rows:
(
permission_id,
user_sub,
email,
display_name,
row_scope_type,
company_key,
site_key,
system_key,
module_key,
action,
created_at,
) = row
scope_id = company_key if row_scope_type == "company" else site_key
module_name = module_key
if isinstance(module_name, str) and module_name.startswith("__system__"):
module_name = None
items.append(
{
"permission_id": permission_id,
"user_sub": user_sub,
"email": email,
"display_name": display_name,
"scope_type": row_scope_type,
"scope_id": scope_id,
"system": system_key,
"module": module_name,
"action": action,
"created_at": created_at,
}
)
return items, total
def revoke_by_permission_id(self, permission_id: str) -> int:
stmt = delete(UserScopePermission).where(UserScopePermission.id == permission_id)
result = self.db.execute(stmt)
self.db.commit()
return int(result.rowcount or 0)

View File

@@ -0,0 +1,97 @@
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.models.role import Role
class RolesRepository:
def __init__(self, db: Session) -> None:
self.db = db
def get_by_key(self, role_key: str) -> Role | None:
return self.db.scalar(select(Role).where(Role.role_key == role_key))
def get_by_id(self, role_id: str) -> Role | None:
return self.db.scalar(select(Role).where(Role.id == role_id))
def list(
self,
*,
keyword: str | None = None,
system_id: str | None = None,
status: str | None = None,
limit: int = 100,
offset: int = 0,
) -> tuple[list[Role], int]:
stmt = select(Role)
count_stmt = select(func.count()).select_from(Role)
if keyword:
pattern = f"%{keyword}%"
cond = or_(
Role.role_key.ilike(pattern),
Role.name.ilike(pattern),
Role.idp_role_name.ilike(pattern),
Role.description.ilike(pattern),
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if system_id:
stmt = stmt.where(Role.system_id == system_id)
count_stmt = count_stmt.where(Role.system_id == system_id)
if status:
stmt = stmt.where(Role.status == status)
count_stmt = count_stmt.where(Role.status == status)
stmt = stmt.order_by(Role.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(
self,
*,
role_key: str,
system_id: str,
name: str,
description: str | None,
idp_role_name: str,
status: str = "active",
) -> Role:
item = Role(
role_key=role_key,
system_id=system_id,
name=name,
description=description,
idp_role_name=idp_role_name,
status=status,
)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(
self,
item: Role,
*,
system_id: str | None = None,
name: str | None = None,
description: str | None = None,
idp_role_name: str | None = None,
status: str | None = None,
) -> Role:
if system_id is not None:
item.system_id = system_id
if name is not None:
item.name = name
if description is not None:
item.description = description
if idp_role_name is not None:
item.idp_role_name = idp_role_name
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item
def delete(self, item: Role) -> None:
self.db.delete(item)
self.db.commit()

View File

@@ -0,0 +1,37 @@
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.models.role import Role
from app.models.site import Site
from app.models.site_role import SiteRole
from app.models.system import System
class SiteRolesRepository:
def __init__(self, db: Session) -> None:
self.db = db
def list_site_role_rows(self, site_id: str) -> list[tuple[SiteRole, Role, System]]:
stmt = (
select(SiteRole, Role, System)
.join(Role, Role.id == SiteRole.role_id)
.join(System, System.id == Role.system_id)
.where(SiteRole.site_id == site_id)
.order_by(System.name.asc(), Role.name.asc())
)
return list(self.db.execute(stmt).all())
def list_role_site_rows(self, role_id: str) -> list[tuple[SiteRole, Site]]:
stmt = (
select(SiteRole, Site)
.join(Site, Site.id == SiteRole.site_id)
.where(SiteRole.role_id == role_id)
.order_by(Site.display_name.asc())
)
return list(self.db.execute(stmt).all())
def set_site_roles(self, *, site_id: str, role_ids: list[str]) -> None:
self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id))
for role_id in role_ids:
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
self.db.commit()

View File

@@ -9,11 +9,14 @@ class SitesRepository:
self.db = db
def get_by_key(self, site_key: str) -> Site | None:
stmt = select(Site).where(Site.site_key == site_key)
return self.db.scalar(stmt)
return self.db.scalar(select(Site).where(Site.site_key == site_key))
def get_by_id(self, site_id: str) -> Site | None:
return self.db.scalar(select(Site).where(Site.id == site_id))
def list(
self,
*,
keyword: str | None = None,
company_id: str | None = None,
limit: int = 100,
@@ -21,19 +24,30 @@ class SitesRepository:
) -> tuple[list[Site], int]:
stmt = select(Site)
count_stmt = select(func.count()).select_from(Site)
if keyword:
pattern = f"%{keyword}%"
cond = or_(Site.site_key.ilike(pattern), Site.name.ilike(pattern))
cond = or_(Site.site_key.ilike(pattern), Site.display_name.ilike(pattern), Site.domain.ilike(pattern))
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if company_id:
stmt = stmt.where(Site.company_id == company_id)
count_stmt = count_stmt.where(Site.company_id == company_id)
stmt = stmt.order_by(Site.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, site_key: str, company_id: str, name: str, status: str = "active") -> Site:
item = Site(site_key=site_key, company_id=company_id, name=name, status=status)
def create(
self,
*,
site_key: str,
company_id: str,
display_name: str,
domain: str | None,
status: str = "active",
) -> Site:
item = Site(site_key=site_key, company_id=company_id, display_name=display_name, domain=domain, status=status)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
@@ -44,15 +58,25 @@ class SitesRepository:
item: Site,
*,
company_id: str | None = None,
name: str | None = None,
display_name: str | None = None,
domain: str | None = None,
idp_group_id: str | None = None,
status: str | None = None,
) -> Site:
if company_id is not None:
item.company_id = company_id
if name is not None:
item.name = name
if display_name is not None:
item.display_name = display_name
if domain is not None:
item.domain = domain
if idp_group_id is not None:
item.idp_group_id = idp_group_id
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item
def delete(self, item: Site) -> None:
self.db.delete(item)
self.db.commit()

View File

@@ -1,4 +1,4 @@
from sqlalchemy import func, select
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.models.system import System
@@ -9,34 +9,51 @@ class SystemsRepository:
self.db = db
def get_by_key(self, system_key: str) -> System | None:
stmt = select(System).where(System.system_key == system_key)
return self.db.scalar(stmt)
return self.db.scalar(select(System).where(System.system_key == system_key))
def get_by_id(self, system_id: str) -> System | None:
stmt = select(System).where(System.id == system_id)
return self.db.scalar(stmt)
return self.db.scalar(select(System).where(System.id == system_id))
def list(self, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[System], int]:
def list(self, *, keyword: str | None = None, status: str | None = None, limit: int = 100, offset: int = 0) -> tuple[list[System], int]:
stmt = select(System)
count_stmt = select(func.count()).select_from(System)
if keyword:
pattern = f"%{keyword}%"
cond = or_(System.system_key.ilike(pattern), System.name.ilike(pattern), System.idp_client_id.ilike(pattern))
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if status:
stmt = stmt.where(System.status == status)
count_stmt = count_stmt.where(System.status == status)
stmt = stmt.order_by(System.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, system_key: str, name: str, status: str = "active") -> System:
item = System(system_key=system_key, name=name, status=status)
def create(self, *, system_key: str, name: str, idp_client_id: str, status: str = "active") -> System:
item = System(system_key=system_key, name=name, idp_client_id=idp_client_id, status=status)
self.db.add(item)
self.db.commit()
self.db.refresh(item)
return item
def update(self, item: System, *, name: str | None = None, status: str | None = None) -> System:
def update(
self,
item: System,
*,
name: str | None = None,
idp_client_id: str | None = None,
status: str | None = None,
) -> System:
if name is not None:
item.name = name
if idp_client_id is not None:
item.idp_client_id = idp_client_id
if status is not None:
item.status = status
self.db.commit()
self.db.refresh(item)
return item
def delete(self, item: System) -> None:
self.db.delete(item)
self.db.commit()

View File

@@ -0,0 +1,54 @@
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.models.company import Company
from app.models.role import Role
from app.models.site import Site
from app.models.site_role import SiteRole
from app.models.system import System
from app.models.user import User
from app.models.user_site import UserSite
class UserSitesRepository:
def __init__(self, db: Session) -> None:
self.db = db
def list_user_site_rows(self, user_id: str) -> list[tuple[UserSite, Site, Company]]:
stmt = (
select(UserSite, Site, Company)
.join(Site, Site.id == UserSite.site_id)
.join(Company, Company.id == Site.company_id)
.where(UserSite.user_id == user_id)
.order_by(Company.display_name.asc(), Site.display_name.asc())
)
return list(self.db.execute(stmt).all())
def list_site_member_rows(self, site_id: str) -> list[tuple[UserSite, User]]:
stmt = (
select(UserSite, User)
.join(User, User.id == UserSite.user_id)
.where(UserSite.site_id == site_id)
.order_by(User.display_name.asc().nulls_last(), User.username.asc().nulls_last(), User.user_sub.asc())
)
return list(self.db.execute(stmt).all())
def set_user_sites(self, *, user_id: str, site_ids: list[str]) -> None:
self.db.execute(delete(UserSite).where(UserSite.user_id == user_id))
for site_id in site_ids:
self.db.add(UserSite(user_id=user_id, site_id=site_id))
self.db.commit()
def get_user_role_rows(self, user_id: str) -> list[tuple[Site, Company, Role, System]]:
stmt = (
select(Site, Company, Role, System)
.select_from(UserSite)
.join(Site, Site.id == UserSite.site_id)
.join(Company, Company.id == Site.company_id)
.join(SiteRole, SiteRole.site_id == Site.id)
.join(Role, Role.id == SiteRole.role_id)
.join(System, System.id == Role.system_id)
.where(UserSite.user_id == user_id)
.order_by(Company.display_name.asc(), Site.display_name.asc(), System.name.asc(), Role.name.asc())
)
return list(self.db.execute(stmt).all())

View File

@@ -9,15 +9,14 @@ class UsersRepository:
self.db = db
def get_by_sub(self, user_sub: str) -> User | None:
stmt = select(User).where(User.user_sub == user_sub)
return self.db.scalar(stmt)
return self.db.scalar(select(User).where(User.user_sub == user_sub))
def get_by_id(self, user_id: str) -> User | None:
stmt = select(User).where(User.id == user_id)
return self.db.scalar(stmt)
return self.db.scalar(select(User).where(User.id == user_id))
def list(
self,
*,
keyword: str | None = None,
is_active: bool | None = None,
limit: int = 50,
@@ -48,11 +47,13 @@ class UsersRepository:
def upsert_by_sub(
self,
*,
user_sub: str,
username: str | None,
email: str | None,
display_name: str | None,
is_active: bool,
status: str = "active",
idp_user_id: str | None = None,
) -> User:
user = self.get_by_sub(user_sub)
@@ -64,6 +65,7 @@ class UsersRepository:
email=email,
display_name=display_name,
is_active=is_active,
status=status,
)
self.db.add(user)
else:
@@ -73,6 +75,7 @@ class UsersRepository:
user.email = email
user.display_name = display_name
user.is_active = is_active
user.status = status
self.db.commit()
self.db.refresh(user)
@@ -86,6 +89,7 @@ class UsersRepository:
email: str | None = None,
display_name: str | None = None,
is_active: bool | None = None,
status: str | None = None,
) -> User:
if username is not None:
user.username = username
@@ -95,6 +99,13 @@ class UsersRepository:
user.display_name = display_name
if is_active is not None:
user.is_active = is_active
if status is not None:
user.status = status
self.db.commit()
self.db.refresh(user)
return user
def delete(self, user: User) -> None:
self.db.delete(user)
self.db.commit()