refactor: simplify schema names and remove provider id columns

This commit is contained in:
Chris
2026-04-03 01:49:36 +08:00
parent e91639d6d9
commit 64246984ba
18 changed files with 148 additions and 133 deletions

View File

@@ -74,8 +74,7 @@ def _company_item(company) -> CompanyItem:
return CompanyItem( return CompanyItem(
id=company.id, id=company.id,
company_key=company.company_key, company_key=company.company_key,
display_name=company.display_name, name=company.name,
legal_name=company.legal_name,
provider_group_id=company.provider_group_id, provider_group_id=company.provider_group_id,
status=company.status, status=company.status,
) )
@@ -86,7 +85,7 @@ def _site_item(site, company) -> SiteItem:
id=site.id, id=site.id,
site_key=site.site_key, site_key=site.site_key,
company_key=company.company_key, company_key=company.company_key,
company_display_name=company.display_name, company_display_name=company.name,
display_name=site.display_name, display_name=site.display_name,
domain=site.domain, domain=site.domain,
provider_group_id=site.provider_group_id, provider_group_id=site.provider_group_id,
@@ -99,7 +98,6 @@ def _system_item(system) -> SystemItem:
id=system.id, id=system.id,
system_key=system.system_key, system_key=system.system_key,
name=system.name, name=system.name,
provider_client_id=system.provider_client_id,
status=system.status, status=system.status,
) )
@@ -117,8 +115,8 @@ def _member_item(user) -> MemberItem:
) )
def _company_group_name(display_name: str, company_key: str) -> str: def _company_group_name(name: str, company_key: str) -> str:
normalized = display_name.strip() if isinstance(display_name, str) else "" normalized = name.strip() if isinstance(name, str) else ""
if not normalized: if not normalized:
return company_key return company_key
return normalized return normalized
@@ -148,20 +146,19 @@ def create_company(payload: CompanyCreateRequest, db: Session = Depends(get_db))
repo = CompaniesRepository(db) repo = CompaniesRepository(db)
idp = ProviderAdminService(get_settings()) idp = ProviderAdminService(get_settings())
company_key = _generate_unique_key("CP", lambda key: repo.get_by_key(key) is not None) company_key = _generate_unique_key("CP", lambda key: repo.get_by_key(key) is not None)
group_name = _company_group_name(payload.display_name, company_key) group_name = _company_group_name(payload.name, company_key)
group = idp.ensure_group( group = idp.ensure_group(
name=group_name, name=group_name,
attributes={ attributes={
"member_entity_type": "company", "member_entity_type": "company",
"company_key": company_key, "company_key": company_key,
"display_name": payload.display_name, "name": payload.name,
"status": payload.status, "status": payload.status,
}, },
) )
item = repo.create( item = repo.create(
company_key=company_key, company_key=company_key,
display_name=payload.display_name, name=payload.name,
legal_name=payload.legal_name,
provider_group_id=group.group_id, provider_group_id=group.group_id,
status=payload.status, status=payload.status,
) )
@@ -175,24 +172,23 @@ def update_company(company_key: str, payload: CompanyUpdateRequest, db: Session
if not item: if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found")
idp = ProviderAdminService(get_settings()) idp = ProviderAdminService(get_settings())
resolved_display_name = payload.display_name if payload.display_name is not None else item.display_name resolved_name = payload.name if payload.name is not None else item.name
resolved_status = payload.status if payload.status is not None else item.status resolved_status = payload.status if payload.status is not None else item.status
resolved_group_id = payload.provider_group_id or item.provider_group_id resolved_group_id = payload.provider_group_id or item.provider_group_id
group_name = _company_group_name(resolved_display_name, company_key) group_name = _company_group_name(resolved_name, company_key)
group = idp.ensure_group( group = idp.ensure_group(
group_id=resolved_group_id, group_id=resolved_group_id,
name=group_name, name=group_name,
attributes={ attributes={
"member_entity_type": "company", "member_entity_type": "company",
"company_key": company_key, "company_key": company_key,
"display_name": resolved_display_name, "name": resolved_name,
"status": resolved_status, "status": resolved_status,
}, },
) )
item = repo.update( item = repo.update(
item, item,
display_name=payload.display_name, name=payload.name,
legal_name=payload.legal_name,
provider_group_id=group.group_id, provider_group_id=group.group_id,
status=payload.status, status=payload.status,
) )
@@ -406,7 +402,6 @@ def list_roles(
system_key=system_map[row.system_id].system_key, system_key=system_map[row.system_id].system_key,
system_name=system_map[row.system_id].name, system_name=system_map[row.system_id].name,
name=row.name, name=row.name,
provider_role_name=row.provider_role_name,
description=row.description, description=row.description,
status=row.status, status=row.status,
) )
@@ -425,12 +420,9 @@ def create_role(payload: RoleCreateRequest, db: Session = Depends(get_db)) -> Ro
system = systems_repo.get_by_key(payload.system_key) system = systems_repo.get_by_key(payload.system_key)
if not system: if not system:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
if not system.provider_client_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_provider_client_id_missing")
idp.ensure_client_role( idp.ensure_client_role(
provider_client_id=system.provider_client_id, provider_client_id=system.name,
provider_role_name=payload.provider_role_name, provider_role_name=payload.name,
description=payload.description, description=payload.description,
) )
@@ -441,7 +433,6 @@ def create_role(payload: RoleCreateRequest, db: Session = Depends(get_db)) -> Ro
system_id=system.id, system_id=system.id,
name=payload.name, name=payload.name,
description=payload.description, description=payload.description,
provider_role_name=payload.provider_role_name,
status=payload.status, status=payload.status,
) )
except IntegrityError: except IntegrityError:
@@ -454,7 +445,6 @@ def create_role(payload: RoleCreateRequest, db: Session = Depends(get_db)) -> Ro
system_key=system.system_key, system_key=system.system_key,
system_name=system.name, system_name=system.name,
name=row.name, name=row.name,
provider_role_name=row.provider_role_name,
description=row.description, description=row.description,
status=row.status, status=row.status,
) )
@@ -473,8 +463,6 @@ def update_role(role_key: str, payload: RoleUpdateRequest, db: Session = Depends
old_system = systems_repo.get_by_id(role.system_id) old_system = systems_repo.get_by_id(role.system_id)
if not old_system: if not old_system:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="system_reference_missing") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="system_reference_missing")
if not old_system.provider_client_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_provider_client_id_missing")
target_system = old_system target_system = old_system
system_id = None system_id = None
@@ -484,26 +472,23 @@ def update_role(role_key: str, payload: RoleUpdateRequest, db: Session = Depends
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
system_id = system.id system_id = system.id
target_system = system target_system = system
if not target_system.provider_client_id: next_provider_role_name = payload.name if payload.name is not None else role.name
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_provider_client_id_missing")
next_provider_role_name = payload.provider_role_name if payload.provider_role_name is not None else role.provider_role_name
next_description = payload.description if payload.description is not None else role.description next_description = payload.description if payload.description is not None else role.description
if target_system.id != old_system.id: if target_system.id != old_system.id:
idp.ensure_client_role( idp.ensure_client_role(
provider_client_id=target_system.provider_client_id, provider_client_id=target_system.name,
provider_role_name=next_provider_role_name, provider_role_name=next_provider_role_name,
description=next_description, description=next_description,
) )
idp.delete_client_role( idp.delete_client_role(
provider_client_id=old_system.provider_client_id, provider_client_id=old_system.name,
provider_role_name=role.provider_role_name, provider_role_name=role.name,
) )
else: else:
idp.update_client_role( idp.update_client_role(
provider_client_id=target_system.provider_client_id, provider_client_id=target_system.name,
old_provider_role_name=role.provider_role_name, old_provider_role_name=role.name,
new_provider_role_name=next_provider_role_name, new_provider_role_name=next_provider_role_name,
description=next_description, description=next_description,
) )
@@ -514,7 +499,6 @@ def update_role(role_key: str, payload: RoleUpdateRequest, db: Session = Depends
system_id=system_id, system_id=system_id,
name=payload.name, name=payload.name,
description=payload.description, description=payload.description,
provider_role_name=payload.provider_role_name,
status=payload.status, status=payload.status,
) )
except IntegrityError: except IntegrityError:
@@ -531,7 +515,6 @@ def update_role(role_key: str, payload: RoleUpdateRequest, db: Session = Depends
system_key=system.system_key, system_key=system.system_key,
system_name=system.name, system_name=system.name,
name=role.name, name=role.name,
provider_role_name=role.provider_role_name,
description=role.description, description=role.description,
status=role.status, status=role.status,
) )
@@ -549,12 +532,9 @@ def delete_role(role_key: str, db: Session = Depends(get_db)) -> dict[str, str]:
system = systems_repo.get_by_id(role.system_id) system = systems_repo.get_by_id(role.system_id)
if not system: if not system:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="system_reference_missing") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="system_reference_missing")
if not system.provider_client_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_provider_client_id_missing")
idp.delete_client_role( idp.delete_client_role(
provider_client_id=system.provider_client_id, provider_client_id=system.name,
provider_role_name=role.provider_role_name, provider_role_name=role.name,
) )
roles_repo.delete(role) roles_repo.delete(role)
return {"deleted": role_key} return {"deleted": role_key}
@@ -579,7 +559,6 @@ def list_system_roles(system_key: str, db: Session = Depends(get_db)) -> SystemR
system_key=system.system_key, system_key=system.system_key,
system_name=system.name, system_name=system.name,
name=row.name, name=row.name,
provider_role_name=row.provider_role_name,
description=row.description, description=row.description,
status=row.status, status=row.status,
) )
@@ -660,7 +639,7 @@ def list_role_sites(role_key: str, db: Session = Depends(get_db)) -> RoleSitesRe
site_key=site.site_key, site_key=site.site_key,
site_display_name=site.display_name, site_display_name=site.display_name,
company_key=company.company_key, company_key=company.company_key,
company_display_name=company.display_name, company_display_name=company.name,
) )
) )
@@ -812,7 +791,7 @@ def list_member_sites(user_sub: str, db: Session = Depends(get_db)) -> UserSites
site_key=site.site_key, site_key=site.site_key,
site_display_name=site.display_name, site_display_name=site.display_name,
company_key=company.company_key, company_key=company.company_key,
company_display_name=company.display_name, company_display_name=company.name,
) )
for user_site, site, company in rows for user_site, site, company in rows
] ]
@@ -855,12 +834,11 @@ def list_member_effective_roles(user_sub: str, db: Session = Depends(get_db)) ->
site_key=site.site_key, site_key=site.site_key,
site_display_name=site.display_name, site_display_name=site.display_name,
company_key=company.company_key, company_key=company.company_key,
company_display_name=company.display_name, company_display_name=company.name,
system_key=system.system_key, system_key=system.system_key,
system_name=system.name, system_name=system.name,
role_key=role.role_key, role_key=role.role_key,
role_name=role.name, role_name=role.name,
provider_role_name=role.provider_role_name,
) )
for site, company, role, system in rows for site, company, role, system in rows
] ]
@@ -903,14 +881,14 @@ def sync_provider_group_names(db: Session = Depends(get_db)) -> dict[str, int]:
for company in companies: for company in companies:
if not company.provider_group_id: if not company.provider_group_id:
continue continue
group_name = _company_group_name(company.display_name, company.company_key) group_name = _company_group_name(company.name, company.company_key)
idp.ensure_group( idp.ensure_group(
group_id=company.provider_group_id, group_id=company.provider_group_id,
name=group_name, name=group_name,
attributes={ attributes={
"member_entity_type": "company", "member_entity_type": "company",
"company_key": company.company_key, "company_key": company.company_key,
"display_name": company.display_name, "name": company.name,
"status": company.status, "status": company.status,
}, },
) )

View File

@@ -42,7 +42,7 @@ def upsert_user_by_sub(
) )
def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, str, str, str, str, str, str, str]]: def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, str, str, str, str, str, str]]:
users_repo = UsersRepository(db) users_repo = UsersRepository(db)
user_sites_repo = UserSitesRepository(db) user_sites_repo = UserSitesRepository(db)
@@ -56,12 +56,11 @@ def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, st
site.site_key, site.site_key,
site.display_name, site.display_name,
company.company_key, company.company_key,
company.display_name, company.name,
system.system_key, system.system_key,
system.name, system.name,
role.role_key, role.role_key,
role.name, role.name,
role.provider_role_name,
) )
for site, company, role, system in rows for site, company, role, system in rows
] ]
@@ -82,7 +81,6 @@ def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUser
system_name=system_name, system_name=system_name,
role_key=role_key, role_key=role_key,
role_name=role_name, role_name=role_name,
provider_role_name=provider_role_name,
) )
for ( for (
site_key, site_key,
@@ -93,7 +91,6 @@ def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUser
system_name, system_name,
role_key, role_key,
role_name, role_name,
provider_role_name,
) in rows ) in rows
], ],
) )

View File

@@ -34,7 +34,6 @@ def internal_list_systems(
"id": i.id, "id": i.id,
"system_key": i.system_key, "system_key": i.system_key,
"name": i.name, "name": i.name,
"provider_client_id": i.provider_client_id,
"status": i.status, "status": i.status,
} }
for i in items for i in items
@@ -72,7 +71,6 @@ def internal_list_roles(
system_key=system_map[i.system_id].system_key, system_key=system_map[i.system_id].system_key,
system_name=system_map[i.system_id].name, system_name=system_map[i.system_id].name,
name=i.name, name=i.name,
provider_role_name=i.provider_role_name,
description=i.description, description=i.description,
status=i.status, status=i.status,
) )
@@ -96,8 +94,7 @@ def internal_list_companies(
{ {
"id": i.id, "id": i.id,
"company_key": i.company_key, "company_key": i.company_key,
"display_name": i.display_name, "name": i.name,
"legal_name": i.legal_name,
"status": i.status, "status": i.status,
} }
for i in items for i in items
@@ -131,7 +128,7 @@ def internal_list_sites(
"id": i.id, "id": i.id,
"site_key": i.site_key, "site_key": i.site_key,
"company_key": mapping[i.company_id].company_key, "company_key": mapping[i.company_id].company_key,
"company_display_name": mapping[i.company_id].display_name, "company_display_name": mapping[i.company_id].name,
"display_name": i.display_name, "display_name": i.display_name,
"domain": i.domain, "domain": i.domain,
"status": i.status, "status": i.status,

View File

@@ -60,12 +60,11 @@ def get_my_permission_snapshot(
site.site_key, site.site_key,
site.display_name, site.display_name,
company.company_key, company.company_key,
company.display_name, company.name,
system.system_key, system.system_key,
system.name, system.name,
role.role_key, role.role_key,
role.name, role.name,
role.provider_role_name,
) )
for site, company, role, system in rows for site, company, role, system in rows
] ]

View File

@@ -13,11 +13,18 @@ class Company(Base):
id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()))
company_key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) company_key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
display_name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
legal_name: Mapped[str | None] = mapped_column(String(255))
provider_group_id: Mapped[str | None] = mapped_column(String(128)) provider_group_id: Mapped[str | None] = mapped_column(String(128))
status: Mapped[str] = mapped_column(String(16), nullable=False, default="active") status: Mapped[str] = mapped_column(String(16), nullable=False, default="active")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
) )
@property
def display_name(self) -> str:
return self.name
@display_name.setter
def display_name(self, value: str) -> None:
self.name = value

View File

@@ -10,16 +10,23 @@ from app.db.base import Base
class Role(Base): class Role(Base):
__tablename__ = "roles" __tablename__ = "roles"
__table_args__ = (UniqueConstraint("system_id", "provider_role_name", name="uq_roles_system_provider_role_name"),) __table_args__ = (UniqueConstraint("system_id", "name", name="uq_roles_system_name"),)
id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()))
role_key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True) role_key: Mapped[str] = mapped_column(String(128), unique=True, nullable=False, index=True)
system_id: Mapped[str] = mapped_column(UUID(as_uuid=False), ForeignKey("systems.id", ondelete="CASCADE"), nullable=False) system_id: Mapped[str] = mapped_column(UUID(as_uuid=False), ForeignKey("systems.id", ondelete="CASCADE"), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(String(1024)) description: Mapped[str | None] = mapped_column(String(1024))
provider_role_name: Mapped[str] = mapped_column(String(255), nullable=False)
status: Mapped[str] = mapped_column(String(16), nullable=False, default="active") status: Mapped[str] = mapped_column(String(16), nullable=False, default="active")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
) )
@property
def provider_role_name(self) -> str:
return self.name
@provider_role_name.setter
def provider_role_name(self, value: str) -> None:
self.name = value

View File

@@ -14,9 +14,16 @@ class System(Base):
id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4())) id: Mapped[str] = mapped_column(UUID(as_uuid=False), primary_key=True, default=lambda: str(uuid4()))
system_key: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True) system_key: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
name: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False)
provider_client_id: Mapped[str] = mapped_column(String(128), unique=True, nullable=False)
status: Mapped[str] = mapped_column(String(16), nullable=False, default="active") status: Mapped[str] = mapped_column(String(16), nullable=False, default="active")
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False
) )
@property
def provider_client_id(self) -> str:
return self.name
@provider_client_id.setter
def provider_client_id(self, value: str) -> None:
self.name = value

View File

@@ -21,8 +21,7 @@ class CompaniesRepository:
pattern = f"%{keyword}%" pattern = f"%{keyword}%"
cond = or_( cond = or_(
Company.company_key.ilike(pattern), Company.company_key.ilike(pattern),
Company.display_name.ilike(pattern), Company.name.ilike(pattern),
Company.legal_name.ilike(pattern),
) )
stmt = stmt.where(cond) stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond) count_stmt = count_stmt.where(cond)
@@ -34,15 +33,13 @@ class CompaniesRepository:
self, self,
*, *,
company_key: str, company_key: str,
display_name: str, name: str,
legal_name: str | None,
provider_group_id: str | None = None, provider_group_id: str | None = None,
status: str = "active", status: str = "active",
) -> Company: ) -> Company:
item = Company( item = Company(
company_key=company_key, company_key=company_key,
display_name=display_name, name=name,
legal_name=legal_name,
provider_group_id=provider_group_id, provider_group_id=provider_group_id,
status=status, status=status,
) )
@@ -55,15 +52,12 @@ class CompaniesRepository:
self, self,
item: Company, item: Company,
*, *,
display_name: str | None = None, name: str | None = None,
legal_name: str | None = None,
provider_group_id: str | None = None, provider_group_id: str | None = None,
status: str | None = None, status: str | None = None,
) -> Company: ) -> Company:
if display_name is not None: if name is not None:
item.display_name = display_name item.name = name
if legal_name is not None:
item.legal_name = legal_name
if provider_group_id is not None: if provider_group_id is not None:
item.provider_group_id = provider_group_id item.provider_group_id = provider_group_id
if status is not None: if status is not None:

View File

@@ -30,7 +30,6 @@ class RolesRepository:
cond = or_( cond = or_(
Role.role_key.ilike(pattern), Role.role_key.ilike(pattern),
Role.name.ilike(pattern), Role.name.ilike(pattern),
Role.provider_role_name.ilike(pattern),
Role.description.ilike(pattern), Role.description.ilike(pattern),
) )
stmt = stmt.where(cond) stmt = stmt.where(cond)
@@ -52,7 +51,6 @@ class RolesRepository:
system_id: str, system_id: str,
name: str, name: str,
description: str | None, description: str | None,
provider_role_name: str,
status: str = "active", status: str = "active",
) -> Role: ) -> Role:
item = Role( item = Role(
@@ -60,7 +58,6 @@ class RolesRepository:
system_id=system_id, system_id=system_id,
name=name, name=name,
description=description, description=description,
provider_role_name=provider_role_name,
status=status, status=status,
) )
self.db.add(item) self.db.add(item)
@@ -75,7 +72,6 @@ class RolesRepository:
system_id: str | None = None, system_id: str | None = None,
name: str | None = None, name: str | None = None,
description: str | None = None, description: str | None = None,
provider_role_name: str | None = None,
status: str | None = None, status: str | None = None,
) -> Role: ) -> Role:
if system_id is not None: if system_id is not None:
@@ -84,8 +80,6 @@ class RolesRepository:
item.name = name item.name = name
if description is not None: if description is not None:
item.description = description item.description = description
if provider_role_name is not None:
item.provider_role_name = provider_role_name
if status is not None: if status is not None:
item.status = status item.status = status
self.db.commit() self.db.commit()

View File

@@ -19,7 +19,7 @@ class SystemsRepository:
count_stmt = select(func.count()).select_from(System) count_stmt = select(func.count()).select_from(System)
if keyword: if keyword:
pattern = f"%{keyword}%" pattern = f"%{keyword}%"
cond = or_(System.system_key.ilike(pattern), System.name.ilike(pattern), System.provider_client_id.ilike(pattern)) cond = or_(System.system_key.ilike(pattern), System.name.ilike(pattern))
stmt = stmt.where(cond) stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond) count_stmt = count_stmt.where(cond)
if status: if status:
@@ -29,8 +29,8 @@ class SystemsRepository:
stmt = stmt.order_by(System.created_at.desc()).limit(limit).offset(offset) stmt = stmt.order_by(System.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, *, system_key: str, name: str, provider_client_id: str, status: str = "active") -> System: def create(self, *, system_key: str, name: str, status: str = "active") -> System:
item = System(system_key=system_key, name=name, provider_client_id=provider_client_id, status=status) item = System(system_key=system_key, name=name, status=status)
self.db.add(item) self.db.add(item)
self.db.commit() self.db.commit()
self.db.refresh(item) self.db.refresh(item)
@@ -41,13 +41,10 @@ class SystemsRepository:
item: System, item: System,
*, *,
name: str | None = None, name: str | None = None,
provider_client_id: str | None = None,
status: str | None = None, status: str | None = None,
) -> System: ) -> System:
if name is not None: if name is not None:
item.name = name item.name = name
if provider_client_id is not None:
item.provider_client_id = provider_client_id
if status is not None: if status is not None:
item.status = status item.status = status
self.db.commit() self.db.commit()

View File

@@ -20,7 +20,7 @@ class UserSitesRepository:
.join(Site, Site.id == UserSite.site_id) .join(Site, Site.id == UserSite.site_id)
.join(Company, Company.id == Site.company_id) .join(Company, Company.id == Site.company_id)
.where(UserSite.user_id == user_id) .where(UserSite.user_id == user_id)
.order_by(Company.display_name.asc(), Site.display_name.asc()) .order_by(Company.name.asc(), Site.display_name.asc())
) )
return list(self.db.execute(stmt).all()) return list(self.db.execute(stmt).all())
@@ -49,6 +49,6 @@ class UserSitesRepository:
.join(Role, Role.id == SiteRole.role_id) .join(Role, Role.id == SiteRole.role_id)
.join(System, System.id == Role.system_id) .join(System, System.id == Role.system_id)
.where(UserSite.user_id == user_id) .where(UserSite.user_id == user_id)
.order_by(Company.display_name.asc(), Site.display_name.asc(), System.name.asc(), Role.name.asc()) .order_by(Company.name.asc(), Site.display_name.asc(), System.name.asc(), Role.name.asc())
) )
return list(self.db.execute(stmt).all()) return list(self.db.execute(stmt).all())

View File

@@ -11,14 +11,12 @@ class ListResponse(BaseModel):
class CompanyCreateRequest(BaseModel): class CompanyCreateRequest(BaseModel):
display_name: str name: str
legal_name: str | None = None
status: str = "active" status: str = "active"
class CompanyUpdateRequest(BaseModel): class CompanyUpdateRequest(BaseModel):
display_name: str | None = None name: str | None = None
legal_name: str | None = None
provider_group_id: str | None = None provider_group_id: str | None = None
status: str | None = None status: str | None = None
@@ -26,8 +24,7 @@ class CompanyUpdateRequest(BaseModel):
class CompanyItem(BaseModel): class CompanyItem(BaseModel):
id: str id: str
company_key: str company_key: str
display_name: str name: str
legal_name: str | None = None
provider_group_id: str | None = None provider_group_id: str | None = None
status: str status: str
@@ -60,13 +57,11 @@ class SiteItem(BaseModel):
class SystemCreateRequest(BaseModel): class SystemCreateRequest(BaseModel):
name: str name: str
provider_client_id: str
status: str = "active" status: str = "active"
class SystemUpdateRequest(BaseModel): class SystemUpdateRequest(BaseModel):
name: str | None = None name: str | None = None
provider_client_id: str | None = None
status: str | None = None status: str | None = None
@@ -74,14 +69,12 @@ class SystemItem(BaseModel):
id: str id: str
system_key: str system_key: str
name: str name: str
provider_client_id: str
status: str status: str
class RoleCreateRequest(BaseModel): class RoleCreateRequest(BaseModel):
system_key: str system_key: str
name: str name: str
provider_role_name: str
description: str | None = None description: str | None = None
status: str = "active" status: str = "active"
@@ -89,7 +82,6 @@ class RoleCreateRequest(BaseModel):
class RoleUpdateRequest(BaseModel): class RoleUpdateRequest(BaseModel):
system_key: str | None = None system_key: str | None = None
name: str | None = None name: str | None = None
provider_role_name: str | None = None
description: str | None = None description: str | None = None
status: str | None = None status: str | None = None
@@ -100,7 +92,6 @@ class RoleItem(BaseModel):
system_key: str system_key: str
system_name: str system_name: str
name: str name: str
provider_role_name: str
description: str | None = None description: str | None = None
status: str status: str
@@ -173,7 +164,6 @@ class UserEffectiveRoleItem(BaseModel):
system_name: str system_name: str
role_key: str role_key: str
role_name: str role_name: str
provider_role_name: str
class UserEffectiveRolesResponse(BaseModel): class UserEffectiveRolesResponse(BaseModel):

View File

@@ -5,7 +5,6 @@ class InternalSystemItem(BaseModel):
id: str id: str
system_key: str system_key: str
name: str name: str
provider_client_id: str
status: str status: str
@@ -22,7 +21,6 @@ class InternalRoleItem(BaseModel):
system_key: str system_key: str
system_name: str system_name: str
name: str name: str
provider_role_name: str
description: str | None = None description: str | None = None
status: str status: str
@@ -37,8 +35,7 @@ class InternalRoleListResponse(BaseModel):
class InternalCompanyItem(BaseModel): class InternalCompanyItem(BaseModel):
id: str id: str
company_key: str company_key: str
display_name: str name: str
legal_name: str | None = None
status: str status: str
@@ -103,7 +100,6 @@ class InternalUserRoleItem(BaseModel):
system_name: str system_name: str
role_key: str role_key: str
role_name: str role_name: str
provider_role_name: str
class InternalUserRoleResponse(BaseModel): class InternalUserRoleResponse(BaseModel):

View File

@@ -10,7 +10,6 @@ class RoleSnapshotItem(BaseModel):
system_name: str system_name: str
role_key: str role_key: str
role_name: str role_name: str
provider_role_name: str
class RoleSnapshotResponse(BaseModel): class RoleSnapshotResponse(BaseModel):

View File

@@ -79,7 +79,7 @@ def _flatten_groups(nodes: list[dict], inherited_company_key: str | None = None)
if company_key: if company_key:
companies[company_key] = { companies[company_key] = {
"company_key": company_key, "company_key": company_key,
"display_name": _first_attr(attrs, "display_name") or name or company_key, "name": _first_attr(attrs, "name") or _first_attr(attrs, "display_name") or name or company_key,
"status": _first_attr(attrs, "status") or "active", "status": _first_attr(attrs, "status") or "active",
"provider_group_id": group_id, "provider_group_id": group_id,
} }
@@ -146,8 +146,7 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
if company is None: if company is None:
company = companies_repo.create( company = companies_repo.create(
company_key=company_key, company_key=company_key,
display_name=row["display_name"], name=row["name"],
legal_name=None,
provider_group_id=row["provider_group_id"], provider_group_id=row["provider_group_id"],
status=row["status"], status=row["status"],
) )
@@ -155,7 +154,7 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
else: else:
company = companies_repo.update( company = companies_repo.update(
company, company,
display_name=row["display_name"], name=row["name"],
provider_group_id=row["provider_group_id"], provider_group_id=row["provider_group_id"],
status=row["status"], status=row["status"],
) )
@@ -172,8 +171,7 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
if placeholder is None: if placeholder is None:
placeholder = companies_repo.create( placeholder = companies_repo.create(
company_key=company_key, company_key=company_key,
display_name=company_key, name=company_key,
legal_name=None,
provider_group_id=None, provider_group_id=None,
status="active", status="active",
) )
@@ -213,7 +211,7 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
if client_id in BUILTIN_CLIENT_IDS: if client_id in BUILTIN_CLIENT_IDS:
continue continue
system = db.scalar(select(System).where(System.provider_client_id == client_id)) system = db.scalar(select(System).where(System.name == client_id))
system_name = str(client.get("name", "")).strip() or client_id system_name = str(client.get("name", "")).strip() or client_id
system_status = "active" if client.get("enabled", True) else "inactive" system_status = "active" if client.get("enabled", True) else "inactive"
if system is None: if system is None:
@@ -221,7 +219,6 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
system = systems_repo.create( system = systems_repo.create(
system_key=system_key, system_key=system_key,
name=system_name, name=system_name,
provider_client_id=client_id,
status=system_status, status=system_status,
) )
systems_created += 1 systems_created += 1
@@ -246,7 +243,7 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
role = db.scalar( role = db.scalar(
select(Role).where( select(Role).where(
Role.system_id == system.id, Role.system_id == system.id,
Role.provider_role_name == role_name, Role.name == role_name,
) )
) )
if role is None: if role is None:
@@ -256,7 +253,6 @@ def sync_from_provider(db: Session, *, force: bool = False) -> dict[str, int]:
system_id=system.id, system_id=system.id,
name=role_name, name=role_name,
description=role_desc, description=role_desc,
provider_role_name=role_name,
status=role_status, status=role_status,
) )
roles_created += 1 roles_created += 1
@@ -339,7 +335,7 @@ def sync_systems_from_provider(db: Session, *, force: bool = False) -> dict[str,
if client_id in BUILTIN_CLIENT_IDS: if client_id in BUILTIN_CLIENT_IDS:
continue continue
system = db.scalar(select(System).where(System.provider_client_id == client_id)) system = db.scalar(select(System).where(System.name == client_id))
system_name = str(client.get("name", "")).strip() or client_id system_name = str(client.get("name", "")).strip() or client_id
system_status = "active" if client.get("enabled", True) else "inactive" system_status = "active" if client.get("enabled", True) else "inactive"
if system is None: if system is None:
@@ -347,7 +343,6 @@ def sync_systems_from_provider(db: Session, *, force: bool = False) -> dict[str,
system = systems_repo.create( system = systems_repo.create(
system_key=system_key, system_key=system_key,
name=system_name, name=system_name,
provider_client_id=client_id,
status=system_status, status=system_status,
) )
systems_created += 1 systems_created += 1
@@ -371,7 +366,7 @@ def sync_systems_from_provider(db: Session, *, force: bool = False) -> dict[str,
role = db.scalar( role = db.scalar(
select(Role).where( select(Role).where(
Role.system_id == system.id, Role.system_id == system.id,
Role.provider_role_name == role_name, Role.name == role_name,
) )
) )
if role is None: if role is None:
@@ -381,7 +376,6 @@ def sync_systems_from_provider(db: Session, *, force: bool = False) -> dict[str,
system_id=system.id, system_id=system.id,
name=role_name, name=role_name,
description=role_desc, description=role_desc,
provider_role_name=role_name,
status=role_status, status=role_status,
) )
roles_created += 1 roles_created += 1

View File

@@ -3,7 +3,7 @@ from app.schemas.permissions import RoleSnapshotItem, RoleSnapshotResponse
class PermissionService: class PermissionService:
@staticmethod @staticmethod
def build_role_snapshot(user_sub: str, rows: list[tuple[str, str, str, str, str, str, str, str, str]]) -> RoleSnapshotResponse: def build_role_snapshot(user_sub: str, rows: list[tuple[str, str, str, str, str, str, str, str]]) -> RoleSnapshotResponse:
return RoleSnapshotResponse( return RoleSnapshotResponse(
user_sub=user_sub, user_sub=user_sub,
roles=[ roles=[
@@ -16,7 +16,6 @@ class PermissionService:
system_name=system_name, system_name=system_name,
role_key=role_key, role_key=role_key,
role_name=role_name, role_name=role_name,
provider_role_name=provider_role_name,
) )
for ( for (
site_key, site_key,
@@ -27,7 +26,6 @@ class PermissionService:
system_name, system_name,
role_key, role_key,
role_name, role_name,
provider_role_name,
) in rows ) in rows
], ],
) )

View File

@@ -37,8 +37,7 @@ CREATE TABLE users (
CREATE TABLE companies ( CREATE TABLE companies (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
company_key TEXT NOT NULL UNIQUE, company_key TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL, name TEXT NOT NULL,
legal_name TEXT,
provider_group_id TEXT, provider_group_id TEXT,
status VARCHAR(16) NOT NULL DEFAULT 'active', status VARCHAR(16) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
@@ -61,7 +60,6 @@ CREATE TABLE systems (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(), id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
system_key TEXT NOT NULL UNIQUE, system_key TEXT NOT NULL UNIQUE,
name TEXT NOT NULL, name TEXT NOT NULL,
provider_client_id TEXT NOT NULL UNIQUE,
status VARCHAR(16) NOT NULL DEFAULT 'active', status VARCHAR(16) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
@@ -73,11 +71,10 @@ CREATE TABLE roles (
system_id UUID NOT NULL REFERENCES systems(id) ON DELETE CASCADE, system_id UUID NOT NULL REFERENCES systems(id) ON DELETE CASCADE,
name TEXT NOT NULL, name TEXT NOT NULL,
description TEXT, description TEXT,
provider_role_name TEXT NOT NULL,
status VARCHAR(16) NOT NULL DEFAULT 'active', status VARCHAR(16) NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
CONSTRAINT uq_roles_system_provider_role_name UNIQUE (system_id, provider_role_name) CONSTRAINT uq_roles_system_name UNIQUE (system_id, name)
); );
CREATE TABLE site_roles ( CREATE TABLE site_roles (

View File

@@ -14,6 +14,70 @@ BEGIN
END IF; END IF;
END $$; END $$;
-- companies.display_name -> companies.name
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'companies' AND column_name = 'display_name'
) AND NOT EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'companies' AND column_name = 'name'
) THEN
ALTER TABLE public.companies RENAME COLUMN display_name TO name;
END IF;
END $$;
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'companies' AND column_name = 'legal_name'
) THEN
ALTER TABLE public.companies DROP COLUMN legal_name;
END IF;
END $$;
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'systems' AND column_name = 'provider_client_id'
) THEN
ALTER TABLE public.systems DROP COLUMN provider_client_id;
END IF;
END $$;
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.columns
WHERE table_schema = 'public' AND table_name = 'roles' AND column_name = 'provider_role_name'
) THEN
ALTER TABLE public.roles DROP COLUMN provider_role_name;
END IF;
END $$;
DO $$
BEGIN
IF EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema='public' AND table_name='roles' AND constraint_name='uq_roles_system_provider_role_name'
) THEN
ALTER TABLE public.roles DROP CONSTRAINT uq_roles_system_provider_role_name;
END IF;
END $$;
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1 FROM information_schema.table_constraints
WHERE table_schema='public' AND table_name='roles' AND constraint_name='uq_roles_system_name'
) THEN
ALTER TABLE public.roles ADD CONSTRAINT uq_roles_system_name UNIQUE (system_id, name);
END IF;
END $$;
DO $$ DO $$
BEGIN BEGIN
IF EXISTS ( IF EXISTS (