Compare commits
10 Commits
64246984ba
...
ade60bdbaa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ade60bdbaa | ||
|
|
75cf22f7e5 | ||
|
|
a187acf58c | ||
|
|
6d6f5aa2a6 | ||
|
|
73ed4ff0ff | ||
|
|
55e640f2fb | ||
|
|
e912d1498e | ||
|
|
c431fe180b | ||
|
|
da08bc01ec | ||
|
|
d5418d47e7 |
13
.dockerignore
Normal file
13
.dockerignore
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
.venv
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
.pytest_cache
|
||||||
|
.ruff_cache
|
||||||
|
tests
|
||||||
|
.env
|
||||||
|
.env.development
|
||||||
|
*.log
|
||||||
@@ -19,4 +19,10 @@ KEYCLOAK_ADMIN_REALM=master
|
|||||||
|
|
||||||
PUBLIC_FRONTEND_ORIGINS=http://127.0.0.1:5173,http://localhost:5173
|
PUBLIC_FRONTEND_ORIGINS=http://127.0.0.1:5173,http://localhost:5173
|
||||||
INTERNAL_SHARED_SECRET=CHANGE_ME
|
INTERNAL_SHARED_SECRET=CHANGE_ME
|
||||||
ADMIN_REQUIRED_GROUPS=member-admin
|
MEMBER_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
ADMIN_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
|
||||||
|
CACHE_BACKEND=memory
|
||||||
|
CACHE_REDIS_URL=redis://127.0.0.1:6379/0
|
||||||
|
CACHE_PREFIX=memberapi
|
||||||
|
CACHE_DEFAULT_TTL_SECONDS=30
|
||||||
|
|||||||
@@ -25,4 +25,11 @@ KEYCLOAK_ADMIN_REALM=
|
|||||||
|
|
||||||
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
|
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
|
||||||
INTERNAL_SHARED_SECRET=CHANGE_ME
|
INTERNAL_SHARED_SECRET=CHANGE_ME
|
||||||
ADMIN_REQUIRED_GROUPS=member-admin
|
MEMBER_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
ADMIN_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
|
||||||
|
# Cache backend: memory | redis
|
||||||
|
CACHE_BACKEND=memory
|
||||||
|
CACHE_REDIS_URL=redis://127.0.0.1:6379/0
|
||||||
|
CACHE_PREFIX=memberapi
|
||||||
|
CACHE_DEFAULT_TTL_SECONDS=30
|
||||||
|
|||||||
@@ -25,3 +25,11 @@ KEYCLOAK_ADMIN_REALM=
|
|||||||
|
|
||||||
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
|
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
|
||||||
INTERNAL_SHARED_SECRET=CHANGE_ME
|
INTERNAL_SHARED_SECRET=CHANGE_ME
|
||||||
|
MEMBER_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
ADMIN_REQUIRED_REALM_ROLES=admin,manager
|
||||||
|
|
||||||
|
# Cache backend: memory | redis
|
||||||
|
CACHE_BACKEND=redis
|
||||||
|
CACHE_REDIS_URL=redis://redis:6379/0
|
||||||
|
CACHE_PREFIX=memberapi
|
||||||
|
CACHE_DEFAULT_TTL_SECONDS=30
|
||||||
|
|||||||
30
Dockerfile
Normal file
30
Dockerfile
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
FROM python:3.12-alpine AS builder
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apk add --no-cache build-base libffi-dev openssl-dev cargo
|
||||||
|
|
||||||
|
COPY pyproject.toml /app/pyproject.toml
|
||||||
|
COPY app /app/app
|
||||||
|
COPY scripts /app/scripts
|
||||||
|
COPY README.md /app/README.md
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir --upgrade pip && \
|
||||||
|
pip wheel --no-cache-dir --wheel-dir /wheels .
|
||||||
|
|
||||||
|
FROM python:3.12-alpine
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1 \
|
||||||
|
PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN apk add --no-cache libstdc++ libffi openssl
|
||||||
|
|
||||||
|
COPY --from=builder /wheels /wheels
|
||||||
|
RUN pip install --no-cache-dir --upgrade pip && \
|
||||||
|
pip install --no-cache-dir /wheels/*
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
29
README.md
29
README.md
@@ -12,6 +12,30 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
|
|||||||
./scripts/start_dev.sh
|
./scripts/start_dev.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Docker (VPS / Production)
|
||||||
|
> 目前 Dockerfile 為 Alpine 多階段建置(較小體積)。
|
||||||
|
|
||||||
|
Build image:
|
||||||
|
```bash
|
||||||
|
cd backend
|
||||||
|
docker build -t memberapi-backend:latest .
|
||||||
|
```
|
||||||
|
|
||||||
|
Run container:
|
||||||
|
```bash
|
||||||
|
docker run -d \
|
||||||
|
--name memberapi-backend \
|
||||||
|
--restart unless-stopped \
|
||||||
|
-p 127.0.0.1:8000:8000 \
|
||||||
|
--env-file .env \
|
||||||
|
memberapi-backend:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
Health check:
|
||||||
|
```bash
|
||||||
|
curl http://127.0.0.1:8000/healthz
|
||||||
|
```
|
||||||
|
|
||||||
## Keycloak env
|
## Keycloak env
|
||||||
|
|
||||||
- Required:
|
- Required:
|
||||||
@@ -29,6 +53,8 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
|
|||||||
- `KEYCLOAK_USERINFO_ENDPOINT`
|
- `KEYCLOAK_USERINFO_ENDPOINT`
|
||||||
- `KEYCLOAK_AUDIENCE`
|
- `KEYCLOAK_AUDIENCE`
|
||||||
- `KEYCLOAK_VERIFY_TLS`
|
- `KEYCLOAK_VERIFY_TLS`
|
||||||
|
- `MEMBER_REQUIRED_REALM_ROLES` (default: `admin,manager`)
|
||||||
|
- `ADMIN_REQUIRED_REALM_ROLES` (default: `admin,manager`)
|
||||||
|
|
||||||
## Main APIs
|
## Main APIs
|
||||||
|
|
||||||
@@ -38,7 +64,7 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
|
|||||||
- `GET /me` (Bearer token required)
|
- `GET /me` (Bearer token required)
|
||||||
- `GET /me/permissions/snapshot` (Bearer token required)
|
- `GET /me/permissions/snapshot` (Bearer token required)
|
||||||
|
|
||||||
### Admin APIs (Bearer + admin group required)
|
### Admin APIs (Bearer + admin realm role required)
|
||||||
- `GET/POST/PATCH/DELETE /admin/companies`
|
- `GET/POST/PATCH/DELETE /admin/companies`
|
||||||
- `GET/POST/PATCH/DELETE /admin/sites`
|
- `GET/POST/PATCH/DELETE /admin/sites`
|
||||||
- `GET/POST/PATCH/DELETE /admin/systems`
|
- `GET/POST/PATCH/DELETE /admin/systems`
|
||||||
@@ -57,5 +83,4 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
|
|||||||
- `GET /internal/members`
|
- `GET /internal/members`
|
||||||
- `POST /internal/users/upsert-by-sub`
|
- `POST /internal/users/upsert-by-sub`
|
||||||
- `GET /internal/users/{user_sub}/roles`
|
- `GET /internal/users/{user_sub}/roles`
|
||||||
- `GET /internal/permissions/{user_sub}/snapshot`
|
|
||||||
- `POST /internal/idp/users/ensure`
|
- `POST /internal/idp/users/ensure`
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ from app.security.admin_guard import require_admin_principal
|
|||||||
from app.security.api_client_auth import hash_api_key
|
from app.security.api_client_auth import hash_api_key
|
||||||
from app.services.idp_admin_service import ProviderAdminService
|
from app.services.idp_admin_service import ProviderAdminService
|
||||||
from app.services.idp_catalog_sync import sync_from_provider
|
from app.services.idp_catalog_sync import sync_from_provider
|
||||||
|
from app.services.runtime_cache import runtime_cache
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -129,6 +130,31 @@ def _site_group_name(display_name: str, site_key: str) -> str:
|
|||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_site_client_roles(
|
||||||
|
*,
|
||||||
|
idp: ProviderAdminService,
|
||||||
|
site,
|
||||||
|
site_role_rows,
|
||||||
|
provider_client_ids: set[str],
|
||||||
|
) -> None:
|
||||||
|
if not site.provider_group_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"site_provider_group_missing:{site.site_key}")
|
||||||
|
|
||||||
|
role_names_by_client: dict[str, list[str]] = {}
|
||||||
|
for _, role, system in site_role_rows:
|
||||||
|
provider_client_id = str(system.name or "").strip()
|
||||||
|
if not provider_client_id:
|
||||||
|
continue
|
||||||
|
role_names_by_client.setdefault(provider_client_id, []).append(role.name)
|
||||||
|
|
||||||
|
for provider_client_id in sorted(provider_client_ids):
|
||||||
|
idp.set_group_client_roles(
|
||||||
|
group_id=site.provider_group_id,
|
||||||
|
provider_client_id=provider_client_id,
|
||||||
|
role_names=role_names_by_client.get(provider_client_id, []),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/companies", response_model=ListResponse)
|
@router.get("/companies", response_model=ListResponse)
|
||||||
def list_companies(
|
def list_companies(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -136,9 +162,16 @@ def list_companies(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
cache_key = f"admin:companies:{keyword or ''}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
repo = CompaniesRepository(db)
|
repo = CompaniesRepository(db)
|
||||||
items, total = repo.list(keyword=keyword, limit=limit, offset=offset)
|
items, total = repo.list(keyword=keyword, limit=limit, offset=offset)
|
||||||
return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset)
|
result = ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/companies", response_model=CompanyItem)
|
@router.post("/companies", response_model=CompanyItem)
|
||||||
@@ -226,6 +259,11 @@ def list_sites(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
cache_key = f"admin:sites:{keyword or ''}:{company_key or ''}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
companies_repo = CompaniesRepository(db)
|
companies_repo = CompaniesRepository(db)
|
||||||
sites_repo = SitesRepository(db)
|
sites_repo = SitesRepository(db)
|
||||||
company_id = None
|
company_id = None
|
||||||
@@ -239,7 +277,9 @@ def list_sites(
|
|||||||
company_map = {c.id: c for c in companies}
|
company_map = {c.id: c for c in companies}
|
||||||
items, total = sites_repo.list(keyword=keyword, company_id=company_id, limit=limit, offset=offset)
|
items, total = sites_repo.list(keyword=keyword, company_id=company_id, limit=limit, offset=offset)
|
||||||
response_items = [_site_item(i, company_map[i.company_id]) for i in items if i.company_id in company_map]
|
response_items = [_site_item(i, company_map[i.company_id]) for i in items if i.company_id in company_map]
|
||||||
return ListResponse(items=response_items, total=total, limit=limit, offset=offset)
|
result = ListResponse(items=response_items, total=total, limit=limit, offset=offset)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sites", response_model=SiteItem)
|
@router.post("/sites", response_model=SiteItem)
|
||||||
@@ -350,9 +390,16 @@ def list_systems(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
cache_key = f"admin:systems:{keyword or ''}:{status_filter or ''}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
repo = SystemsRepository(db)
|
repo = SystemsRepository(db)
|
||||||
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
|
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
|
||||||
return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset)
|
result = ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/systems", response_model=SystemItem)
|
@router.post("/systems", response_model=SystemItem)
|
||||||
@@ -379,6 +426,11 @@ def list_roles(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
cache_key = f"admin:roles:{keyword or ''}:{system_key or ''}:{status_filter or ''}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
systems_repo = SystemsRepository(db)
|
systems_repo = SystemsRepository(db)
|
||||||
roles_repo = RolesRepository(db)
|
roles_repo = RolesRepository(db)
|
||||||
|
|
||||||
@@ -408,7 +460,9 @@ def list_roles(
|
|||||||
for row in rows
|
for row in rows
|
||||||
if row.system_id in system_map
|
if row.system_id in system_map
|
||||||
]
|
]
|
||||||
return ListResponse(items=items, total=total, limit=limit, offset=offset)
|
result = ListResponse(items=items, total=total, limit=limit, offset=offset)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/roles", response_model=RoleItem)
|
@router.post("/roles", response_model=RoleItem)
|
||||||
@@ -597,11 +651,15 @@ def assign_site_roles(site_key: str, payload: SiteRoleAssignRequest, db: Session
|
|||||||
sites_repo = SitesRepository(db)
|
sites_repo = SitesRepository(db)
|
||||||
roles_repo = RolesRepository(db)
|
roles_repo = RolesRepository(db)
|
||||||
site_roles_repo = SiteRolesRepository(db)
|
site_roles_repo = SiteRolesRepository(db)
|
||||||
|
idp = ProviderAdminService(get_settings())
|
||||||
|
|
||||||
site = sites_repo.get_by_key(site_key)
|
site = sites_repo.get_by_key(site_key)
|
||||||
if not site:
|
if not site:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found")
|
||||||
|
|
||||||
|
current_rows = site_roles_repo.list_site_role_rows(site.id)
|
||||||
|
current_client_ids = {str(system.name or "").strip() for _, _, system in current_rows if str(system.name or "").strip()}
|
||||||
|
|
||||||
role_ids: list[str] = []
|
role_ids: list[str] = []
|
||||||
for role_key in list(dict.fromkeys(payload.role_keys)):
|
for role_key in list(dict.fromkeys(payload.role_keys)):
|
||||||
role = roles_repo.get_by_key(role_key)
|
role = roles_repo.get_by_key(role_key)
|
||||||
@@ -609,7 +667,23 @@ def assign_site_roles(site_key: str, payload: SiteRoleAssignRequest, db: Session
|
|||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"role_not_found:{role_key}")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"role_not_found:{role_key}")
|
||||||
role_ids.append(role.id)
|
role_ids.append(role.id)
|
||||||
|
|
||||||
site_roles_repo.set_site_roles(site_id=site.id, role_ids=role_ids)
|
site_roles_repo.set_site_roles(site_id=site.id, role_ids=role_ids, commit=False)
|
||||||
|
updated_rows = site_roles_repo.list_site_role_rows(site.id)
|
||||||
|
updated_client_ids = {str(system.name or "").strip() for _, _, system in updated_rows if str(system.name or "").strip()}
|
||||||
|
clients_to_sync = current_client_ids | updated_client_ids
|
||||||
|
|
||||||
|
try:
|
||||||
|
_sync_site_client_roles(
|
||||||
|
idp=idp,
|
||||||
|
site=site,
|
||||||
|
site_role_rows=updated_rows,
|
||||||
|
provider_client_ids=clients_to_sync,
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
return list_site_roles(site_key=site_key, db=db)
|
return list_site_roles(site_key=site_key, db=db)
|
||||||
|
|
||||||
|
|
||||||
@@ -646,6 +720,58 @@ def list_role_sites(role_key: str, db: Session = Depends(get_db)) -> RoleSitesRe
|
|||||||
return RoleSitesResponse(role_key=role_key, sites=result)
|
return RoleSitesResponse(role_key=role_key, sites=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/roles/{role_key}/sites", response_model=RoleSitesResponse)
|
||||||
|
def assign_role_sites(role_key: str, payload: UserSiteAssignRequest, db: Session = Depends(get_db)) -> RoleSitesResponse:
|
||||||
|
roles_repo = RolesRepository(db)
|
||||||
|
sites_repo = SitesRepository(db)
|
||||||
|
systems_repo = SystemsRepository(db)
|
||||||
|
site_roles_repo = SiteRolesRepository(db)
|
||||||
|
idp = ProviderAdminService(get_settings())
|
||||||
|
|
||||||
|
role = roles_repo.get_by_key(role_key)
|
||||||
|
if not role:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="role_not_found")
|
||||||
|
|
||||||
|
system = systems_repo.get_by_id(role.system_id)
|
||||||
|
if not system:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
|
||||||
|
provider_client_id = str(system.name or "").strip()
|
||||||
|
if not provider_client_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"provider_client_id_missing:{system.system_key}")
|
||||||
|
|
||||||
|
previous_rows = site_roles_repo.list_role_site_rows(role.id)
|
||||||
|
previous_site_ids = {site.id for _, site in previous_rows}
|
||||||
|
|
||||||
|
site_ids: list[str] = []
|
||||||
|
for site_key in list(dict.fromkeys(payload.site_keys)):
|
||||||
|
site = sites_repo.get_by_key(site_key)
|
||||||
|
if not site:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"site_not_found:{site_key}")
|
||||||
|
site_ids.append(site.id)
|
||||||
|
|
||||||
|
site_roles_repo.set_role_sites(role_id=role.id, site_ids=site_ids, commit=False)
|
||||||
|
|
||||||
|
affected_site_ids = previous_site_ids | set(site_ids)
|
||||||
|
try:
|
||||||
|
for site_id in affected_site_ids:
|
||||||
|
site = sites_repo.get_by_id(site_id)
|
||||||
|
if not site:
|
||||||
|
continue
|
||||||
|
site_rows = site_roles_repo.list_site_role_rows(site.id)
|
||||||
|
_sync_site_client_roles(
|
||||||
|
idp=idp,
|
||||||
|
site=site,
|
||||||
|
site_role_rows=site_rows,
|
||||||
|
provider_client_ids={provider_client_id},
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
except Exception:
|
||||||
|
db.rollback()
|
||||||
|
raise
|
||||||
|
|
||||||
|
return list_role_sites(role_key=role_key, db=db)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/members", response_model=ListResponse)
|
@router.get("/members", response_model=ListResponse)
|
||||||
def list_members(
|
def list_members(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
@@ -654,9 +780,17 @@ def list_members(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
is_active_key = "" if is_active is None else ("1" if is_active else "0")
|
||||||
|
cache_key = f"admin:members:{keyword or ''}:{is_active_key}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
repo = UsersRepository(db)
|
repo = UsersRepository(db)
|
||||||
rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset)
|
rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset)
|
||||||
return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset)
|
result = ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/members", response_model=MemberItem)
|
@router.post("/members", response_model=MemberItem)
|
||||||
@@ -853,14 +987,21 @@ def list_api_clients(
|
|||||||
limit: int = Query(default=100, ge=1, le=500),
|
limit: int = Query(default=100, ge=1, le=500),
|
||||||
offset: int = Query(default=0, ge=0),
|
offset: int = Query(default=0, ge=0),
|
||||||
) -> ListResponse:
|
) -> ListResponse:
|
||||||
|
cache_key = f"admin:api_clients:{keyword or ''}:{status_filter or ''}:{limit}:{offset}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, ListResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
repo = ApiClientsRepository(db)
|
repo = ApiClientsRepository(db)
|
||||||
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
|
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
|
||||||
return ListResponse(
|
result = ListResponse(
|
||||||
items=[ApiClientItem.model_validate(i, from_attributes=True) for i in items],
|
items=[ApiClientItem.model_validate(i, from_attributes=True) for i in items],
|
||||||
total=total,
|
total=total,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
)
|
)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=20)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.post("/sync/from-provider")
|
@router.post("/sync/from-provider")
|
||||||
|
|||||||
@@ -7,11 +7,10 @@ from app.repositories.users_repo import UsersRepository
|
|||||||
from app.repositories.user_sites_repo import UserSitesRepository
|
from app.repositories.user_sites_repo import UserSitesRepository
|
||||||
from app.schemas.idp_admin import ProviderEnsureUserRequest, ProviderEnsureUserResponse
|
from app.schemas.idp_admin import ProviderEnsureUserRequest, ProviderEnsureUserResponse
|
||||||
from app.schemas.internal import InternalUpsertUserBySubResponse, InternalUserRoleItem, InternalUserRoleResponse
|
from app.schemas.internal import InternalUpsertUserBySubResponse, InternalUserRoleItem, InternalUserRoleResponse
|
||||||
from app.schemas.permissions import RoleSnapshotResponse
|
|
||||||
from app.schemas.users import UserUpsertBySubRequest
|
from app.schemas.users import UserUpsertBySubRequest
|
||||||
from app.security.api_client_auth import require_api_client
|
from app.security.api_client_auth import require_api_client
|
||||||
from app.services.idp_admin_service import ProviderAdminService
|
from app.services.idp_admin_service import ProviderAdminService
|
||||||
from app.services.permission_service import PermissionService
|
from app.services.runtime_cache import runtime_cache
|
||||||
|
|
||||||
router = APIRouter(prefix="/internal", tags=["internal"], dependencies=[Depends(require_api_client)])
|
router = APIRouter(prefix="/internal", tags=["internal"], dependencies=[Depends(require_api_client)])
|
||||||
|
|
||||||
@@ -68,8 +67,13 @@ def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, st
|
|||||||
|
|
||||||
@router.get("/users/{user_sub}/roles", response_model=InternalUserRoleResponse)
|
@router.get("/users/{user_sub}/roles", response_model=InternalUserRoleResponse)
|
||||||
def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUserRoleResponse:
|
def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUserRoleResponse:
|
||||||
|
cache_key = f"internal:user_roles:{user_sub}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, InternalUserRoleResponse):
|
||||||
|
return cached
|
||||||
|
|
||||||
rows = _build_user_role_rows(db, user_sub)
|
rows = _build_user_role_rows(db, user_sub)
|
||||||
return InternalUserRoleResponse(
|
result = InternalUserRoleResponse(
|
||||||
user_sub=user_sub,
|
user_sub=user_sub,
|
||||||
roles=[
|
roles=[
|
||||||
InternalUserRoleItem(
|
InternalUserRoleItem(
|
||||||
@@ -94,15 +98,8 @@ def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUser
|
|||||||
) in rows
|
) in rows
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=30)
|
||||||
|
return result
|
||||||
@router.get("/permissions/{user_sub}/snapshot", response_model=RoleSnapshotResponse)
|
|
||||||
def get_permission_snapshot(
|
|
||||||
user_sub: str,
|
|
||||||
db: Session = Depends(get_db),
|
|
||||||
) -> RoleSnapshotResponse:
|
|
||||||
rows = _build_user_role_rows(db, user_sub)
|
|
||||||
return PermissionService.build_role_snapshot(user_sub=user_sub, rows=rows)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/provider/users/ensure", response_model=ProviderEnsureUserResponse)
|
@router.post("/provider/users/ensure", response_model=ProviderEnsureUserResponse)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from app.schemas.auth import ProviderPrincipal, MeSummaryResponse
|
|||||||
from app.schemas.permissions import RoleSnapshotResponse
|
from app.schemas.permissions import RoleSnapshotResponse
|
||||||
from app.security.idp_jwt import require_authenticated_principal
|
from app.security.idp_jwt import require_authenticated_principal
|
||||||
from app.services.permission_service import PermissionService
|
from app.services.permission_service import PermissionService
|
||||||
|
from app.services.runtime_cache import runtime_cache
|
||||||
|
|
||||||
router = APIRouter(prefix="/me", tags=["me"])
|
router = APIRouter(prefix="/me", tags=["me"])
|
||||||
|
|
||||||
@@ -18,6 +19,10 @@ def get_me(
|
|||||||
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> MeSummaryResponse:
|
) -> MeSummaryResponse:
|
||||||
|
cache_key = f"me:{principal.sub}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, MeSummaryResponse):
|
||||||
|
return cached
|
||||||
try:
|
try:
|
||||||
users_repo = UsersRepository(db)
|
users_repo = UsersRepository(db)
|
||||||
user = users_repo.upsert_by_sub(
|
user = users_repo.upsert_by_sub(
|
||||||
@@ -28,13 +33,17 @@ def get_me(
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
status="active",
|
status="active",
|
||||||
)
|
)
|
||||||
return MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name)
|
result = MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=30)
|
||||||
|
return result
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
return MeSummaryResponse(
|
result = MeSummaryResponse(
|
||||||
sub=principal.sub,
|
sub=principal.sub,
|
||||||
email=principal.email,
|
email=principal.email,
|
||||||
display_name=principal.name or principal.preferred_username,
|
display_name=principal.name or principal.preferred_username,
|
||||||
)
|
)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=15)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get("/permissions/snapshot", response_model=RoleSnapshotResponse)
|
@router.get("/permissions/snapshot", response_model=RoleSnapshotResponse)
|
||||||
@@ -42,6 +51,10 @@ def get_my_permission_snapshot(
|
|||||||
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> RoleSnapshotResponse:
|
) -> RoleSnapshotResponse:
|
||||||
|
cache_key = f"me:permissions_snapshot:{principal.sub}"
|
||||||
|
cached = runtime_cache.get(cache_key)
|
||||||
|
if isinstance(cached, RoleSnapshotResponse):
|
||||||
|
return cached
|
||||||
try:
|
try:
|
||||||
users_repo = UsersRepository(db)
|
users_repo = UsersRepository(db)
|
||||||
user_sites_repo = UserSitesRepository(db)
|
user_sites_repo = UserSitesRepository(db)
|
||||||
@@ -68,6 +81,10 @@ def get_my_permission_snapshot(
|
|||||||
)
|
)
|
||||||
for site, company, role, system in rows
|
for site, company, role, system in rows
|
||||||
]
|
]
|
||||||
return PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized)
|
result = PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized)
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=30)
|
||||||
|
return result
|
||||||
except SQLAlchemyError:
|
except SQLAlchemyError:
|
||||||
return RoleSnapshotResponse(user_sub=principal.sub, roles=[])
|
result = RoleSnapshotResponse(user_sub=principal.sub, roles=[])
|
||||||
|
runtime_cache.set(cache_key, result, ttl_seconds=10)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -35,6 +35,12 @@ class Settings(BaseSettings):
|
|||||||
public_frontend_origins: Annotated[list[str], NoDecode] = ["https://member.ose.tw"]
|
public_frontend_origins: Annotated[list[str], NoDecode] = ["https://member.ose.tw"]
|
||||||
internal_shared_secret: str = ""
|
internal_shared_secret: str = ""
|
||||||
admin_required_groups: Annotated[list[str], NoDecode] = []
|
admin_required_groups: Annotated[list[str], NoDecode] = []
|
||||||
|
member_required_realm_roles: Annotated[list[str], NoDecode] = ["admin", "manager"]
|
||||||
|
admin_required_realm_roles: Annotated[list[str], NoDecode] = ["admin", "manager"]
|
||||||
|
cache_backend: str = "memory"
|
||||||
|
cache_redis_url: str = "redis://127.0.0.1:6379/0"
|
||||||
|
cache_prefix: str = "memberapi"
|
||||||
|
cache_default_ttl_seconds: int = 30
|
||||||
|
|
||||||
@field_validator("public_frontend_origins", mode="before")
|
@field_validator("public_frontend_origins", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -54,6 +60,15 @@ class Settings(BaseSettings):
|
|||||||
return []
|
return []
|
||||||
return [part.strip() for part in value.split(",") if part.strip()]
|
return [part.strip() for part in value.split(",") if part.strip()]
|
||||||
|
|
||||||
|
@field_validator("member_required_realm_roles", "admin_required_realm_roles", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def parse_roles_csv(cls, value: str | list[str]) -> list[str]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [str(v).strip() for v in value if str(v).strip()]
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
return [part.strip() for part in value.split(",") if part.strip()]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def database_url(self) -> str:
|
def database_url(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
13
app/main.py
13
app/main.py
@@ -7,6 +7,7 @@ from app.api.internal_catalog import router as internal_catalog_router
|
|||||||
from app.api.internal import router as internal_router
|
from app.api.internal import router as internal_router
|
||||||
from app.api.me import router as me_router
|
from app.api.me import router as me_router
|
||||||
from app.core.config import get_settings
|
from app.core.config import get_settings
|
||||||
|
from app.services.runtime_cache import runtime_cache
|
||||||
|
|
||||||
app = FastAPI(title="memberapi.ose.tw", version="0.1.0")
|
app = FastAPI(title="memberapi.ose.tw", version="0.1.0")
|
||||||
|
|
||||||
@@ -20,6 +21,18 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def invalidate_runtime_cache_on_cud(request, call_next):
|
||||||
|
response = await call_next(request)
|
||||||
|
if (
|
||||||
|
request.method in {"POST", "PUT", "PATCH", "DELETE"}
|
||||||
|
and request.url.path.startswith(("/admin", "/internal"))
|
||||||
|
and response.status_code < 400
|
||||||
|
):
|
||||||
|
runtime_cache.bump_revision()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
@app.get("/healthz", tags=["health"])
|
@app.get("/healthz", tags=["health"])
|
||||||
def healthz() -> dict[str, str]:
|
def healthz() -> dict[str, str]:
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|||||||
@@ -30,8 +30,16 @@ class SiteRolesRepository:
|
|||||||
)
|
)
|
||||||
return list(self.db.execute(stmt).all())
|
return list(self.db.execute(stmt).all())
|
||||||
|
|
||||||
def set_site_roles(self, *, site_id: str, role_ids: list[str]) -> None:
|
def set_site_roles(self, *, site_id: str, role_ids: list[str], commit: bool = True) -> None:
|
||||||
self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id))
|
self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id))
|
||||||
for role_id in role_ids:
|
for role_id in role_ids:
|
||||||
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
|
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
|
||||||
self.db.commit()
|
if commit:
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
def set_role_sites(self, *, role_id: str, site_ids: list[str], commit: bool = True) -> None:
|
||||||
|
self.db.execute(delete(SiteRole).where(SiteRole.role_id == role_id))
|
||||||
|
for site_id in site_ids:
|
||||||
|
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
|
||||||
|
if commit:
|
||||||
|
self.db.commit()
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ class UsersRepository:
|
|||||||
provider_user_id: str | None = None,
|
provider_user_id: str | None = None,
|
||||||
) -> User:
|
) -> User:
|
||||||
user = self.get_by_sub(user_sub)
|
user = self.get_by_sub(user_sub)
|
||||||
|
changed = False
|
||||||
if user is None:
|
if user is None:
|
||||||
user = User(
|
user = User(
|
||||||
user_sub=user_sub,
|
user_sub=user_sub,
|
||||||
@@ -68,17 +69,30 @@ class UsersRepository:
|
|||||||
status=status,
|
status=status,
|
||||||
)
|
)
|
||||||
self.db.add(user)
|
self.db.add(user)
|
||||||
|
changed = True
|
||||||
else:
|
else:
|
||||||
if provider_user_id is not None:
|
if provider_user_id is not None and user.provider_user_id != provider_user_id:
|
||||||
user.provider_user_id = provider_user_id
|
user.provider_user_id = provider_user_id
|
||||||
user.username = username
|
changed = True
|
||||||
user.email = email
|
if user.username != username:
|
||||||
user.display_name = display_name
|
user.username = username
|
||||||
user.is_active = is_active
|
changed = True
|
||||||
user.status = status
|
if user.email != email:
|
||||||
|
user.email = email
|
||||||
|
changed = True
|
||||||
|
if user.display_name != display_name:
|
||||||
|
user.display_name = display_name
|
||||||
|
changed = True
|
||||||
|
if user.is_active != is_active:
|
||||||
|
user.is_active = is_active
|
||||||
|
changed = True
|
||||||
|
if user.status != status:
|
||||||
|
user.status = status
|
||||||
|
changed = True
|
||||||
|
|
||||||
self.db.commit()
|
if changed:
|
||||||
self.db.refresh(user)
|
self.db.commit()
|
||||||
|
self.db.refresh(user)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def update_member(
|
def update_member(
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ class ProviderPrincipal(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
preferred_username: str | None = None
|
preferred_username: str | None = None
|
||||||
groups: list[str] = Field(default_factory=list)
|
groups: list[str] = Field(default_factory=list)
|
||||||
|
realm_roles: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class MeSummaryResponse(BaseModel):
|
class MeSummaryResponse(BaseModel):
|
||||||
|
|||||||
@@ -5,33 +5,27 @@ from app.schemas.auth import ProviderPrincipal
|
|||||||
from app.security.idp_jwt import require_authenticated_principal
|
from app.security.idp_jwt import require_authenticated_principal
|
||||||
|
|
||||||
|
|
||||||
def _expand_group_aliases(groups: set[str]) -> set[str]:
|
def _normalize_roles(values: set[str]) -> set[str]:
|
||||||
expanded: set[str] = set()
|
normalized: set[str] = set()
|
||||||
for group in groups:
|
for value in values:
|
||||||
value = group.strip().lower()
|
role = value.strip().lower()
|
||||||
if not value:
|
if role:
|
||||||
continue
|
normalized.add(role)
|
||||||
expanded.add(value)
|
return normalized
|
||||||
stripped = value.lstrip("/")
|
|
||||||
if stripped:
|
|
||||||
expanded.add(stripped)
|
|
||||||
if "/" in stripped:
|
|
||||||
expanded.add(stripped.rsplit("/", 1)[-1])
|
|
||||||
return expanded
|
|
||||||
|
|
||||||
|
|
||||||
def require_admin_principal(
|
def require_admin_principal(
|
||||||
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
principal: ProviderPrincipal = Depends(require_authenticated_principal),
|
||||||
) -> ProviderPrincipal:
|
) -> ProviderPrincipal:
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
required_groups = _expand_group_aliases(set(settings.admin_required_groups))
|
required_roles = _normalize_roles(set(settings.admin_required_realm_roles))
|
||||||
|
|
||||||
if not required_groups:
|
if not required_roles:
|
||||||
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="admin_policy_not_configured")
|
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="admin_policy_not_configured")
|
||||||
|
|
||||||
principal_groups = _expand_group_aliases(set(principal.groups))
|
principal_roles = _normalize_roles(set(principal.realm_roles))
|
||||||
group_ok = bool(required_groups.intersection(principal_groups))
|
role_ok = bool(required_roles.intersection(principal_roles))
|
||||||
|
|
||||||
if not group_ok:
|
if not role_ok:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="admin_forbidden")
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="admin_forbidden")
|
||||||
return principal
|
return principal
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import jwt
|
import jwt
|
||||||
@@ -30,6 +31,7 @@ class ProviderTokenVerifier:
|
|||||||
admin_realm: str | None,
|
admin_realm: str | None,
|
||||||
admin_client_id: str | None,
|
admin_client_id: str | None,
|
||||||
admin_client_secret: str | None,
|
admin_client_secret: str | None,
|
||||||
|
member_required_realm_roles: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.issuer = issuer.strip() if issuer else None
|
self.issuer = issuer.strip() if issuer else None
|
||||||
self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer)
|
self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer)
|
||||||
@@ -49,7 +51,20 @@ class ProviderTokenVerifier:
|
|||||||
if not self.jwks_url:
|
if not self.jwks_url:
|
||||||
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
|
||||||
|
|
||||||
self._jwk_client = jwt.PyJWKClient(self.jwks_url)
|
self._jwk_client = jwt.PyJWKClient(
|
||||||
|
self.jwks_url,
|
||||||
|
cache_jwk_set=True,
|
||||||
|
lifespan=600,
|
||||||
|
headers={
|
||||||
|
"Accept": "application/json",
|
||||||
|
"User-Agent": "member-ose-backend/1.0",
|
||||||
|
},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
self._admin_token_cached: str | None = None
|
||||||
|
self._admin_token_expires_at: float = 0
|
||||||
|
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
|
||||||
|
self.member_required_realm_roles = {r.strip().lower() for r in member_required_realm_roles if r and r.strip()}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
def _infer_introspection_endpoint(issuer: str | None) -> str | None:
|
||||||
@@ -138,10 +153,14 @@ class ProviderTokenVerifier:
|
|||||||
name=name,
|
name=name,
|
||||||
preferred_username=preferred_username,
|
preferred_username=preferred_username,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
|
realm_roles=principal.realm_roles,
|
||||||
)
|
)
|
||||||
return self._enrich_groups_from_admin(enriched)
|
return self._enrich_groups_from_admin(enriched)
|
||||||
|
|
||||||
def _get_admin_token(self) -> str | None:
|
def _get_admin_token(self) -> str | None:
|
||||||
|
now = time.time()
|
||||||
|
if self._admin_token_cached and now < self._admin_token_expires_at:
|
||||||
|
return self._admin_token_cached
|
||||||
if (
|
if (
|
||||||
not self.base_url
|
not self.base_url
|
||||||
or not self.admin_realm
|
or not self.admin_realm
|
||||||
@@ -167,7 +186,14 @@ class ProviderTokenVerifier:
|
|||||||
if resp.status_code >= 400:
|
if resp.status_code >= 400:
|
||||||
return None
|
return None
|
||||||
token = resp.json().get("access_token")
|
token = resp.json().get("access_token")
|
||||||
return str(token) if token else None
|
expires_in = resp.json().get("expires_in")
|
||||||
|
if token:
|
||||||
|
ttl = int(expires_in) if isinstance(expires_in, int) else 30
|
||||||
|
# Keep a small buffer to avoid using near-expiry admin token.
|
||||||
|
self._admin_token_cached = str(token)
|
||||||
|
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
|
||||||
|
return self._admin_token_cached
|
||||||
|
return None
|
||||||
|
|
||||||
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
|
||||||
if principal.groups:
|
if principal.groups:
|
||||||
@@ -210,9 +236,22 @@ class ProviderTokenVerifier:
|
|||||||
name=principal.name,
|
name=principal.name,
|
||||||
preferred_username=principal.preferred_username,
|
preferred_username=principal.preferred_username,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
|
realm_roles=principal.realm_roles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _require_member_role(self, principal: ProviderPrincipal) -> None:
|
||||||
|
if not self.member_required_realm_roles:
|
||||||
|
return
|
||||||
|
user_roles = {r.strip().lower() for r in principal.realm_roles if isinstance(r, str) and r.strip()}
|
||||||
|
if not user_roles.intersection(self.member_required_realm_roles):
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="member_forbidden")
|
||||||
|
|
||||||
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
def verify_access_token(self, token: str) -> ProviderPrincipal:
|
||||||
|
now = time.time()
|
||||||
|
cached = self._principal_cache.get(token)
|
||||||
|
if cached and now < cached[0]:
|
||||||
|
return cached[1]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
header = jwt.get_unverified_header(token)
|
header = jwt.get_unverified_header(token)
|
||||||
algorithm = str(header.get("alg", "")).upper()
|
algorithm = str(header.get("alg", "")).upper()
|
||||||
@@ -247,7 +286,7 @@ class ProviderTokenVerifier:
|
|||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
claims = self._introspect_token(token)
|
claims = self._introspect_token(token)
|
||||||
if claims:
|
if claims:
|
||||||
logger.warning("jwt verify failed, used introspection fallback: %s", exc)
|
logger.debug("jwt verify failed, used introspection fallback: %s", exc)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc
|
||||||
|
|
||||||
@@ -261,8 +300,31 @@ class ProviderTokenVerifier:
|
|||||||
name=claims.get("name"),
|
name=claims.get("name"),
|
||||||
preferred_username=claims.get("preferred_username"),
|
preferred_username=claims.get("preferred_username"),
|
||||||
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
|
||||||
|
realm_roles=[
|
||||||
|
str(r)
|
||||||
|
for r in (
|
||||||
|
claims.get("realm_access", {}).get("roles", [])
|
||||||
|
if isinstance(claims.get("realm_access"), dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if str(r)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
return self._enrich_from_userinfo(principal, token)
|
enriched = self._enrich_from_userinfo(principal, token)
|
||||||
|
self._require_member_role(enriched)
|
||||||
|
|
||||||
|
exp = claims.get("exp")
|
||||||
|
if isinstance(exp, int):
|
||||||
|
cache_until = min(float(exp), now + 60)
|
||||||
|
else:
|
||||||
|
cache_until = now + 30
|
||||||
|
if cache_until > now:
|
||||||
|
self._principal_cache[token] = (cache_until, enriched)
|
||||||
|
if len(self._principal_cache) > 512:
|
||||||
|
# Simple bound to avoid unbounded memory growth.
|
||||||
|
self._principal_cache.clear()
|
||||||
|
|
||||||
|
return enriched
|
||||||
|
|
||||||
|
|
||||||
@lru_cache
|
@lru_cache
|
||||||
@@ -281,6 +343,7 @@ def _get_verifier() -> ProviderTokenVerifier:
|
|||||||
admin_realm=settings.keycloak_admin_realm,
|
admin_realm=settings.keycloak_admin_realm,
|
||||||
admin_client_id=settings.keycloak_admin_client_id,
|
admin_client_id=settings.keycloak_admin_client_id,
|
||||||
admin_client_secret=settings.keycloak_admin_client_secret,
|
admin_client_secret=settings.keycloak_admin_client_secret,
|
||||||
|
member_required_realm_roles=settings.member_required_realm_roles,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -413,6 +413,65 @@ class ProviderAdminService:
|
|||||||
return client_uuid
|
return client_uuid
|
||||||
raise HTTPException(status_code=404, detail="provider_client_not_found")
|
raise HTTPException(status_code=404, detail="provider_client_not_found")
|
||||||
|
|
||||||
|
def _get_client_role_representation(self, client: httpx.Client, *, client_uuid: str, role_name: str) -> dict:
|
||||||
|
resp = client.get(f"/admin/realms/{self.realm}/clients/{client_uuid}/roles/{role_name}")
|
||||||
|
if resp.status_code == 404:
|
||||||
|
raise HTTPException(status_code=404, detail=f"provider_role_not_found:{role_name}")
|
||||||
|
if resp.status_code >= 400:
|
||||||
|
raise HTTPException(status_code=502, detail="idp_lookup_failed")
|
||||||
|
payload = resp.json() if resp.content else {}
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise HTTPException(status_code=502, detail="idp_lookup_failed")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
def set_group_client_roles(self, *, group_id: str, provider_client_id: str, role_names: list[str]) -> None:
|
||||||
|
if not group_id:
|
||||||
|
raise HTTPException(status_code=400, detail="provider_group_id_required")
|
||||||
|
|
||||||
|
desired_names = [name.strip() for name in role_names if isinstance(name, str) and name.strip()]
|
||||||
|
desired_name_set = set(desired_names)
|
||||||
|
|
||||||
|
with self._client() as client:
|
||||||
|
client_uuid = self._resolve_client_uuid(client, provider_client_id)
|
||||||
|
|
||||||
|
current_resp = client.get(f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}")
|
||||||
|
if current_resp.status_code >= 400:
|
||||||
|
raise HTTPException(status_code=502, detail="idp_group_role_mapping_lookup_failed")
|
||||||
|
current_payload = current_resp.json() if current_resp.content else []
|
||||||
|
current_rows = current_payload if isinstance(current_payload, list) else []
|
||||||
|
current_map: dict[str, dict] = {}
|
||||||
|
for row in current_rows:
|
||||||
|
if not isinstance(row, dict):
|
||||||
|
continue
|
||||||
|
name = str(row.get("name", "")).strip()
|
||||||
|
if name:
|
||||||
|
current_map[name] = row
|
||||||
|
|
||||||
|
to_add_names = sorted(desired_name_set - set(current_map.keys()))
|
||||||
|
to_remove_names = sorted(set(current_map.keys()) - desired_name_set)
|
||||||
|
|
||||||
|
if to_add_names:
|
||||||
|
add_payload = [
|
||||||
|
self._get_client_role_representation(client, client_uuid=client_uuid, role_name=role_name)
|
||||||
|
for role_name in to_add_names
|
||||||
|
]
|
||||||
|
add_resp = client.post(
|
||||||
|
f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}",
|
||||||
|
json=add_payload,
|
||||||
|
)
|
||||||
|
if add_resp.status_code >= 400:
|
||||||
|
raise HTTPException(status_code=502, detail="idp_group_role_mapping_add_failed")
|
||||||
|
|
||||||
|
if to_remove_names:
|
||||||
|
remove_payload = [current_map[name] for name in to_remove_names]
|
||||||
|
remove_resp = client.request(
|
||||||
|
"DELETE",
|
||||||
|
f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}",
|
||||||
|
json=remove_payload,
|
||||||
|
)
|
||||||
|
if remove_resp.status_code >= 400:
|
||||||
|
raise HTTPException(status_code=502, detail="idp_group_role_mapping_remove_failed")
|
||||||
|
|
||||||
def ensure_client_role(
|
def ensure_client_role(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|||||||
175
app/services/runtime_cache.py
Normal file
175
app/services/runtime_cache.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from hashlib import sha256
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
from threading import RLock
|
||||||
|
import time
|
||||||
|
from typing import Callable, Protocol, TypeVar
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis
|
||||||
|
except Exception: # pragma: no cover - optional dependency in local dev
|
||||||
|
redis = None
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheBackend(Protocol):
|
||||||
|
def get(self, key: str) -> object | None: ...
|
||||||
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ...
|
||||||
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: ...
|
||||||
|
def bump_revision(self) -> int: ...
|
||||||
|
def revision(self) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _CacheEntry:
|
||||||
|
value: object
|
||||||
|
expires_at: float
|
||||||
|
revision: int
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRuntimeCache:
|
||||||
|
"""Simple in-memory cache for local/single-instance environments."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = RLock()
|
||||||
|
self._revision = 0
|
||||||
|
self._entries: dict[str, _CacheEntry] = {}
|
||||||
|
|
||||||
|
def get(self, key: str) -> object | None:
|
||||||
|
now = time.time()
|
||||||
|
with self._lock:
|
||||||
|
entry = self._entries.get(key)
|
||||||
|
if not entry:
|
||||||
|
return None
|
||||||
|
if entry.expires_at <= now or entry.revision != self._revision:
|
||||||
|
self._entries.pop(key, None)
|
||||||
|
return None
|
||||||
|
return entry.value
|
||||||
|
|
||||||
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
|
||||||
|
now = time.time()
|
||||||
|
with self._lock:
|
||||||
|
self._entries[key] = _CacheEntry(
|
||||||
|
value=value,
|
||||||
|
expires_at=now + max(ttl_seconds, 1),
|
||||||
|
revision=self._revision,
|
||||||
|
)
|
||||||
|
if len(self._entries) > 2000:
|
||||||
|
self._entries.clear()
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
|
||||||
|
cached = self.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached # type: ignore[return-value]
|
||||||
|
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
|
||||||
|
|
||||||
|
def bump_revision(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
self._revision += 1
|
||||||
|
if self._revision > 1_000_000_000:
|
||||||
|
self._revision = 1
|
||||||
|
self._entries.clear()
|
||||||
|
return self._revision
|
||||||
|
|
||||||
|
def revision(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._revision
|
||||||
|
|
||||||
|
|
||||||
|
class RedisRuntimeCache:
|
||||||
|
"""Redis-backed cache for multi-instance deployments."""
|
||||||
|
|
||||||
|
def __init__(self, *, redis_url: str, prefix: str, default_ttl_seconds: int = 30) -> None:
|
||||||
|
if redis is None:
|
||||||
|
raise RuntimeError("redis_package_not_installed")
|
||||||
|
self._redis = redis.Redis.from_url(redis_url, decode_responses=False)
|
||||||
|
self._prefix = prefix.strip() or "memberapi"
|
||||||
|
self._default_ttl_seconds = max(int(default_ttl_seconds), 1)
|
||||||
|
self._revision_key = f"{self._prefix}:cache:revision"
|
||||||
|
self._rev_cache_value = 0
|
||||||
|
self._rev_cache_expires_at = 0.0
|
||||||
|
|
||||||
|
def _cache_key(self, key: str, revision: int) -> str:
|
||||||
|
key_hash = sha256(key.encode("utf-8")).hexdigest()
|
||||||
|
return f"{self._prefix}:cache:{revision}:{key_hash}"
|
||||||
|
|
||||||
|
def _get_revision_cached(self) -> int:
|
||||||
|
now = time.time()
|
||||||
|
if now < self._rev_cache_expires_at:
|
||||||
|
return self._rev_cache_value
|
||||||
|
try:
|
||||||
|
raw = self._redis.get(self._revision_key)
|
||||||
|
value = int(raw) if raw else 0
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
self._rev_cache_value = value
|
||||||
|
self._rev_cache_expires_at = now + 1.0
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get(self, key: str) -> object | None:
|
||||||
|
try:
|
||||||
|
revision = self._get_revision_cached()
|
||||||
|
raw = self._redis.get(self._cache_key(key, revision))
|
||||||
|
if raw is None:
|
||||||
|
return None
|
||||||
|
return pickle.loads(raw)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
|
||||||
|
ttl = max(int(ttl_seconds), 1) if ttl_seconds else self._default_ttl_seconds
|
||||||
|
try:
|
||||||
|
revision = self._get_revision_cached()
|
||||||
|
self._redis.setex(self._cache_key(key, revision), ttl, pickle.dumps(value))
|
||||||
|
except Exception:
|
||||||
|
# Keep request path healthy even when Redis has issues.
|
||||||
|
pass
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
|
||||||
|
cached = self.get(key)
|
||||||
|
if cached is not None:
|
||||||
|
return cached # type: ignore[return-value]
|
||||||
|
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
|
||||||
|
|
||||||
|
def bump_revision(self) -> int:
|
||||||
|
try:
|
||||||
|
value = int(self._redis.incr(self._revision_key))
|
||||||
|
self._rev_cache_value = value
|
||||||
|
self._rev_cache_expires_at = time.time() + 1.0
|
||||||
|
return value
|
||||||
|
except Exception:
|
||||||
|
# Fail-open: keep app usable; caller still succeeds.
|
||||||
|
return self._get_revision_cached()
|
||||||
|
|
||||||
|
def revision(self) -> int:
|
||||||
|
return self._get_revision_cached()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_runtime_cache() -> CacheBackend:
|
||||||
|
settings = get_settings()
|
||||||
|
backend = (settings.cache_backend or "memory").strip().lower()
|
||||||
|
if backend == "redis":
|
||||||
|
try:
|
||||||
|
cache = RedisRuntimeCache(
|
||||||
|
redis_url=settings.cache_redis_url,
|
||||||
|
prefix=settings.cache_prefix,
|
||||||
|
default_ttl_seconds=settings.cache_default_ttl_seconds,
|
||||||
|
)
|
||||||
|
logger.info("runtime cache backend: redis")
|
||||||
|
return cache
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("redis cache unavailable, fallback to memory: %s", exc)
|
||||||
|
logger.info("runtime cache backend: memory")
|
||||||
|
return MemoryRuntimeCache()
|
||||||
|
|
||||||
|
|
||||||
|
runtime_cache: CacheBackend = _build_runtime_cache()
|
||||||
@@ -13,6 +13,7 @@ dependencies = [
|
|||||||
"passlib[bcrypt]>=1.7.4",
|
"passlib[bcrypt]>=1.7.4",
|
||||||
"pyjwt[crypto]>=2.10.1",
|
"pyjwt[crypto]>=2.10.1",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
|
"redis>=5.2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
Reference in New Issue
Block a user