From 573a75b9e3e6dc015a8c50cc952be46e50920e41 Mon Sep 17 00:00:00 2001 From: Chris Date: Fri, 3 Apr 2026 00:46:46 +0800 Subject: [PATCH] feat(sync): keycloak as source-of-truth with auto catalog sync and token refresh --- app/api/admin_catalog.py | 113 ++++++++--- app/api/auth.py | 52 ++++- app/repositories/companies_repo.py | 18 +- app/repositories/sites_repo.py | 10 +- app/schemas/login.py | 6 + app/services/idp_admin_service.py | 177 +++++++++++++++++ app/services/idp_catalog_sync.py | 306 +++++++++++++++++++++++++++++ 7 files changed, 649 insertions(+), 33 deletions(-) create mode 100644 app/services/idp_catalog_sync.py diff --git a/app/api/admin_catalog.py b/app/api/admin_catalog.py index ebd3768..50c0dfc 100644 --- a/app/api/admin_catalog.py +++ b/app/api/admin_catalog.py @@ -52,6 +52,7 @@ from app.schemas.catalog import ( from app.security.admin_guard import require_admin_principal from app.security.api_client_auth import hash_api_key from app.services.idp_admin_service import KeycloakAdminService +from app.services.idp_catalog_sync import sync_from_keycloak from app.core.config import get_settings router = APIRouter( @@ -123,6 +124,7 @@ def list_companies( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + sync_from_keycloak(db) repo = CompaniesRepository(db) items, total = repo.list(keyword=keyword, limit=limit, offset=offset) return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset) @@ -131,8 +133,24 @@ def list_companies( @router.post("/companies", response_model=CompanyItem) def create_company(payload: CompanyCreateRequest, db: Session = Depends(get_db)) -> CompanyItem: repo = CompaniesRepository(db) + idp = KeycloakAdminService(get_settings()) company_key = _generate_unique_key("CP", lambda key: repo.get_by_key(key) is not None) - item = repo.create(company_key=company_key, display_name=payload.display_name, legal_name=payload.legal_name, status=payload.status) + group = idp.ensure_group( + name=company_key, + attributes={ + "member_entity_type": "company", + "company_key": company_key, + "display_name": payload.display_name, + "status": payload.status, + }, + ) + item = repo.create( + company_key=company_key, + display_name=payload.display_name, + legal_name=payload.legal_name, + idp_group_id=group.group_id, + status=payload.status, + ) return _company_item(item) @@ -142,11 +160,25 @@ def update_company(company_key: str, payload: CompanyUpdateRequest, db: Session item = repo.get_by_key(company_key) if not item: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") + idp = KeycloakAdminService(get_settings()) + resolved_display_name = payload.display_name if payload.display_name is not None else item.display_name + resolved_status = payload.status if payload.status is not None else item.status + resolved_group_id = payload.idp_group_id or item.idp_group_id + group = idp.ensure_group( + group_id=resolved_group_id, + name=company_key, + attributes={ + "member_entity_type": "company", + "company_key": company_key, + "display_name": resolved_display_name, + "status": resolved_status, + }, + ) item = repo.update( item, display_name=payload.display_name, legal_name=payload.legal_name, - idp_group_id=payload.idp_group_id, + idp_group_id=group.group_id, status=payload.status, ) return _company_item(item) @@ -155,9 +187,11 @@ def update_company(company_key: str, payload: CompanyUpdateRequest, db: Session @router.delete("/companies/{company_key}") def delete_company(company_key: str, db: Session = Depends(get_db)) -> dict[str, str]: repo = CompaniesRepository(db) + idp = KeycloakAdminService(get_settings()) item = repo.get_by_key(company_key) if not item: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") + idp.delete_group(group_id=item.idp_group_id) repo.delete(item) return {"deleted": company_key} @@ -181,6 +215,7 @@ def list_sites( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + sync_from_keycloak(db) companies_repo = CompaniesRepository(db) sites_repo = SitesRepository(db) company_id = None @@ -201,16 +236,31 @@ def list_sites( def create_site(payload: SiteCreateRequest, db: Session = Depends(get_db)) -> SiteItem: companies_repo = CompaniesRepository(db) sites_repo = SitesRepository(db) + idp = KeycloakAdminService(get_settings()) company = companies_repo.get_by_key(payload.company_key) if not company: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") site_key = _generate_unique_key("ST", lambda key: sites_repo.get_by_key(key) is not None) + group = idp.ensure_group( + group_id=None, + name=site_key, + parent_group_id=company.idp_group_id, + attributes={ + "member_entity_type": "site", + "site_key": site_key, + "company_key": company.company_key, + "display_name": payload.display_name, + "domain": payload.domain or "", + "status": payload.status, + }, + ) item = sites_repo.create( site_key=site_key, company_id=company.id, display_name=payload.display_name, domain=payload.domain, + idp_group_id=group.group_id, status=payload.status, ) return _site_item(item, company) @@ -220,24 +270,46 @@ def create_site(payload: SiteCreateRequest, db: Session = Depends(get_db)) -> Si def update_site(site_key: str, payload: SiteUpdateRequest, db: Session = Depends(get_db)) -> SiteItem: companies_repo = CompaniesRepository(db) sites_repo = SitesRepository(db) + idp = KeycloakAdminService(get_settings()) item = sites_repo.get_by_key(site_key) if not item: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found") + target_company = companies_repo.get_by_id(item.company_id) company_id = None if payload.company_key: - company = companies_repo.get_by_key(payload.company_key) - if not company: + target_company = companies_repo.get_by_key(payload.company_key) + if not target_company: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") - company_id = company.id + company_id = target_company.id + if not target_company: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="company_reference_missing") + + resolved_display_name = payload.display_name if payload.display_name is not None else item.display_name + resolved_domain = payload.domain if payload.domain is not None else item.domain + resolved_status = payload.status if payload.status is not None else item.status + resolved_group_id = payload.idp_group_id or item.idp_group_id + group = idp.ensure_group( + group_id=resolved_group_id, + name=site_key, + parent_group_id=target_company.idp_group_id, + attributes={ + "member_entity_type": "site", + "site_key": site_key, + "company_key": target_company.company_key, + "display_name": resolved_display_name, + "domain": resolved_domain or "", + "status": resolved_status, + }, + ) item = sites_repo.update( item, company_id=company_id, display_name=payload.display_name, domain=payload.domain, - idp_group_id=payload.idp_group_id, + idp_group_id=group.group_id, status=payload.status, ) company = companies_repo.get_by_id(item.company_id) @@ -249,9 +321,11 @@ def update_site(site_key: str, payload: SiteUpdateRequest, db: Session = Depends @router.delete("/sites/{site_key}") def delete_site(site_key: str, db: Session = Depends(get_db)) -> dict[str, str]: repo = SitesRepository(db) + idp = KeycloakAdminService(get_settings()) item = repo.get_by_key(site_key) if not item: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found") + idp.delete_group(group_id=item.idp_group_id) repo.delete(item) return {"deleted": site_key} @@ -264,6 +338,7 @@ def list_systems( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + sync_from_keycloak(db) repo = SystemsRepository(db) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset) @@ -271,30 +346,17 @@ def list_systems( @router.post("/systems", response_model=SystemItem) def create_system(payload: SystemCreateRequest, db: Session = Depends(get_db)) -> SystemItem: - repo = SystemsRepository(db) - system_key = _generate_unique_key("SY", lambda key: repo.get_by_key(key) is not None) - item = repo.create(system_key=system_key, name=payload.name, idp_client_id=payload.idp_client_id, status=payload.status) - return _system_item(item) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only") @router.patch("/systems/{system_key}", response_model=SystemItem) def update_system(system_key: str, payload: SystemUpdateRequest, db: Session = Depends(get_db)) -> SystemItem: - repo = SystemsRepository(db) - item = repo.get_by_key(system_key) - if not item: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found") - item = repo.update(item, name=payload.name, idp_client_id=payload.idp_client_id, status=payload.status) - return _system_item(item) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only") @router.delete("/systems/{system_key}") def delete_system(system_key: str, db: Session = Depends(get_db)) -> dict[str, str]: - repo = SystemsRepository(db) - item = repo.get_by_key(system_key) - if not item: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found") - repo.delete(item) - return {"deleted": system_key} + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only") @router.get("/roles", response_model=ListResponse) @@ -306,6 +368,7 @@ def list_roles( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + sync_from_keycloak(db) systems_repo = SystemsRepository(db) roles_repo = RolesRepository(db) @@ -544,6 +607,7 @@ def list_members( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + sync_from_keycloak(db) repo = UsersRepository(db) rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset) return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset) @@ -754,6 +818,11 @@ def list_api_clients( ) +@router.post("/sync/from-keycloak") +def sync_catalog_from_keycloak(db: Session = Depends(get_db), force: bool = Query(default=True)) -> dict[str, int]: + return sync_from_keycloak(db, force=force) + + @router.post("/api-clients", response_model=ApiClientCreateResponse) def create_api_client(payload: ApiClientCreateRequest, db: Session = Depends(get_db)) -> ApiClientCreateResponse: repo = ApiClientsRepository(db) diff --git a/app/api/auth.py b/app/api/auth.py index 3dd102e..3891e94 100644 --- a/app/api/auth.py +++ b/app/api/auth.py @@ -5,7 +5,7 @@ import httpx from fastapi import APIRouter, HTTPException, status from app.core.config import get_settings -from app.schemas.login import LoginRequest, LoginResponse, OIDCAuthUrlResponse, OIDCCodeExchangeRequest +from app.schemas.login import LoginRequest, LoginResponse, OIDCAuthUrlResponse, OIDCCodeExchangeRequest, RefreshTokenRequest router = APIRouter(prefix="/auth", tags=["auth"]) logger = logging.getLogger(__name__) @@ -45,12 +45,7 @@ def login(payload: LoginRequest) -> LoginResponse: token = data.get("access_token") if not token: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token") - return LoginResponse( - access_token=token, - token_type=data.get("token_type", "Bearer"), - expires_in=data.get("expires_in"), - scope=data.get("scope"), - ) + return _build_login_response(data) @router.get("/oidc/url", response_model=OIDCAuthUrlResponse) @@ -123,9 +118,50 @@ def exchange_oidc_code(payload: OIDCCodeExchangeRequest) -> LoginResponse: token = data.get("access_token") if not token: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token") + return _build_login_response(data) + + +@router.post("/refresh", response_model=LoginResponse) +def refresh_access_token(payload: RefreshTokenRequest) -> LoginResponse: + settings = get_settings() + client_id = settings.idp_client_id or settings.idp_audience + if not settings.idp_base_url or not client_id or not settings.idp_client_secret: + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="idp_login_not_configured") + + form = { + "grant_type": "refresh_token", + "client_id": client_id, + "client_secret": settings.idp_client_secret, + "refresh_token": payload.refresh_token, + } + try: + resp = httpx.post( + settings.idp_token_endpoint, + data=form, + timeout=10, + verify=settings.idp_verify_tls, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + except Exception as exc: + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_unreachable") from exc + + if resp.status_code >= 400: + logger.warning("idp refresh-token grant failed: status=%s body=%s", resp.status_code, resp.text) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_refresh_token") + + data = resp.json() + token = data.get("access_token") + if not token: + raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token") + return _build_login_response(data) + + +def _build_login_response(data: dict) -> LoginResponse: return LoginResponse( - access_token=token, + access_token=data.get("access_token", ""), + refresh_token=data.get("refresh_token"), token_type=data.get("token_type", "Bearer"), expires_in=data.get("expires_in"), + refresh_expires_in=data.get("refresh_expires_in"), scope=data.get("scope"), ) diff --git a/app/repositories/companies_repo.py b/app/repositories/companies_repo.py index bbe8979..3ef2ba2 100644 --- a/app/repositories/companies_repo.py +++ b/app/repositories/companies_repo.py @@ -30,8 +30,22 @@ class CompaniesRepository: stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) - def create(self, *, company_key: str, display_name: str, legal_name: str | None, status: str = "active") -> Company: - item = Company(company_key=company_key, display_name=display_name, legal_name=legal_name, status=status) + def create( + self, + *, + company_key: str, + display_name: str, + legal_name: str | None, + idp_group_id: str | None = None, + status: str = "active", + ) -> Company: + item = Company( + company_key=company_key, + display_name=display_name, + legal_name=legal_name, + idp_group_id=idp_group_id, + status=status, + ) self.db.add(item) self.db.commit() self.db.refresh(item) diff --git a/app/repositories/sites_repo.py b/app/repositories/sites_repo.py index 679647e..8ed1164 100644 --- a/app/repositories/sites_repo.py +++ b/app/repositories/sites_repo.py @@ -45,9 +45,17 @@ class SitesRepository: company_id: str, display_name: str, domain: str | None, + idp_group_id: str | None = None, status: str = "active", ) -> Site: - item = Site(site_key=site_key, company_id=company_id, display_name=display_name, domain=domain, status=status) + item = Site( + site_key=site_key, + company_id=company_id, + display_name=display_name, + domain=domain, + idp_group_id=idp_group_id, + status=status, + ) self.db.add(item) self.db.commit() self.db.refresh(item) diff --git a/app/schemas/login.py b/app/schemas/login.py index c8ee3f5..4d1e99d 100644 --- a/app/schemas/login.py +++ b/app/schemas/login.py @@ -8,8 +8,10 @@ class LoginRequest(BaseModel): class LoginResponse(BaseModel): access_token: str + refresh_token: str | None = None token_type: str = "Bearer" expires_in: int | None = None + refresh_expires_in: int | None = None scope: str | None = None @@ -21,3 +23,7 @@ class OIDCCodeExchangeRequest(BaseModel): code: str redirect_uri: str code_verifier: str | None = None + + +class RefreshTokenRequest(BaseModel): + refresh_token: str diff --git a/app/services/idp_admin_service.py b/app/services/idp_admin_service.py index 4e0e337..18ee26a 100644 --- a/app/services/idp_admin_service.py +++ b/app/services/idp_admin_service.py @@ -29,6 +29,12 @@ class KeycloakDeleteResult: user_id: str | None = None +@dataclass +class KeycloakGroupSyncResult: + group_id: str + action: str + + class KeycloakAdminService: def __init__(self, settings: Settings) -> None: self.base_url = settings.keycloak_base_url.rstrip("/") @@ -97,6 +103,64 @@ class KeycloakAdminService: raise HTTPException(status_code=502, detail="idp_lookup_failed") return resp.json() + def _lookup_group_by_id(self, client: httpx.Client, group_id: str) -> dict | None: + resp = client.get(f"/admin/realms/{self.realm}/groups/{group_id}") + if resp.status_code == 404: + return None + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_lookup_failed") + payload = resp.json() if resp.content else {} + return payload if isinstance(payload, dict) else None + + def _lookup_group_by_name(self, client: httpx.Client, *, name: str, parent_group_id: str | None) -> dict | None: + if parent_group_id: + resp = client.get( + f"/admin/realms/{self.realm}/groups/{parent_group_id}/children", + params={"search": name, "briefRepresentation": "false"}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_lookup_failed") + matches = resp.json() if isinstance(resp.json(), list) else [] + for row in matches: + if isinstance(row, dict) and str(row.get("name", "")).strip() == name: + return row + return None + + resp = client.get( + f"/admin/realms/{self.realm}/groups", + params={"search": name, "exact": "true", "briefRepresentation": "false"}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_lookup_failed") + matches = resp.json() if isinstance(resp.json(), list) else [] + for row in matches: + if not isinstance(row, dict): + continue + if str(row.get("name", "")).strip() != name: + continue + parent_id = row.get("parentId") + if parent_group_id: + if str(parent_id or "") == parent_group_id: + return row + elif not parent_id: + return row + return None + + @staticmethod + def _normalize_group_attributes(attributes: dict[str, str | list[str]] | None) -> dict[str, list[str]]: + if not attributes: + return {} + output: dict[str, list[str]] = {} + for key, value in attributes.items(): + normalized_key = str(key).strip() + if not normalized_key: + continue + if isinstance(value, list): + output[normalized_key] = [str(v) for v in value if str(v)] + elif value is not None and str(value): + output[normalized_key] = [str(value)] + return output + def _lookup_user_by_email_or_username( self, client: httpx.Client, *, email: str | None, username: str | None ) -> dict | None: @@ -162,6 +226,59 @@ class KeycloakAdminService: raise HTTPException(status_code=502, detail="idp_create_failed") return KeycloakSyncResult(user_id=user_id, action="created", user_sub=user_id) + def ensure_group( + self, + *, + name: str, + group_id: str | None = None, + parent_group_id: str | None = None, + attributes: dict[str, str | list[str]] | None = None, + ) -> KeycloakGroupSyncResult: + if not name: + raise HTTPException(status_code=400, detail="idp_group_name_required") + normalized_attrs = self._normalize_group_attributes(attributes) + + with self._client() as client: + existing = self._lookup_group_by_id(client, group_id) if group_id else None + if existing is None: + existing = self._lookup_group_by_name(client, name=name, parent_group_id=parent_group_id) + + if existing and existing.get("id"): + resolved_id = str(existing["id"]) + payload = {"name": name, "attributes": normalized_attrs} + put_resp = client.put(f"/admin/realms/{self.realm}/groups/{resolved_id}", json=payload) + if put_resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_update_failed") + return KeycloakGroupSyncResult(group_id=resolved_id, action="updated") + + payload = {"name": name, "attributes": normalized_attrs} + if parent_group_id: + create_resp = client.post(f"/admin/realms/{self.realm}/groups/{parent_group_id}/children", json=payload) + else: + create_resp = client.post(f"/admin/realms/{self.realm}/groups", json=payload) + if create_resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_create_failed") + + location = create_resp.headers.get("Location", "") + resolved_id = location.rstrip("/").split("/")[-1] if location and "/" in location else "" + if not resolved_id: + found = self._lookup_group_by_name(client, name=name, parent_group_id=parent_group_id) + resolved_id = str(found.get("id")) if found and found.get("id") else "" + if not resolved_id: + raise HTTPException(status_code=502, detail="idp_group_create_failed") + return KeycloakGroupSyncResult(group_id=resolved_id, action="created") + + def delete_group(self, *, group_id: str | None) -> KeycloakDeleteResult: + if not group_id: + return KeycloakDeleteResult(action="not_found") + with self._client() as client: + resp = client.delete(f"/admin/realms/{self.realm}/groups/{group_id}") + if resp.status_code in {204, 404}: + return KeycloakDeleteResult(action="deleted" if resp.status_code == 204 else "not_found") + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_delete_failed") + return KeycloakDeleteResult(action="deleted") + def reset_password( self, *, @@ -207,3 +324,63 @@ class KeycloakAdminService: if resp.status_code >= 400: raise HTTPException(status_code=502, detail="idp_delete_failed") return KeycloakDeleteResult(action="deleted", user_id=user_id) + + def list_groups_tree(self) -> list[dict]: + with self._client() as client: + resp = client.get( + f"/admin/realms/{self.realm}/groups", + params={"first": 0, "max": 5000, "briefRepresentation": "false"}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_group_lookup_failed") + payload = resp.json() if resp.content else [] + return payload if isinstance(payload, list) else [] + + def list_users(self) -> list[dict]: + users: list[dict] = [] + first = 0 + page_size = 200 + with self._client() as client: + while True: + resp = client.get( + f"/admin/realms/{self.realm}/users", + params={"first": first, "max": page_size}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_lookup_failed") + batch = resp.json() if isinstance(resp.json(), list) else [] + users.extend([row for row in batch if isinstance(row, dict)]) + if len(batch) < page_size: + break + first += page_size + return users + + def list_clients(self) -> list[dict]: + clients: list[dict] = [] + first = 0 + page_size = 200 + with self._client() as client: + while True: + resp = client.get( + f"/admin/realms/{self.realm}/clients", + params={"first": first, "max": page_size}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_lookup_failed") + batch = resp.json() if isinstance(resp.json(), list) else [] + clients.extend([row for row in batch if isinstance(row, dict)]) + if len(batch) < page_size: + break + first += page_size + return clients + + def list_client_roles(self, client_uuid: str) -> list[dict]: + with self._client() as client: + resp = client.get( + f"/admin/realms/{self.realm}/clients/{client_uuid}/roles", + params={"first": 0, "max": 5000}, + ) + if resp.status_code >= 400: + raise HTTPException(status_code=502, detail="idp_lookup_failed") + payload = resp.json() if resp.content else [] + return payload if isinstance(payload, list) else [] diff --git a/app/services/idp_catalog_sync.py b/app/services/idp_catalog_sync.py new file mode 100644 index 0000000..e460fbd --- /dev/null +++ b/app/services/idp_catalog_sync.py @@ -0,0 +1,306 @@ +from __future__ import annotations + +import threading +import time + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from app.core.config import get_settings +from app.core.keygen import generate_key +from app.models.company import Company +from app.models.role import Role +from app.models.site import Site +from app.models.system import System +from app.repositories.companies_repo import CompaniesRepository +from app.repositories.roles_repo import RolesRepository +from app.repositories.sites_repo import SitesRepository +from app.repositories.systems_repo import SystemsRepository +from app.repositories.users_repo import UsersRepository +from app.services.idp_admin_service import KeycloakAdminService + +BUILTIN_CLIENT_IDS = { + "account", + "account-console", + "admin-cli", + "broker", + "realm-management", + "security-admin-console", + "master-realm", +} + +_sync_lock = threading.Lock() +_last_synced_at = 0.0 +_min_sync_interval_sec = 30.0 + + +def _generate_unique_key(prefix: str, exists_check) -> str: + for salt in range(5000): + key = generate_key(prefix, salt) + if not exists_check(key): + return key + raise RuntimeError(f"failed_generate_{prefix.lower()}_key") + + +def _first_attr(attrs: dict | None, key: str) -> str | None: + if not isinstance(attrs, dict): + return None + raw = attrs.get(key) + if isinstance(raw, list) and raw: + value = str(raw[0]).strip() + return value or None + if isinstance(raw, str): + value = raw.strip() + return value or None + return None + + +def _flatten_groups(nodes: list[dict], inherited_company_key: str | None = None) -> tuple[dict[str, dict], dict[str, dict]]: + companies: dict[str, dict] = {} + sites: dict[str, dict] = {} + + for node in nodes: + if not isinstance(node, dict): + continue + attrs = node.get("attributes") + group_id = str(node.get("id", "")).strip() or None + name = str(node.get("name", "")).strip() + children = node.get("subGroups") if isinstance(node.get("subGroups"), list) else [] + + company_key = _first_attr(attrs, "company_key") + if not company_key and name.startswith("CP"): + company_key = name + if _first_attr(attrs, "member_entity_type") == "company" and not company_key: + company_key = name or None + + current_company_key = company_key or inherited_company_key + + if company_key: + companies[company_key] = { + "company_key": company_key, + "display_name": _first_attr(attrs, "display_name") or name or company_key, + "status": _first_attr(attrs, "status") or "active", + "idp_group_id": group_id, + } + + site_key = _first_attr(attrs, "site_key") + if not site_key and name.startswith("ST"): + site_key = name + if _first_attr(attrs, "member_entity_type") == "site" and not site_key: + site_key = name or None + if site_key: + sites[site_key] = { + "site_key": site_key, + "company_key": _first_attr(attrs, "company_key") or current_company_key, + "display_name": _first_attr(attrs, "display_name") or name or site_key, + "domain": _first_attr(attrs, "domain"), + "status": _first_attr(attrs, "status") or "active", + "idp_group_id": group_id, + } + + child_companies, child_sites = _flatten_groups(children, current_company_key) + companies.update(child_companies) + sites.update(child_sites) + + return companies, sites + + +def sync_from_keycloak(db: Session, *, force: bool = False) -> dict[str, int]: + global _last_synced_at + now = time.time() + if not force and now - _last_synced_at < _min_sync_interval_sec: + return {"synced": 0} + + if not _sync_lock.acquire(blocking=False): + return {"synced": 0} + + try: + now = time.time() + if not force and now - _last_synced_at < _min_sync_interval_sec: + return {"synced": 0} + + idp = KeycloakAdminService(get_settings()) + companies_repo = CompaniesRepository(db) + sites_repo = SitesRepository(db) + systems_repo = SystemsRepository(db) + roles_repo = RolesRepository(db) + users_repo = UsersRepository(db) + + companies_created = 0 + companies_updated = 0 + sites_created = 0 + sites_updated = 0 + systems_created = 0 + systems_updated = 0 + roles_created = 0 + roles_updated = 0 + users_created_or_updated = 0 + + group_tree = idp.list_groups_tree() + company_records, site_records = _flatten_groups(group_tree) + + company_id_map: dict[str, str] = {} + for company_key, row in company_records.items(): + company = companies_repo.get_by_key(company_key) + if company is None: + company = companies_repo.create( + company_key=company_key, + display_name=row["display_name"], + legal_name=None, + idp_group_id=row["idp_group_id"], + status=row["status"], + ) + companies_created += 1 + else: + company = companies_repo.update( + company, + display_name=row["display_name"], + idp_group_id=row["idp_group_id"], + status=row["status"], + ) + companies_updated += 1 + company_id_map[company_key] = company.id + + for site_key, row in site_records.items(): + company_key = row.get("company_key") + if not company_key: + continue + company_id = company_id_map.get(company_key) + if not company_id: + placeholder = companies_repo.get_by_key(company_key) + if placeholder is None: + placeholder = companies_repo.create( + company_key=company_key, + display_name=company_key, + legal_name=None, + idp_group_id=None, + status="active", + ) + companies_created += 1 + company_id = placeholder.id + company_id_map[company_key] = company_id + + site = sites_repo.get_by_key(site_key) + if site is None: + sites_repo.create( + site_key=site_key, + company_id=company_id, + display_name=row["display_name"], + domain=row["domain"], + idp_group_id=row["idp_group_id"], + status=row["status"], + ) + sites_created += 1 + else: + sites_repo.update( + site, + company_id=company_id, + display_name=row["display_name"], + domain=row["domain"], + idp_group_id=row["idp_group_id"], + status=row["status"], + ) + sites_updated += 1 + + client_rows = idp.list_clients() + system_map_by_client_id: dict[str, System] = {} + for client in client_rows: + client_uuid = str(client.get("id", "")).strip() + client_id = str(client.get("clientId", "")).strip() + if not client_uuid or not client_id: + continue + if client_id in BUILTIN_CLIENT_IDS: + continue + + system = db.scalar(select(System).where(System.idp_client_id == client_id)) + system_name = str(client.get("name", "")).strip() or client_id + system_status = "active" if client.get("enabled", True) else "inactive" + if system is None: + system_key = _generate_unique_key("SY", lambda key: systems_repo.get_by_key(key) is not None) + system = systems_repo.create( + system_key=system_key, + name=system_name, + idp_client_id=client_id, + status=system_status, + ) + systems_created += 1 + else: + system = systems_repo.update( + system, + name=system_name, + status=system_status, + ) + systems_updated += 1 + system_map_by_client_id[client_id] = system + + client_roles = idp.list_client_roles(client_uuid) + for role_row in client_roles: + if not isinstance(role_row, dict): + continue + role_name = str(role_row.get("name", "")).strip() + if not role_name: + continue + role_desc = str(role_row.get("description", "")).strip() or None + role_status = "active" if not role_row.get("composite", False) else "active" + role = db.scalar( + select(Role).where( + Role.system_id == system.id, + Role.idp_role_name == role_name, + ) + ) + if role is None: + role_key = _generate_unique_key("RL", lambda key: roles_repo.get_by_key(key) is not None) + roles_repo.create( + role_key=role_key, + system_id=system.id, + name=role_name, + description=role_desc, + idp_role_name=role_name, + status=role_status, + ) + roles_created += 1 + else: + roles_repo.update( + role, + name=role_name, + description=role_desc, + status=role_status, + ) + roles_updated += 1 + + for user in idp.list_users(): + user_id = str(user.get("id", "")).strip() + if not user_id: + continue + display_name = ( + str(user.get("firstName", "")).strip() + or str(user.get("username", "")).strip() + or str(user.get("email", "")).strip() + or user_id + ) + users_repo.upsert_by_sub( + user_sub=user_id, + idp_user_id=user_id, + username=str(user.get("username", "")).strip() or None, + email=str(user.get("email", "")).strip() or None, + display_name=display_name, + is_active=bool(user.get("enabled", True)), + status="active" if user.get("enabled", True) else "inactive", + ) + users_created_or_updated += 1 + + _last_synced_at = time.time() + return { + "synced": 1, + "companies_created": companies_created, + "companies_updated": companies_updated, + "sites_created": sites_created, + "sites_updated": sites_updated, + "systems_created": systems_created, + "systems_updated": systems_updated, + "roles_created": roles_created, + "roles_updated": roles_updated, + "users_upserted": users_created_or_updated, + } + finally: + _sync_lock.release()