diff --git a/app/api/admin_catalog.py b/app/api/admin_catalog.py index 9a019c4..ce1ac42 100644 --- a/app/api/admin_catalog.py +++ b/app/api/admin_catalog.py @@ -646,6 +646,27 @@ def list_role_sites(role_key: str, db: Session = Depends(get_db)) -> RoleSitesRe return RoleSitesResponse(role_key=role_key, sites=result) +@router.put("/roles/{role_key}/sites", response_model=RoleSitesResponse) +def assign_role_sites(role_key: str, payload: UserSiteAssignRequest, db: Session = Depends(get_db)) -> RoleSitesResponse: + roles_repo = RolesRepository(db) + sites_repo = SitesRepository(db) + site_roles_repo = SiteRolesRepository(db) + + role = roles_repo.get_by_key(role_key) + if not role: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="role_not_found") + + site_ids: list[str] = [] + for site_key in list(dict.fromkeys(payload.site_keys)): + site = sites_repo.get_by_key(site_key) + if not site: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"site_not_found:{site_key}") + site_ids.append(site.id) + + site_roles_repo.set_role_sites(role_id=role.id, site_ids=site_ids) + return list_role_sites(role_key=role_key, db=db) + + @router.get("/members", response_model=ListResponse) def list_members( db: Session = Depends(get_db), diff --git a/app/repositories/site_roles_repo.py b/app/repositories/site_roles_repo.py index 4771bdd..689138e 100644 --- a/app/repositories/site_roles_repo.py +++ b/app/repositories/site_roles_repo.py @@ -35,3 +35,9 @@ class SiteRolesRepository: for role_id in role_ids: self.db.add(SiteRole(site_id=site_id, role_id=role_id)) self.db.commit() + + def set_role_sites(self, *, role_id: str, site_ids: list[str]) -> None: + self.db.execute(delete(SiteRole).where(SiteRole.role_id == role_id)) + for site_id in site_ids: + self.db.add(SiteRole(site_id=site_id, role_id=role_id)) + self.db.commit()