feat(sync): keycloak as source-of-truth with auto catalog sync and token refresh

This commit is contained in:
Chris
2026-04-03 00:46:46 +08:00
parent 81085e1844
commit 573a75b9e3
7 changed files with 649 additions and 33 deletions

View File

@@ -52,6 +52,7 @@ from app.schemas.catalog import (
from app.security.admin_guard import require_admin_principal from app.security.admin_guard import require_admin_principal
from app.security.api_client_auth import hash_api_key from app.security.api_client_auth import hash_api_key
from app.services.idp_admin_service import KeycloakAdminService from app.services.idp_admin_service import KeycloakAdminService
from app.services.idp_catalog_sync import sync_from_keycloak
from app.core.config import get_settings from app.core.config import get_settings
router = APIRouter( router = APIRouter(
@@ -123,6 +124,7 @@ def list_companies(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
sync_from_keycloak(db)
repo = CompaniesRepository(db) repo = CompaniesRepository(db)
items, total = repo.list(keyword=keyword, limit=limit, offset=offset) items, total = repo.list(keyword=keyword, limit=limit, offset=offset)
return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset) return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset)
@@ -131,8 +133,24 @@ def list_companies(
@router.post("/companies", response_model=CompanyItem) @router.post("/companies", response_model=CompanyItem)
def create_company(payload: CompanyCreateRequest, db: Session = Depends(get_db)) -> CompanyItem: def create_company(payload: CompanyCreateRequest, db: Session = Depends(get_db)) -> CompanyItem:
repo = CompaniesRepository(db) repo = CompaniesRepository(db)
idp = KeycloakAdminService(get_settings())
company_key = _generate_unique_key("CP", lambda key: repo.get_by_key(key) is not None) company_key = _generate_unique_key("CP", lambda key: repo.get_by_key(key) is not None)
item = repo.create(company_key=company_key, display_name=payload.display_name, legal_name=payload.legal_name, status=payload.status) group = idp.ensure_group(
name=company_key,
attributes={
"member_entity_type": "company",
"company_key": company_key,
"display_name": payload.display_name,
"status": payload.status,
},
)
item = repo.create(
company_key=company_key,
display_name=payload.display_name,
legal_name=payload.legal_name,
idp_group_id=group.group_id,
status=payload.status,
)
return _company_item(item) return _company_item(item)
@@ -142,11 +160,25 @@ def update_company(company_key: str, payload: CompanyUpdateRequest, db: Session
item = repo.get_by_key(company_key) item = repo.get_by_key(company_key)
if not item: if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found")
idp = KeycloakAdminService(get_settings())
resolved_display_name = payload.display_name if payload.display_name is not None else item.display_name
resolved_status = payload.status if payload.status is not None else item.status
resolved_group_id = payload.idp_group_id or item.idp_group_id
group = idp.ensure_group(
group_id=resolved_group_id,
name=company_key,
attributes={
"member_entity_type": "company",
"company_key": company_key,
"display_name": resolved_display_name,
"status": resolved_status,
},
)
item = repo.update( item = repo.update(
item, item,
display_name=payload.display_name, display_name=payload.display_name,
legal_name=payload.legal_name, legal_name=payload.legal_name,
idp_group_id=payload.idp_group_id, idp_group_id=group.group_id,
status=payload.status, status=payload.status,
) )
return _company_item(item) return _company_item(item)
@@ -155,9 +187,11 @@ def update_company(company_key: str, payload: CompanyUpdateRequest, db: Session
@router.delete("/companies/{company_key}") @router.delete("/companies/{company_key}")
def delete_company(company_key: str, db: Session = Depends(get_db)) -> dict[str, str]: def delete_company(company_key: str, db: Session = Depends(get_db)) -> dict[str, str]:
repo = CompaniesRepository(db) repo = CompaniesRepository(db)
idp = KeycloakAdminService(get_settings())
item = repo.get_by_key(company_key) item = repo.get_by_key(company_key)
if not item: if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found")
idp.delete_group(group_id=item.idp_group_id)
repo.delete(item) repo.delete(item)
return {"deleted": company_key} return {"deleted": company_key}
@@ -181,6 +215,7 @@ def list_sites(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
sync_from_keycloak(db)
companies_repo = CompaniesRepository(db) companies_repo = CompaniesRepository(db)
sites_repo = SitesRepository(db) sites_repo = SitesRepository(db)
company_id = None company_id = None
@@ -201,16 +236,31 @@ def list_sites(
def create_site(payload: SiteCreateRequest, db: Session = Depends(get_db)) -> SiteItem: def create_site(payload: SiteCreateRequest, db: Session = Depends(get_db)) -> SiteItem:
companies_repo = CompaniesRepository(db) companies_repo = CompaniesRepository(db)
sites_repo = SitesRepository(db) sites_repo = SitesRepository(db)
idp = KeycloakAdminService(get_settings())
company = companies_repo.get_by_key(payload.company_key) company = companies_repo.get_by_key(payload.company_key)
if not company: if not company:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found")
site_key = _generate_unique_key("ST", lambda key: sites_repo.get_by_key(key) is not None) site_key = _generate_unique_key("ST", lambda key: sites_repo.get_by_key(key) is not None)
group = idp.ensure_group(
group_id=None,
name=site_key,
parent_group_id=company.idp_group_id,
attributes={
"member_entity_type": "site",
"site_key": site_key,
"company_key": company.company_key,
"display_name": payload.display_name,
"domain": payload.domain or "",
"status": payload.status,
},
)
item = sites_repo.create( item = sites_repo.create(
site_key=site_key, site_key=site_key,
company_id=company.id, company_id=company.id,
display_name=payload.display_name, display_name=payload.display_name,
domain=payload.domain, domain=payload.domain,
idp_group_id=group.group_id,
status=payload.status, status=payload.status,
) )
return _site_item(item, company) return _site_item(item, company)
@@ -220,24 +270,46 @@ def create_site(payload: SiteCreateRequest, db: Session = Depends(get_db)) -> Si
def update_site(site_key: str, payload: SiteUpdateRequest, db: Session = Depends(get_db)) -> SiteItem: def update_site(site_key: str, payload: SiteUpdateRequest, db: Session = Depends(get_db)) -> SiteItem:
companies_repo = CompaniesRepository(db) companies_repo = CompaniesRepository(db)
sites_repo = SitesRepository(db) sites_repo = SitesRepository(db)
idp = KeycloakAdminService(get_settings())
item = sites_repo.get_by_key(site_key) item = sites_repo.get_by_key(site_key)
if not item: if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found")
target_company = companies_repo.get_by_id(item.company_id)
company_id = None company_id = None
if payload.company_key: if payload.company_key:
company = companies_repo.get_by_key(payload.company_key) target_company = companies_repo.get_by_key(payload.company_key)
if not company: if not target_company:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="company_not_found")
company_id = company.id company_id = target_company.id
if not target_company:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="company_reference_missing")
resolved_display_name = payload.display_name if payload.display_name is not None else item.display_name
resolved_domain = payload.domain if payload.domain is not None else item.domain
resolved_status = payload.status if payload.status is not None else item.status
resolved_group_id = payload.idp_group_id or item.idp_group_id
group = idp.ensure_group(
group_id=resolved_group_id,
name=site_key,
parent_group_id=target_company.idp_group_id,
attributes={
"member_entity_type": "site",
"site_key": site_key,
"company_key": target_company.company_key,
"display_name": resolved_display_name,
"domain": resolved_domain or "",
"status": resolved_status,
},
)
item = sites_repo.update( item = sites_repo.update(
item, item,
company_id=company_id, company_id=company_id,
display_name=payload.display_name, display_name=payload.display_name,
domain=payload.domain, domain=payload.domain,
idp_group_id=payload.idp_group_id, idp_group_id=group.group_id,
status=payload.status, status=payload.status,
) )
company = companies_repo.get_by_id(item.company_id) company = companies_repo.get_by_id(item.company_id)
@@ -249,9 +321,11 @@ def update_site(site_key: str, payload: SiteUpdateRequest, db: Session = Depends
@router.delete("/sites/{site_key}") @router.delete("/sites/{site_key}")
def delete_site(site_key: str, db: Session = Depends(get_db)) -> dict[str, str]: def delete_site(site_key: str, db: Session = Depends(get_db)) -> dict[str, str]:
repo = SitesRepository(db) repo = SitesRepository(db)
idp = KeycloakAdminService(get_settings())
item = repo.get_by_key(site_key) item = repo.get_by_key(site_key)
if not item: if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found")
idp.delete_group(group_id=item.idp_group_id)
repo.delete(item) repo.delete(item)
return {"deleted": site_key} return {"deleted": site_key}
@@ -264,6 +338,7 @@ def list_systems(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
sync_from_keycloak(db)
repo = SystemsRepository(db) repo = SystemsRepository(db)
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset) return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset)
@@ -271,30 +346,17 @@ def list_systems(
@router.post("/systems", response_model=SystemItem) @router.post("/systems", response_model=SystemItem)
def create_system(payload: SystemCreateRequest, db: Session = Depends(get_db)) -> SystemItem: def create_system(payload: SystemCreateRequest, db: Session = Depends(get_db)) -> SystemItem:
repo = SystemsRepository(db) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only")
system_key = _generate_unique_key("SY", lambda key: repo.get_by_key(key) is not None)
item = repo.create(system_key=system_key, name=payload.name, idp_client_id=payload.idp_client_id, status=payload.status)
return _system_item(item)
@router.patch("/systems/{system_key}", response_model=SystemItem) @router.patch("/systems/{system_key}", response_model=SystemItem)
def update_system(system_key: str, payload: SystemUpdateRequest, db: Session = Depends(get_db)) -> SystemItem: def update_system(system_key: str, payload: SystemUpdateRequest, db: Session = Depends(get_db)) -> SystemItem:
repo = SystemsRepository(db) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only")
item = repo.get_by_key(system_key)
if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
item = repo.update(item, name=payload.name, idp_client_id=payload.idp_client_id, status=payload.status)
return _system_item(item)
@router.delete("/systems/{system_key}") @router.delete("/systems/{system_key}")
def delete_system(system_key: str, db: Session = Depends(get_db)) -> dict[str, str]: def delete_system(system_key: str, db: Session = Depends(get_db)) -> dict[str, str]:
repo = SystemsRepository(db) raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="system_manage_in_keycloak_only")
item = repo.get_by_key(system_key)
if not item:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
repo.delete(item)
return {"deleted": system_key}
@router.get("/roles", response_model=ListResponse) @router.get("/roles", response_model=ListResponse)
@@ -306,6 +368,7 @@ def list_roles(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
sync_from_keycloak(db)
systems_repo = SystemsRepository(db) systems_repo = SystemsRepository(db)
roles_repo = RolesRepository(db) roles_repo = RolesRepository(db)
@@ -544,6 +607,7 @@ def list_members(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
sync_from_keycloak(db)
repo = UsersRepository(db) repo = UsersRepository(db)
rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset) rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset)
return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset) return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset)
@@ -754,6 +818,11 @@ def list_api_clients(
) )
@router.post("/sync/from-keycloak")
def sync_catalog_from_keycloak(db: Session = Depends(get_db), force: bool = Query(default=True)) -> dict[str, int]:
return sync_from_keycloak(db, force=force)
@router.post("/api-clients", response_model=ApiClientCreateResponse) @router.post("/api-clients", response_model=ApiClientCreateResponse)
def create_api_client(payload: ApiClientCreateRequest, db: Session = Depends(get_db)) -> ApiClientCreateResponse: def create_api_client(payload: ApiClientCreateRequest, db: Session = Depends(get_db)) -> ApiClientCreateResponse:
repo = ApiClientsRepository(db) repo = ApiClientsRepository(db)

View File

@@ -5,7 +5,7 @@ import httpx
from fastapi import APIRouter, HTTPException, status from fastapi import APIRouter, HTTPException, status
from app.core.config import get_settings from app.core.config import get_settings
from app.schemas.login import LoginRequest, LoginResponse, OIDCAuthUrlResponse, OIDCCodeExchangeRequest from app.schemas.login import LoginRequest, LoginResponse, OIDCAuthUrlResponse, OIDCCodeExchangeRequest, RefreshTokenRequest
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -45,12 +45,7 @@ def login(payload: LoginRequest) -> LoginResponse:
token = data.get("access_token") token = data.get("access_token")
if not token: if not token:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token") raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token")
return LoginResponse( return _build_login_response(data)
access_token=token,
token_type=data.get("token_type", "Bearer"),
expires_in=data.get("expires_in"),
scope=data.get("scope"),
)
@router.get("/oidc/url", response_model=OIDCAuthUrlResponse) @router.get("/oidc/url", response_model=OIDCAuthUrlResponse)
@@ -123,9 +118,50 @@ def exchange_oidc_code(payload: OIDCCodeExchangeRequest) -> LoginResponse:
token = data.get("access_token") token = data.get("access_token")
if not token: if not token:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token") raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token")
return _build_login_response(data)
@router.post("/refresh", response_model=LoginResponse)
def refresh_access_token(payload: RefreshTokenRequest) -> LoginResponse:
settings = get_settings()
client_id = settings.idp_client_id or settings.idp_audience
if not settings.idp_base_url or not client_id or not settings.idp_client_secret:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="idp_login_not_configured")
form = {
"grant_type": "refresh_token",
"client_id": client_id,
"client_secret": settings.idp_client_secret,
"refresh_token": payload.refresh_token,
}
try:
resp = httpx.post(
settings.idp_token_endpoint,
data=form,
timeout=10,
verify=settings.idp_verify_tls,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
except Exception as exc:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_unreachable") from exc
if resp.status_code >= 400:
logger.warning("idp refresh-token grant failed: status=%s body=%s", resp.status_code, resp.text)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_refresh_token")
data = resp.json()
token = data.get("access_token")
if not token:
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="idp_missing_access_token")
return _build_login_response(data)
def _build_login_response(data: dict) -> LoginResponse:
return LoginResponse( return LoginResponse(
access_token=token, access_token=data.get("access_token", ""),
refresh_token=data.get("refresh_token"),
token_type=data.get("token_type", "Bearer"), token_type=data.get("token_type", "Bearer"),
expires_in=data.get("expires_in"), expires_in=data.get("expires_in"),
refresh_expires_in=data.get("refresh_expires_in"),
scope=data.get("scope"), scope=data.get("scope"),
) )

View File

@@ -30,8 +30,22 @@ class CompaniesRepository:
stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset) stmt = stmt.order_by(Company.created_at.desc()).limit(limit).offset(offset)
return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0) return list(self.db.scalars(stmt).all()), int(self.db.scalar(count_stmt) or 0)
def create(self, *, company_key: str, display_name: str, legal_name: str | None, status: str = "active") -> Company: def create(
item = Company(company_key=company_key, display_name=display_name, legal_name=legal_name, status=status) self,
*,
company_key: str,
display_name: str,
legal_name: str | None,
idp_group_id: str | None = None,
status: str = "active",
) -> Company:
item = Company(
company_key=company_key,
display_name=display_name,
legal_name=legal_name,
idp_group_id=idp_group_id,
status=status,
)
self.db.add(item) self.db.add(item)
self.db.commit() self.db.commit()
self.db.refresh(item) self.db.refresh(item)

View File

@@ -45,9 +45,17 @@ class SitesRepository:
company_id: str, company_id: str,
display_name: str, display_name: str,
domain: str | None, domain: str | None,
idp_group_id: str | None = None,
status: str = "active", status: str = "active",
) -> Site: ) -> Site:
item = Site(site_key=site_key, company_id=company_id, display_name=display_name, domain=domain, status=status) item = Site(
site_key=site_key,
company_id=company_id,
display_name=display_name,
domain=domain,
idp_group_id=idp_group_id,
status=status,
)
self.db.add(item) self.db.add(item)
self.db.commit() self.db.commit()
self.db.refresh(item) self.db.refresh(item)

View File

@@ -8,8 +8,10 @@ class LoginRequest(BaseModel):
class LoginResponse(BaseModel): class LoginResponse(BaseModel):
access_token: str access_token: str
refresh_token: str | None = None
token_type: str = "Bearer" token_type: str = "Bearer"
expires_in: int | None = None expires_in: int | None = None
refresh_expires_in: int | None = None
scope: str | None = None scope: str | None = None
@@ -21,3 +23,7 @@ class OIDCCodeExchangeRequest(BaseModel):
code: str code: str
redirect_uri: str redirect_uri: str
code_verifier: str | None = None code_verifier: str | None = None
class RefreshTokenRequest(BaseModel):
refresh_token: str

View File

@@ -29,6 +29,12 @@ class KeycloakDeleteResult:
user_id: str | None = None user_id: str | None = None
@dataclass
class KeycloakGroupSyncResult:
group_id: str
action: str
class KeycloakAdminService: class KeycloakAdminService:
def __init__(self, settings: Settings) -> None: def __init__(self, settings: Settings) -> None:
self.base_url = settings.keycloak_base_url.rstrip("/") self.base_url = settings.keycloak_base_url.rstrip("/")
@@ -97,6 +103,64 @@ class KeycloakAdminService:
raise HTTPException(status_code=502, detail="idp_lookup_failed") raise HTTPException(status_code=502, detail="idp_lookup_failed")
return resp.json() return resp.json()
def _lookup_group_by_id(self, client: httpx.Client, group_id: str) -> dict | None:
resp = client.get(f"/admin/realms/{self.realm}/groups/{group_id}")
if resp.status_code == 404:
return None
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_lookup_failed")
payload = resp.json() if resp.content else {}
return payload if isinstance(payload, dict) else None
def _lookup_group_by_name(self, client: httpx.Client, *, name: str, parent_group_id: str | None) -> dict | None:
if parent_group_id:
resp = client.get(
f"/admin/realms/{self.realm}/groups/{parent_group_id}/children",
params={"search": name, "briefRepresentation": "false"},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_lookup_failed")
matches = resp.json() if isinstance(resp.json(), list) else []
for row in matches:
if isinstance(row, dict) and str(row.get("name", "")).strip() == name:
return row
return None
resp = client.get(
f"/admin/realms/{self.realm}/groups",
params={"search": name, "exact": "true", "briefRepresentation": "false"},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_lookup_failed")
matches = resp.json() if isinstance(resp.json(), list) else []
for row in matches:
if not isinstance(row, dict):
continue
if str(row.get("name", "")).strip() != name:
continue
parent_id = row.get("parentId")
if parent_group_id:
if str(parent_id or "") == parent_group_id:
return row
elif not parent_id:
return row
return None
@staticmethod
def _normalize_group_attributes(attributes: dict[str, str | list[str]] | None) -> dict[str, list[str]]:
if not attributes:
return {}
output: dict[str, list[str]] = {}
for key, value in attributes.items():
normalized_key = str(key).strip()
if not normalized_key:
continue
if isinstance(value, list):
output[normalized_key] = [str(v) for v in value if str(v)]
elif value is not None and str(value):
output[normalized_key] = [str(value)]
return output
def _lookup_user_by_email_or_username( def _lookup_user_by_email_or_username(
self, client: httpx.Client, *, email: str | None, username: str | None self, client: httpx.Client, *, email: str | None, username: str | None
) -> dict | None: ) -> dict | None:
@@ -162,6 +226,59 @@ class KeycloakAdminService:
raise HTTPException(status_code=502, detail="idp_create_failed") raise HTTPException(status_code=502, detail="idp_create_failed")
return KeycloakSyncResult(user_id=user_id, action="created", user_sub=user_id) return KeycloakSyncResult(user_id=user_id, action="created", user_sub=user_id)
def ensure_group(
self,
*,
name: str,
group_id: str | None = None,
parent_group_id: str | None = None,
attributes: dict[str, str | list[str]] | None = None,
) -> KeycloakGroupSyncResult:
if not name:
raise HTTPException(status_code=400, detail="idp_group_name_required")
normalized_attrs = self._normalize_group_attributes(attributes)
with self._client() as client:
existing = self._lookup_group_by_id(client, group_id) if group_id else None
if existing is None:
existing = self._lookup_group_by_name(client, name=name, parent_group_id=parent_group_id)
if existing and existing.get("id"):
resolved_id = str(existing["id"])
payload = {"name": name, "attributes": normalized_attrs}
put_resp = client.put(f"/admin/realms/{self.realm}/groups/{resolved_id}", json=payload)
if put_resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_update_failed")
return KeycloakGroupSyncResult(group_id=resolved_id, action="updated")
payload = {"name": name, "attributes": normalized_attrs}
if parent_group_id:
create_resp = client.post(f"/admin/realms/{self.realm}/groups/{parent_group_id}/children", json=payload)
else:
create_resp = client.post(f"/admin/realms/{self.realm}/groups", json=payload)
if create_resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_create_failed")
location = create_resp.headers.get("Location", "")
resolved_id = location.rstrip("/").split("/")[-1] if location and "/" in location else ""
if not resolved_id:
found = self._lookup_group_by_name(client, name=name, parent_group_id=parent_group_id)
resolved_id = str(found.get("id")) if found and found.get("id") else ""
if not resolved_id:
raise HTTPException(status_code=502, detail="idp_group_create_failed")
return KeycloakGroupSyncResult(group_id=resolved_id, action="created")
def delete_group(self, *, group_id: str | None) -> KeycloakDeleteResult:
if not group_id:
return KeycloakDeleteResult(action="not_found")
with self._client() as client:
resp = client.delete(f"/admin/realms/{self.realm}/groups/{group_id}")
if resp.status_code in {204, 404}:
return KeycloakDeleteResult(action="deleted" if resp.status_code == 204 else "not_found")
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_delete_failed")
return KeycloakDeleteResult(action="deleted")
def reset_password( def reset_password(
self, self,
*, *,
@@ -207,3 +324,63 @@ class KeycloakAdminService:
if resp.status_code >= 400: if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_delete_failed") raise HTTPException(status_code=502, detail="idp_delete_failed")
return KeycloakDeleteResult(action="deleted", user_id=user_id) return KeycloakDeleteResult(action="deleted", user_id=user_id)
def list_groups_tree(self) -> list[dict]:
with self._client() as client:
resp = client.get(
f"/admin/realms/{self.realm}/groups",
params={"first": 0, "max": 5000, "briefRepresentation": "false"},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_lookup_failed")
payload = resp.json() if resp.content else []
return payload if isinstance(payload, list) else []
def list_users(self) -> list[dict]:
users: list[dict] = []
first = 0
page_size = 200
with self._client() as client:
while True:
resp = client.get(
f"/admin/realms/{self.realm}/users",
params={"first": first, "max": page_size},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_lookup_failed")
batch = resp.json() if isinstance(resp.json(), list) else []
users.extend([row for row in batch if isinstance(row, dict)])
if len(batch) < page_size:
break
first += page_size
return users
def list_clients(self) -> list[dict]:
clients: list[dict] = []
first = 0
page_size = 200
with self._client() as client:
while True:
resp = client.get(
f"/admin/realms/{self.realm}/clients",
params={"first": first, "max": page_size},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_lookup_failed")
batch = resp.json() if isinstance(resp.json(), list) else []
clients.extend([row for row in batch if isinstance(row, dict)])
if len(batch) < page_size:
break
first += page_size
return clients
def list_client_roles(self, client_uuid: str) -> list[dict]:
with self._client() as client:
resp = client.get(
f"/admin/realms/{self.realm}/clients/{client_uuid}/roles",
params={"first": 0, "max": 5000},
)
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_lookup_failed")
payload = resp.json() if resp.content else []
return payload if isinstance(payload, list) else []

View File

@@ -0,0 +1,306 @@
from __future__ import annotations
import threading
import time
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.config import get_settings
from app.core.keygen import generate_key
from app.models.company import Company
from app.models.role import Role
from app.models.site import Site
from app.models.system import System
from app.repositories.companies_repo import CompaniesRepository
from app.repositories.roles_repo import RolesRepository
from app.repositories.sites_repo import SitesRepository
from app.repositories.systems_repo import SystemsRepository
from app.repositories.users_repo import UsersRepository
from app.services.idp_admin_service import KeycloakAdminService
BUILTIN_CLIENT_IDS = {
"account",
"account-console",
"admin-cli",
"broker",
"realm-management",
"security-admin-console",
"master-realm",
}
_sync_lock = threading.Lock()
_last_synced_at = 0.0
_min_sync_interval_sec = 30.0
def _generate_unique_key(prefix: str, exists_check) -> str:
for salt in range(5000):
key = generate_key(prefix, salt)
if not exists_check(key):
return key
raise RuntimeError(f"failed_generate_{prefix.lower()}_key")
def _first_attr(attrs: dict | None, key: str) -> str | None:
if not isinstance(attrs, dict):
return None
raw = attrs.get(key)
if isinstance(raw, list) and raw:
value = str(raw[0]).strip()
return value or None
if isinstance(raw, str):
value = raw.strip()
return value or None
return None
def _flatten_groups(nodes: list[dict], inherited_company_key: str | None = None) -> tuple[dict[str, dict], dict[str, dict]]:
companies: dict[str, dict] = {}
sites: dict[str, dict] = {}
for node in nodes:
if not isinstance(node, dict):
continue
attrs = node.get("attributes")
group_id = str(node.get("id", "")).strip() or None
name = str(node.get("name", "")).strip()
children = node.get("subGroups") if isinstance(node.get("subGroups"), list) else []
company_key = _first_attr(attrs, "company_key")
if not company_key and name.startswith("CP"):
company_key = name
if _first_attr(attrs, "member_entity_type") == "company" and not company_key:
company_key = name or None
current_company_key = company_key or inherited_company_key
if company_key:
companies[company_key] = {
"company_key": company_key,
"display_name": _first_attr(attrs, "display_name") or name or company_key,
"status": _first_attr(attrs, "status") or "active",
"idp_group_id": group_id,
}
site_key = _first_attr(attrs, "site_key")
if not site_key and name.startswith("ST"):
site_key = name
if _first_attr(attrs, "member_entity_type") == "site" and not site_key:
site_key = name or None
if site_key:
sites[site_key] = {
"site_key": site_key,
"company_key": _first_attr(attrs, "company_key") or current_company_key,
"display_name": _first_attr(attrs, "display_name") or name or site_key,
"domain": _first_attr(attrs, "domain"),
"status": _first_attr(attrs, "status") or "active",
"idp_group_id": group_id,
}
child_companies, child_sites = _flatten_groups(children, current_company_key)
companies.update(child_companies)
sites.update(child_sites)
return companies, sites
def sync_from_keycloak(db: Session, *, force: bool = False) -> dict[str, int]:
global _last_synced_at
now = time.time()
if not force and now - _last_synced_at < _min_sync_interval_sec:
return {"synced": 0}
if not _sync_lock.acquire(blocking=False):
return {"synced": 0}
try:
now = time.time()
if not force and now - _last_synced_at < _min_sync_interval_sec:
return {"synced": 0}
idp = KeycloakAdminService(get_settings())
companies_repo = CompaniesRepository(db)
sites_repo = SitesRepository(db)
systems_repo = SystemsRepository(db)
roles_repo = RolesRepository(db)
users_repo = UsersRepository(db)
companies_created = 0
companies_updated = 0
sites_created = 0
sites_updated = 0
systems_created = 0
systems_updated = 0
roles_created = 0
roles_updated = 0
users_created_or_updated = 0
group_tree = idp.list_groups_tree()
company_records, site_records = _flatten_groups(group_tree)
company_id_map: dict[str, str] = {}
for company_key, row in company_records.items():
company = companies_repo.get_by_key(company_key)
if company is None:
company = companies_repo.create(
company_key=company_key,
display_name=row["display_name"],
legal_name=None,
idp_group_id=row["idp_group_id"],
status=row["status"],
)
companies_created += 1
else:
company = companies_repo.update(
company,
display_name=row["display_name"],
idp_group_id=row["idp_group_id"],
status=row["status"],
)
companies_updated += 1
company_id_map[company_key] = company.id
for site_key, row in site_records.items():
company_key = row.get("company_key")
if not company_key:
continue
company_id = company_id_map.get(company_key)
if not company_id:
placeholder = companies_repo.get_by_key(company_key)
if placeholder is None:
placeholder = companies_repo.create(
company_key=company_key,
display_name=company_key,
legal_name=None,
idp_group_id=None,
status="active",
)
companies_created += 1
company_id = placeholder.id
company_id_map[company_key] = company_id
site = sites_repo.get_by_key(site_key)
if site is None:
sites_repo.create(
site_key=site_key,
company_id=company_id,
display_name=row["display_name"],
domain=row["domain"],
idp_group_id=row["idp_group_id"],
status=row["status"],
)
sites_created += 1
else:
sites_repo.update(
site,
company_id=company_id,
display_name=row["display_name"],
domain=row["domain"],
idp_group_id=row["idp_group_id"],
status=row["status"],
)
sites_updated += 1
client_rows = idp.list_clients()
system_map_by_client_id: dict[str, System] = {}
for client in client_rows:
client_uuid = str(client.get("id", "")).strip()
client_id = str(client.get("clientId", "")).strip()
if not client_uuid or not client_id:
continue
if client_id in BUILTIN_CLIENT_IDS:
continue
system = db.scalar(select(System).where(System.idp_client_id == client_id))
system_name = str(client.get("name", "")).strip() or client_id
system_status = "active" if client.get("enabled", True) else "inactive"
if system is None:
system_key = _generate_unique_key("SY", lambda key: systems_repo.get_by_key(key) is not None)
system = systems_repo.create(
system_key=system_key,
name=system_name,
idp_client_id=client_id,
status=system_status,
)
systems_created += 1
else:
system = systems_repo.update(
system,
name=system_name,
status=system_status,
)
systems_updated += 1
system_map_by_client_id[client_id] = system
client_roles = idp.list_client_roles(client_uuid)
for role_row in client_roles:
if not isinstance(role_row, dict):
continue
role_name = str(role_row.get("name", "")).strip()
if not role_name:
continue
role_desc = str(role_row.get("description", "")).strip() or None
role_status = "active" if not role_row.get("composite", False) else "active"
role = db.scalar(
select(Role).where(
Role.system_id == system.id,
Role.idp_role_name == role_name,
)
)
if role is None:
role_key = _generate_unique_key("RL", lambda key: roles_repo.get_by_key(key) is not None)
roles_repo.create(
role_key=role_key,
system_id=system.id,
name=role_name,
description=role_desc,
idp_role_name=role_name,
status=role_status,
)
roles_created += 1
else:
roles_repo.update(
role,
name=role_name,
description=role_desc,
status=role_status,
)
roles_updated += 1
for user in idp.list_users():
user_id = str(user.get("id", "")).strip()
if not user_id:
continue
display_name = (
str(user.get("firstName", "")).strip()
or str(user.get("username", "")).strip()
or str(user.get("email", "")).strip()
or user_id
)
users_repo.upsert_by_sub(
user_sub=user_id,
idp_user_id=user_id,
username=str(user.get("username", "")).strip() or None,
email=str(user.get("email", "")).strip() or None,
display_name=display_name,
is_active=bool(user.get("enabled", True)),
status="active" if user.get("enabled", True) else "inactive",
)
users_created_or_updated += 1
_last_synced_at = time.time()
return {
"synced": 1,
"companies_created": companies_created,
"companies_updated": companies_updated,
"sites_created": sites_created,
"sites_updated": sites_updated,
"systems_created": systems_created,
"systems_updated": systems_updated,
"roles_created": roles_created,
"roles_updated": roles_updated,
"users_upserted": users_created_or_updated,
}
finally:
_sync_lock.release()