Compare commits

...

10 Commits

19 changed files with 646 additions and 59 deletions

13
.dockerignore Normal file
View File

@@ -0,0 +1,13 @@
.git
.gitignore
.venv
__pycache__
*.pyc
*.pyo
*.pyd
.pytest_cache
.ruff_cache
tests
.env
.env.development
*.log

View File

@@ -19,4 +19,10 @@ KEYCLOAK_ADMIN_REALM=master
PUBLIC_FRONTEND_ORIGINS=http://127.0.0.1:5173,http://localhost:5173 PUBLIC_FRONTEND_ORIGINS=http://127.0.0.1:5173,http://localhost:5173
INTERNAL_SHARED_SECRET=CHANGE_ME INTERNAL_SHARED_SECRET=CHANGE_ME
ADMIN_REQUIRED_GROUPS=member-admin MEMBER_REQUIRED_REALM_ROLES=admin,manager
ADMIN_REQUIRED_REALM_ROLES=admin,manager
CACHE_BACKEND=memory
CACHE_REDIS_URL=redis://127.0.0.1:6379/0
CACHE_PREFIX=memberapi
CACHE_DEFAULT_TTL_SECONDS=30

View File

@@ -25,4 +25,11 @@ KEYCLOAK_ADMIN_REALM=
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
INTERNAL_SHARED_SECRET=CHANGE_ME INTERNAL_SHARED_SECRET=CHANGE_ME
ADMIN_REQUIRED_GROUPS=member-admin MEMBER_REQUIRED_REALM_ROLES=admin,manager
ADMIN_REQUIRED_REALM_ROLES=admin,manager
# Cache backend: memory | redis
CACHE_BACKEND=memory
CACHE_REDIS_URL=redis://127.0.0.1:6379/0
CACHE_PREFIX=memberapi
CACHE_DEFAULT_TTL_SECONDS=30

View File

@@ -25,3 +25,11 @@ KEYCLOAK_ADMIN_REALM=
PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw PUBLIC_FRONTEND_ORIGINS=https://member.ose.tw,https://mkt.ose.tw,https://admin.ose.tw
INTERNAL_SHARED_SECRET=CHANGE_ME INTERNAL_SHARED_SECRET=CHANGE_ME
MEMBER_REQUIRED_REALM_ROLES=admin,manager
ADMIN_REQUIRED_REALM_ROLES=admin,manager
# Cache backend: memory | redis
CACHE_BACKEND=redis
CACHE_REDIS_URL=redis://redis:6379/0
CACHE_PREFIX=memberapi
CACHE_DEFAULT_TTL_SECONDS=30

30
Dockerfile Normal file
View File

@@ -0,0 +1,30 @@
FROM python:3.12-alpine AS builder
WORKDIR /app
RUN apk add --no-cache build-base libffi-dev openssl-dev cargo
COPY pyproject.toml /app/pyproject.toml
COPY app /app/app
COPY scripts /app/scripts
COPY README.md /app/README.md
RUN pip install --no-cache-dir --upgrade pip && \
pip wheel --no-cache-dir --wheel-dir /wheels .
FROM python:3.12-alpine
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
WORKDIR /app
RUN apk add --no-cache libstdc++ libffi openssl
COPY --from=builder /wheels /wheels
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir /wheels/*
EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]

View File

@@ -12,6 +12,30 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
./scripts/start_dev.sh ./scripts/start_dev.sh
``` ```
## Docker (VPS / Production)
> 目前 Dockerfile 為 Alpine 多階段建置(較小體積)。
Build image:
```bash
cd backend
docker build -t memberapi-backend:latest .
```
Run container:
```bash
docker run -d \
--name memberapi-backend \
--restart unless-stopped \
-p 127.0.0.1:8000:8000 \
--env-file .env \
memberapi-backend:latest
```
Health check:
```bash
curl http://127.0.0.1:8000/healthz
```
## Keycloak env ## Keycloak env
- Required: - Required:
@@ -29,6 +53,8 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
- `KEYCLOAK_USERINFO_ENDPOINT` - `KEYCLOAK_USERINFO_ENDPOINT`
- `KEYCLOAK_AUDIENCE` - `KEYCLOAK_AUDIENCE`
- `KEYCLOAK_VERIFY_TLS` - `KEYCLOAK_VERIFY_TLS`
- `MEMBER_REQUIRED_REALM_ROLES` (default: `admin,manager`)
- `ADMIN_REQUIRED_REALM_ROLES` (default: `admin,manager`)
## Main APIs ## Main APIs
@@ -38,7 +64,7 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
- `GET /me` (Bearer token required) - `GET /me` (Bearer token required)
- `GET /me/permissions/snapshot` (Bearer token required) - `GET /me/permissions/snapshot` (Bearer token required)
### Admin APIs (Bearer + admin group required) ### Admin APIs (Bearer + admin realm role required)
- `GET/POST/PATCH/DELETE /admin/companies` - `GET/POST/PATCH/DELETE /admin/companies`
- `GET/POST/PATCH/DELETE /admin/sites` - `GET/POST/PATCH/DELETE /admin/sites`
- `GET/POST/PATCH/DELETE /admin/systems` - `GET/POST/PATCH/DELETE /admin/systems`
@@ -57,5 +83,4 @@ psql "$DATABASE_URL" -f scripts/init_schema.sql
- `GET /internal/members` - `GET /internal/members`
- `POST /internal/users/upsert-by-sub` - `POST /internal/users/upsert-by-sub`
- `GET /internal/users/{user_sub}/roles` - `GET /internal/users/{user_sub}/roles`
- `GET /internal/permissions/{user_sub}/snapshot`
- `POST /internal/idp/users/ensure` - `POST /internal/idp/users/ensure`

View File

@@ -53,6 +53,7 @@ from app.security.admin_guard import require_admin_principal
from app.security.api_client_auth import hash_api_key from app.security.api_client_auth import hash_api_key
from app.services.idp_admin_service import ProviderAdminService from app.services.idp_admin_service import ProviderAdminService
from app.services.idp_catalog_sync import sync_from_provider from app.services.idp_catalog_sync import sync_from_provider
from app.services.runtime_cache import runtime_cache
from app.core.config import get_settings from app.core.config import get_settings
router = APIRouter( router = APIRouter(
@@ -129,6 +130,31 @@ def _site_group_name(display_name: str, site_key: str) -> str:
return normalized return normalized
def _sync_site_client_roles(
*,
idp: ProviderAdminService,
site,
site_role_rows,
provider_client_ids: set[str],
) -> None:
if not site.provider_group_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"site_provider_group_missing:{site.site_key}")
role_names_by_client: dict[str, list[str]] = {}
for _, role, system in site_role_rows:
provider_client_id = str(system.name or "").strip()
if not provider_client_id:
continue
role_names_by_client.setdefault(provider_client_id, []).append(role.name)
for provider_client_id in sorted(provider_client_ids):
idp.set_group_client_roles(
group_id=site.provider_group_id,
provider_client_id=provider_client_id,
role_names=role_names_by_client.get(provider_client_id, []),
)
@router.get("/companies", response_model=ListResponse) @router.get("/companies", response_model=ListResponse)
def list_companies( def list_companies(
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -136,9 +162,16 @@ def list_companies(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
cache_key = f"admin:companies:{keyword or ''}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
repo = CompaniesRepository(db) repo = CompaniesRepository(db)
items, total = repo.list(keyword=keyword, limit=limit, offset=offset) items, total = repo.list(keyword=keyword, limit=limit, offset=offset)
return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset) result = ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset)
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/companies", response_model=CompanyItem) @router.post("/companies", response_model=CompanyItem)
@@ -226,6 +259,11 @@ def list_sites(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
cache_key = f"admin:sites:{keyword or ''}:{company_key or ''}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
companies_repo = CompaniesRepository(db) companies_repo = CompaniesRepository(db)
sites_repo = SitesRepository(db) sites_repo = SitesRepository(db)
company_id = None company_id = None
@@ -239,7 +277,9 @@ def list_sites(
company_map = {c.id: c for c in companies} company_map = {c.id: c for c in companies}
items, total = sites_repo.list(keyword=keyword, company_id=company_id, limit=limit, offset=offset) items, total = sites_repo.list(keyword=keyword, company_id=company_id, limit=limit, offset=offset)
response_items = [_site_item(i, company_map[i.company_id]) for i in items if i.company_id in company_map] response_items = [_site_item(i, company_map[i.company_id]) for i in items if i.company_id in company_map]
return ListResponse(items=response_items, total=total, limit=limit, offset=offset) result = ListResponse(items=response_items, total=total, limit=limit, offset=offset)
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/sites", response_model=SiteItem) @router.post("/sites", response_model=SiteItem)
@@ -350,9 +390,16 @@ def list_systems(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
cache_key = f"admin:systems:{keyword or ''}:{status_filter or ''}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
repo = SystemsRepository(db) repo = SystemsRepository(db)
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset) result = ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset)
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/systems", response_model=SystemItem) @router.post("/systems", response_model=SystemItem)
@@ -379,6 +426,11 @@ def list_roles(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
cache_key = f"admin:roles:{keyword or ''}:{system_key or ''}:{status_filter or ''}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
systems_repo = SystemsRepository(db) systems_repo = SystemsRepository(db)
roles_repo = RolesRepository(db) roles_repo = RolesRepository(db)
@@ -408,7 +460,9 @@ def list_roles(
for row in rows for row in rows
if row.system_id in system_map if row.system_id in system_map
] ]
return ListResponse(items=items, total=total, limit=limit, offset=offset) result = ListResponse(items=items, total=total, limit=limit, offset=offset)
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/roles", response_model=RoleItem) @router.post("/roles", response_model=RoleItem)
@@ -597,11 +651,15 @@ def assign_site_roles(site_key: str, payload: SiteRoleAssignRequest, db: Session
sites_repo = SitesRepository(db) sites_repo = SitesRepository(db)
roles_repo = RolesRepository(db) roles_repo = RolesRepository(db)
site_roles_repo = SiteRolesRepository(db) site_roles_repo = SiteRolesRepository(db)
idp = ProviderAdminService(get_settings())
site = sites_repo.get_by_key(site_key) site = sites_repo.get_by_key(site_key)
if not site: if not site:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="site_not_found")
current_rows = site_roles_repo.list_site_role_rows(site.id)
current_client_ids = {str(system.name or "").strip() for _, _, system in current_rows if str(system.name or "").strip()}
role_ids: list[str] = [] role_ids: list[str] = []
for role_key in list(dict.fromkeys(payload.role_keys)): for role_key in list(dict.fromkeys(payload.role_keys)):
role = roles_repo.get_by_key(role_key) role = roles_repo.get_by_key(role_key)
@@ -609,7 +667,23 @@ def assign_site_roles(site_key: str, payload: SiteRoleAssignRequest, db: Session
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"role_not_found:{role_key}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"role_not_found:{role_key}")
role_ids.append(role.id) role_ids.append(role.id)
site_roles_repo.set_site_roles(site_id=site.id, role_ids=role_ids) site_roles_repo.set_site_roles(site_id=site.id, role_ids=role_ids, commit=False)
updated_rows = site_roles_repo.list_site_role_rows(site.id)
updated_client_ids = {str(system.name or "").strip() for _, _, system in updated_rows if str(system.name or "").strip()}
clients_to_sync = current_client_ids | updated_client_ids
try:
_sync_site_client_roles(
idp=idp,
site=site,
site_role_rows=updated_rows,
provider_client_ids=clients_to_sync,
)
db.commit()
except Exception:
db.rollback()
raise
return list_site_roles(site_key=site_key, db=db) return list_site_roles(site_key=site_key, db=db)
@@ -646,6 +720,58 @@ def list_role_sites(role_key: str, db: Session = Depends(get_db)) -> RoleSitesRe
return RoleSitesResponse(role_key=role_key, sites=result) return RoleSitesResponse(role_key=role_key, sites=result)
@router.put("/roles/{role_key}/sites", response_model=RoleSitesResponse)
def assign_role_sites(role_key: str, payload: UserSiteAssignRequest, db: Session = Depends(get_db)) -> RoleSitesResponse:
roles_repo = RolesRepository(db)
sites_repo = SitesRepository(db)
systems_repo = SystemsRepository(db)
site_roles_repo = SiteRolesRepository(db)
idp = ProviderAdminService(get_settings())
role = roles_repo.get_by_key(role_key)
if not role:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="role_not_found")
system = systems_repo.get_by_id(role.system_id)
if not system:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="system_not_found")
provider_client_id = str(system.name or "").strip()
if not provider_client_id:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=f"provider_client_id_missing:{system.system_key}")
previous_rows = site_roles_repo.list_role_site_rows(role.id)
previous_site_ids = {site.id for _, site in previous_rows}
site_ids: list[str] = []
for site_key in list(dict.fromkeys(payload.site_keys)):
site = sites_repo.get_by_key(site_key)
if not site:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"site_not_found:{site_key}")
site_ids.append(site.id)
site_roles_repo.set_role_sites(role_id=role.id, site_ids=site_ids, commit=False)
affected_site_ids = previous_site_ids | set(site_ids)
try:
for site_id in affected_site_ids:
site = sites_repo.get_by_id(site_id)
if not site:
continue
site_rows = site_roles_repo.list_site_role_rows(site.id)
_sync_site_client_roles(
idp=idp,
site=site,
site_role_rows=site_rows,
provider_client_ids={provider_client_id},
)
db.commit()
except Exception:
db.rollback()
raise
return list_role_sites(role_key=role_key, db=db)
@router.get("/members", response_model=ListResponse) @router.get("/members", response_model=ListResponse)
def list_members( def list_members(
db: Session = Depends(get_db), db: Session = Depends(get_db),
@@ -654,9 +780,17 @@ def list_members(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
is_active_key = "" if is_active is None else ("1" if is_active else "0")
cache_key = f"admin:members:{keyword or ''}:{is_active_key}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
repo = UsersRepository(db) repo = UsersRepository(db)
rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset) rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset)
return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset) result = ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset)
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/members", response_model=MemberItem) @router.post("/members", response_model=MemberItem)
@@ -853,14 +987,21 @@ def list_api_clients(
limit: int = Query(default=100, ge=1, le=500), limit: int = Query(default=100, ge=1, le=500),
offset: int = Query(default=0, ge=0), offset: int = Query(default=0, ge=0),
) -> ListResponse: ) -> ListResponse:
cache_key = f"admin:api_clients:{keyword or ''}:{status_filter or ''}:{limit}:{offset}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, ListResponse):
return cached
repo = ApiClientsRepository(db) repo = ApiClientsRepository(db)
items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset)
return ListResponse( result = ListResponse(
items=[ApiClientItem.model_validate(i, from_attributes=True) for i in items], items=[ApiClientItem.model_validate(i, from_attributes=True) for i in items],
total=total, total=total,
limit=limit, limit=limit,
offset=offset, offset=offset,
) )
runtime_cache.set(cache_key, result, ttl_seconds=20)
return result
@router.post("/sync/from-provider") @router.post("/sync/from-provider")

View File

@@ -7,11 +7,10 @@ from app.repositories.users_repo import UsersRepository
from app.repositories.user_sites_repo import UserSitesRepository from app.repositories.user_sites_repo import UserSitesRepository
from app.schemas.idp_admin import ProviderEnsureUserRequest, ProviderEnsureUserResponse from app.schemas.idp_admin import ProviderEnsureUserRequest, ProviderEnsureUserResponse
from app.schemas.internal import InternalUpsertUserBySubResponse, InternalUserRoleItem, InternalUserRoleResponse from app.schemas.internal import InternalUpsertUserBySubResponse, InternalUserRoleItem, InternalUserRoleResponse
from app.schemas.permissions import RoleSnapshotResponse
from app.schemas.users import UserUpsertBySubRequest from app.schemas.users import UserUpsertBySubRequest
from app.security.api_client_auth import require_api_client from app.security.api_client_auth import require_api_client
from app.services.idp_admin_service import ProviderAdminService from app.services.idp_admin_service import ProviderAdminService
from app.services.permission_service import PermissionService from app.services.runtime_cache import runtime_cache
router = APIRouter(prefix="/internal", tags=["internal"], dependencies=[Depends(require_api_client)]) router = APIRouter(prefix="/internal", tags=["internal"], dependencies=[Depends(require_api_client)])
@@ -68,8 +67,13 @@ def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, st
@router.get("/users/{user_sub}/roles", response_model=InternalUserRoleResponse) @router.get("/users/{user_sub}/roles", response_model=InternalUserRoleResponse)
def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUserRoleResponse: def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUserRoleResponse:
cache_key = f"internal:user_roles:{user_sub}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, InternalUserRoleResponse):
return cached
rows = _build_user_role_rows(db, user_sub) rows = _build_user_role_rows(db, user_sub)
return InternalUserRoleResponse( result = InternalUserRoleResponse(
user_sub=user_sub, user_sub=user_sub,
roles=[ roles=[
InternalUserRoleItem( InternalUserRoleItem(
@@ -94,15 +98,8 @@ def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUser
) in rows ) in rows
], ],
) )
runtime_cache.set(cache_key, result, ttl_seconds=30)
return result
@router.get("/permissions/{user_sub}/snapshot", response_model=RoleSnapshotResponse)
def get_permission_snapshot(
user_sub: str,
db: Session = Depends(get_db),
) -> RoleSnapshotResponse:
rows = _build_user_role_rows(db, user_sub)
return PermissionService.build_role_snapshot(user_sub=user_sub, rows=rows)
@router.post("/provider/users/ensure", response_model=ProviderEnsureUserResponse) @router.post("/provider/users/ensure", response_model=ProviderEnsureUserResponse)

View File

@@ -9,6 +9,7 @@ from app.schemas.auth import ProviderPrincipal, MeSummaryResponse
from app.schemas.permissions import RoleSnapshotResponse from app.schemas.permissions import RoleSnapshotResponse
from app.security.idp_jwt import require_authenticated_principal from app.security.idp_jwt import require_authenticated_principal
from app.services.permission_service import PermissionService from app.services.permission_service import PermissionService
from app.services.runtime_cache import runtime_cache
router = APIRouter(prefix="/me", tags=["me"]) router = APIRouter(prefix="/me", tags=["me"])
@@ -18,6 +19,10 @@ def get_me(
principal: ProviderPrincipal = Depends(require_authenticated_principal), principal: ProviderPrincipal = Depends(require_authenticated_principal),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> MeSummaryResponse: ) -> MeSummaryResponse:
cache_key = f"me:{principal.sub}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, MeSummaryResponse):
return cached
try: try:
users_repo = UsersRepository(db) users_repo = UsersRepository(db)
user = users_repo.upsert_by_sub( user = users_repo.upsert_by_sub(
@@ -28,13 +33,17 @@ def get_me(
is_active=True, is_active=True,
status="active", status="active",
) )
return MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name) result = MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name)
runtime_cache.set(cache_key, result, ttl_seconds=30)
return result
except SQLAlchemyError: except SQLAlchemyError:
return MeSummaryResponse( result = MeSummaryResponse(
sub=principal.sub, sub=principal.sub,
email=principal.email, email=principal.email,
display_name=principal.name or principal.preferred_username, display_name=principal.name or principal.preferred_username,
) )
runtime_cache.set(cache_key, result, ttl_seconds=15)
return result
@router.get("/permissions/snapshot", response_model=RoleSnapshotResponse) @router.get("/permissions/snapshot", response_model=RoleSnapshotResponse)
@@ -42,6 +51,10 @@ def get_my_permission_snapshot(
principal: ProviderPrincipal = Depends(require_authenticated_principal), principal: ProviderPrincipal = Depends(require_authenticated_principal),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> RoleSnapshotResponse: ) -> RoleSnapshotResponse:
cache_key = f"me:permissions_snapshot:{principal.sub}"
cached = runtime_cache.get(cache_key)
if isinstance(cached, RoleSnapshotResponse):
return cached
try: try:
users_repo = UsersRepository(db) users_repo = UsersRepository(db)
user_sites_repo = UserSitesRepository(db) user_sites_repo = UserSitesRepository(db)
@@ -68,6 +81,10 @@ def get_my_permission_snapshot(
) )
for site, company, role, system in rows for site, company, role, system in rows
] ]
return PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized) result = PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized)
runtime_cache.set(cache_key, result, ttl_seconds=30)
return result
except SQLAlchemyError: except SQLAlchemyError:
return RoleSnapshotResponse(user_sub=principal.sub, roles=[]) result = RoleSnapshotResponse(user_sub=principal.sub, roles=[])
runtime_cache.set(cache_key, result, ttl_seconds=10)
return result

View File

@@ -35,6 +35,12 @@ class Settings(BaseSettings):
public_frontend_origins: Annotated[list[str], NoDecode] = ["https://member.ose.tw"] public_frontend_origins: Annotated[list[str], NoDecode] = ["https://member.ose.tw"]
internal_shared_secret: str = "" internal_shared_secret: str = ""
admin_required_groups: Annotated[list[str], NoDecode] = [] admin_required_groups: Annotated[list[str], NoDecode] = []
member_required_realm_roles: Annotated[list[str], NoDecode] = ["admin", "manager"]
admin_required_realm_roles: Annotated[list[str], NoDecode] = ["admin", "manager"]
cache_backend: str = "memory"
cache_redis_url: str = "redis://127.0.0.1:6379/0"
cache_prefix: str = "memberapi"
cache_default_ttl_seconds: int = 30
@field_validator("public_frontend_origins", mode="before") @field_validator("public_frontend_origins", mode="before")
@classmethod @classmethod
@@ -54,6 +60,15 @@ class Settings(BaseSettings):
return [] return []
return [part.strip() for part in value.split(",") if part.strip()] return [part.strip() for part in value.split(",") if part.strip()]
@field_validator("member_required_realm_roles", "admin_required_realm_roles", mode="before")
@classmethod
def parse_roles_csv(cls, value: str | list[str]) -> list[str]:
if isinstance(value, list):
return [str(v).strip() for v in value if str(v).strip()]
if not value:
return []
return [part.strip() for part in value.split(",") if part.strip()]
@property @property
def database_url(self) -> str: def database_url(self) -> str:
return ( return (

View File

@@ -7,6 +7,7 @@ from app.api.internal_catalog import router as internal_catalog_router
from app.api.internal import router as internal_router from app.api.internal import router as internal_router
from app.api.me import router as me_router from app.api.me import router as me_router
from app.core.config import get_settings from app.core.config import get_settings
from app.services.runtime_cache import runtime_cache
app = FastAPI(title="memberapi.ose.tw", version="0.1.0") app = FastAPI(title="memberapi.ose.tw", version="0.1.0")
@@ -20,6 +21,18 @@ app.add_middleware(
) )
@app.middleware("http")
async def invalidate_runtime_cache_on_cud(request, call_next):
response = await call_next(request)
if (
request.method in {"POST", "PUT", "PATCH", "DELETE"}
and request.url.path.startswith(("/admin", "/internal"))
and response.status_code < 400
):
runtime_cache.bump_revision()
return response
@app.get("/healthz", tags=["health"]) @app.get("/healthz", tags=["health"])
def healthz() -> dict[str, str]: def healthz() -> dict[str, str]:
return {"status": "ok"} return {"status": "ok"}

View File

@@ -30,8 +30,16 @@ class SiteRolesRepository:
) )
return list(self.db.execute(stmt).all()) return list(self.db.execute(stmt).all())
def set_site_roles(self, *, site_id: str, role_ids: list[str]) -> None: def set_site_roles(self, *, site_id: str, role_ids: list[str], commit: bool = True) -> None:
self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id)) self.db.execute(delete(SiteRole).where(SiteRole.site_id == site_id))
for role_id in role_ids: for role_id in role_ids:
self.db.add(SiteRole(site_id=site_id, role_id=role_id)) self.db.add(SiteRole(site_id=site_id, role_id=role_id))
self.db.commit() if commit:
self.db.commit()
def set_role_sites(self, *, role_id: str, site_ids: list[str], commit: bool = True) -> None:
self.db.execute(delete(SiteRole).where(SiteRole.role_id == role_id))
for site_id in site_ids:
self.db.add(SiteRole(site_id=site_id, role_id=role_id))
if commit:
self.db.commit()

View File

@@ -57,6 +57,7 @@ class UsersRepository:
provider_user_id: str | None = None, provider_user_id: str | None = None,
) -> User: ) -> User:
user = self.get_by_sub(user_sub) user = self.get_by_sub(user_sub)
changed = False
if user is None: if user is None:
user = User( user = User(
user_sub=user_sub, user_sub=user_sub,
@@ -68,17 +69,30 @@ class UsersRepository:
status=status, status=status,
) )
self.db.add(user) self.db.add(user)
changed = True
else: else:
if provider_user_id is not None: if provider_user_id is not None and user.provider_user_id != provider_user_id:
user.provider_user_id = provider_user_id user.provider_user_id = provider_user_id
user.username = username changed = True
user.email = email if user.username != username:
user.display_name = display_name user.username = username
user.is_active = is_active changed = True
user.status = status if user.email != email:
user.email = email
changed = True
if user.display_name != display_name:
user.display_name = display_name
changed = True
if user.is_active != is_active:
user.is_active = is_active
changed = True
if user.status != status:
user.status = status
changed = True
self.db.commit() if changed:
self.db.refresh(user) self.db.commit()
self.db.refresh(user)
return user return user
def update_member( def update_member(

View File

@@ -7,6 +7,7 @@ class ProviderPrincipal(BaseModel):
name: str | None = None name: str | None = None
preferred_username: str | None = None preferred_username: str | None = None
groups: list[str] = Field(default_factory=list) groups: list[str] = Field(default_factory=list)
realm_roles: list[str] = Field(default_factory=list)
class MeSummaryResponse(BaseModel): class MeSummaryResponse(BaseModel):

View File

@@ -5,33 +5,27 @@ from app.schemas.auth import ProviderPrincipal
from app.security.idp_jwt import require_authenticated_principal from app.security.idp_jwt import require_authenticated_principal
def _expand_group_aliases(groups: set[str]) -> set[str]: def _normalize_roles(values: set[str]) -> set[str]:
expanded: set[str] = set() normalized: set[str] = set()
for group in groups: for value in values:
value = group.strip().lower() role = value.strip().lower()
if not value: if role:
continue normalized.add(role)
expanded.add(value) return normalized
stripped = value.lstrip("/")
if stripped:
expanded.add(stripped)
if "/" in stripped:
expanded.add(stripped.rsplit("/", 1)[-1])
return expanded
def require_admin_principal( def require_admin_principal(
principal: ProviderPrincipal = Depends(require_authenticated_principal), principal: ProviderPrincipal = Depends(require_authenticated_principal),
) -> ProviderPrincipal: ) -> ProviderPrincipal:
settings = get_settings() settings = get_settings()
required_groups = _expand_group_aliases(set(settings.admin_required_groups)) required_roles = _normalize_roles(set(settings.admin_required_realm_roles))
if not required_groups: if not required_roles:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="admin_policy_not_configured") raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="admin_policy_not_configured")
principal_groups = _expand_group_aliases(set(principal.groups)) principal_roles = _normalize_roles(set(principal.realm_roles))
group_ok = bool(required_groups.intersection(principal_groups)) role_ok = bool(required_roles.intersection(principal_roles))
if not group_ok: if not role_ok:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="admin_forbidden") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="admin_forbidden")
return principal return principal

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
from functools import lru_cache from functools import lru_cache
import logging import logging
import time
import httpx import httpx
import jwt import jwt
@@ -30,6 +31,7 @@ class ProviderTokenVerifier:
admin_realm: str | None, admin_realm: str | None,
admin_client_id: str | None, admin_client_id: str | None,
admin_client_secret: str | None, admin_client_secret: str | None,
member_required_realm_roles: list[str],
) -> None: ) -> None:
self.issuer = issuer.strip() if issuer else None self.issuer = issuer.strip() if issuer else None
self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer) self.jwks_url = jwks_url.strip() if jwks_url else self._infer_jwks_url(self.issuer)
@@ -49,7 +51,20 @@ class ProviderTokenVerifier:
if not self.jwks_url: if not self.jwks_url:
raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required") raise ValueError("KEYCLOAK_JWKS_URL or KEYCLOAK_ISSUER is required")
self._jwk_client = jwt.PyJWKClient(self.jwks_url) self._jwk_client = jwt.PyJWKClient(
self.jwks_url,
cache_jwk_set=True,
lifespan=600,
headers={
"Accept": "application/json",
"User-Agent": "member-ose-backend/1.0",
},
timeout=5,
)
self._admin_token_cached: str | None = None
self._admin_token_expires_at: float = 0
self._principal_cache: dict[str, tuple[float, ProviderPrincipal]] = {}
self.member_required_realm_roles = {r.strip().lower() for r in member_required_realm_roles if r and r.strip()}
@staticmethod @staticmethod
def _infer_introspection_endpoint(issuer: str | None) -> str | None: def _infer_introspection_endpoint(issuer: str | None) -> str | None:
@@ -138,10 +153,14 @@ class ProviderTokenVerifier:
name=name, name=name,
preferred_username=preferred_username, preferred_username=preferred_username,
groups=groups, groups=groups,
realm_roles=principal.realm_roles,
) )
return self._enrich_groups_from_admin(enriched) return self._enrich_groups_from_admin(enriched)
def _get_admin_token(self) -> str | None: def _get_admin_token(self) -> str | None:
now = time.time()
if self._admin_token_cached and now < self._admin_token_expires_at:
return self._admin_token_cached
if ( if (
not self.base_url not self.base_url
or not self.admin_realm or not self.admin_realm
@@ -167,7 +186,14 @@ class ProviderTokenVerifier:
if resp.status_code >= 400: if resp.status_code >= 400:
return None return None
token = resp.json().get("access_token") token = resp.json().get("access_token")
return str(token) if token else None expires_in = resp.json().get("expires_in")
if token:
ttl = int(expires_in) if isinstance(expires_in, int) else 30
# Keep a small buffer to avoid using near-expiry admin token.
self._admin_token_cached = str(token)
self._admin_token_expires_at = max(now + ttl - 15, now + 5)
return self._admin_token_cached
return None
def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal: def _enrich_groups_from_admin(self, principal: ProviderPrincipal) -> ProviderPrincipal:
if principal.groups: if principal.groups:
@@ -210,9 +236,22 @@ class ProviderTokenVerifier:
name=principal.name, name=principal.name,
preferred_username=principal.preferred_username, preferred_username=principal.preferred_username,
groups=groups, groups=groups,
realm_roles=principal.realm_roles,
) )
def _require_member_role(self, principal: ProviderPrincipal) -> None:
if not self.member_required_realm_roles:
return
user_roles = {r.strip().lower() for r in principal.realm_roles if isinstance(r, str) and r.strip()}
if not user_roles.intersection(self.member_required_realm_roles):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="member_forbidden")
def verify_access_token(self, token: str) -> ProviderPrincipal: def verify_access_token(self, token: str) -> ProviderPrincipal:
now = time.time()
cached = self._principal_cache.get(token)
if cached and now < cached[0]:
return cached[1]
try: try:
header = jwt.get_unverified_header(token) header = jwt.get_unverified_header(token)
algorithm = str(header.get("alg", "")).upper() algorithm = str(header.get("alg", "")).upper()
@@ -247,7 +286,7 @@ class ProviderTokenVerifier:
except Exception as exc: except Exception as exc:
claims = self._introspect_token(token) claims = self._introspect_token(token)
if claims: if claims:
logger.warning("jwt verify failed, used introspection fallback: %s", exc) logger.debug("jwt verify failed, used introspection fallback: %s", exc)
else: else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="invalid_bearer_token") from exc
@@ -261,8 +300,31 @@ class ProviderTokenVerifier:
name=claims.get("name"), name=claims.get("name"),
preferred_username=claims.get("preferred_username"), preferred_username=claims.get("preferred_username"),
groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [], groups=[str(g) for g in claims.get("groups", []) if str(g)] if isinstance(claims.get("groups"), list) else [],
realm_roles=[
str(r)
for r in (
claims.get("realm_access", {}).get("roles", [])
if isinstance(claims.get("realm_access"), dict)
else []
)
if str(r)
],
) )
return self._enrich_from_userinfo(principal, token) enriched = self._enrich_from_userinfo(principal, token)
self._require_member_role(enriched)
exp = claims.get("exp")
if isinstance(exp, int):
cache_until = min(float(exp), now + 60)
else:
cache_until = now + 30
if cache_until > now:
self._principal_cache[token] = (cache_until, enriched)
if len(self._principal_cache) > 512:
# Simple bound to avoid unbounded memory growth.
self._principal_cache.clear()
return enriched
@lru_cache @lru_cache
@@ -281,6 +343,7 @@ def _get_verifier() -> ProviderTokenVerifier:
admin_realm=settings.keycloak_admin_realm, admin_realm=settings.keycloak_admin_realm,
admin_client_id=settings.keycloak_admin_client_id, admin_client_id=settings.keycloak_admin_client_id,
admin_client_secret=settings.keycloak_admin_client_secret, admin_client_secret=settings.keycloak_admin_client_secret,
member_required_realm_roles=settings.member_required_realm_roles,
) )

View File

@@ -413,6 +413,65 @@ class ProviderAdminService:
return client_uuid return client_uuid
raise HTTPException(status_code=404, detail="provider_client_not_found") raise HTTPException(status_code=404, detail="provider_client_not_found")
def _get_client_role_representation(self, client: httpx.Client, *, client_uuid: str, role_name: str) -> dict:
resp = client.get(f"/admin/realms/{self.realm}/clients/{client_uuid}/roles/{role_name}")
if resp.status_code == 404:
raise HTTPException(status_code=404, detail=f"provider_role_not_found:{role_name}")
if resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_lookup_failed")
payload = resp.json() if resp.content else {}
if not isinstance(payload, dict):
raise HTTPException(status_code=502, detail="idp_lookup_failed")
return payload
def set_group_client_roles(self, *, group_id: str, provider_client_id: str, role_names: list[str]) -> None:
if not group_id:
raise HTTPException(status_code=400, detail="provider_group_id_required")
desired_names = [name.strip() for name in role_names if isinstance(name, str) and name.strip()]
desired_name_set = set(desired_names)
with self._client() as client:
client_uuid = self._resolve_client_uuid(client, provider_client_id)
current_resp = client.get(f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}")
if current_resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_role_mapping_lookup_failed")
current_payload = current_resp.json() if current_resp.content else []
current_rows = current_payload if isinstance(current_payload, list) else []
current_map: dict[str, dict] = {}
for row in current_rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).strip()
if name:
current_map[name] = row
to_add_names = sorted(desired_name_set - set(current_map.keys()))
to_remove_names = sorted(set(current_map.keys()) - desired_name_set)
if to_add_names:
add_payload = [
self._get_client_role_representation(client, client_uuid=client_uuid, role_name=role_name)
for role_name in to_add_names
]
add_resp = client.post(
f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}",
json=add_payload,
)
if add_resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_role_mapping_add_failed")
if to_remove_names:
remove_payload = [current_map[name] for name in to_remove_names]
remove_resp = client.request(
"DELETE",
f"/admin/realms/{self.realm}/groups/{group_id}/role-mappings/clients/{client_uuid}",
json=remove_payload,
)
if remove_resp.status_code >= 400:
raise HTTPException(status_code=502, detail="idp_group_role_mapping_remove_failed")
def ensure_client_role( def ensure_client_role(
self, self,
*, *,

View File

@@ -0,0 +1,175 @@
from __future__ import annotations
from dataclasses import dataclass
from hashlib import sha256
import logging
import pickle
from threading import RLock
import time
from typing import Callable, Protocol, TypeVar
from app.core.config import get_settings
try:
import redis
except Exception: # pragma: no cover - optional dependency in local dev
redis = None
T = TypeVar("T")
logger = logging.getLogger(__name__)
class CacheBackend(Protocol):
def get(self, key: str) -> object | None: ...
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ...
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: ...
def bump_revision(self) -> int: ...
def revision(self) -> int: ...
@dataclass
class _CacheEntry:
value: object
expires_at: float
revision: int
class MemoryRuntimeCache:
"""Simple in-memory cache for local/single-instance environments."""
def __init__(self) -> None:
self._lock = RLock()
self._revision = 0
self._entries: dict[str, _CacheEntry] = {}
def get(self, key: str) -> object | None:
now = time.time()
with self._lock:
entry = self._entries.get(key)
if not entry:
return None
if entry.expires_at <= now or entry.revision != self._revision:
self._entries.pop(key, None)
return None
return entry.value
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
now = time.time()
with self._lock:
self._entries[key] = _CacheEntry(
value=value,
expires_at=now + max(ttl_seconds, 1),
revision=self._revision,
)
if len(self._entries) > 2000:
self._entries.clear()
return value
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
cached = self.get(key)
if cached is not None:
return cached # type: ignore[return-value]
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
def bump_revision(self) -> int:
with self._lock:
self._revision += 1
if self._revision > 1_000_000_000:
self._revision = 1
self._entries.clear()
return self._revision
def revision(self) -> int:
with self._lock:
return self._revision
class RedisRuntimeCache:
"""Redis-backed cache for multi-instance deployments."""
def __init__(self, *, redis_url: str, prefix: str, default_ttl_seconds: int = 30) -> None:
if redis is None:
raise RuntimeError("redis_package_not_installed")
self._redis = redis.Redis.from_url(redis_url, decode_responses=False)
self._prefix = prefix.strip() or "memberapi"
self._default_ttl_seconds = max(int(default_ttl_seconds), 1)
self._revision_key = f"{self._prefix}:cache:revision"
self._rev_cache_value = 0
self._rev_cache_expires_at = 0.0
def _cache_key(self, key: str, revision: int) -> str:
key_hash = sha256(key.encode("utf-8")).hexdigest()
return f"{self._prefix}:cache:{revision}:{key_hash}"
def _get_revision_cached(self) -> int:
now = time.time()
if now < self._rev_cache_expires_at:
return self._rev_cache_value
try:
raw = self._redis.get(self._revision_key)
value = int(raw) if raw else 0
except Exception:
return 0
self._rev_cache_value = value
self._rev_cache_expires_at = now + 1.0
return value
def get(self, key: str) -> object | None:
try:
revision = self._get_revision_cached()
raw = self._redis.get(self._cache_key(key, revision))
if raw is None:
return None
return pickle.loads(raw)
except Exception:
return None
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
ttl = max(int(ttl_seconds), 1) if ttl_seconds else self._default_ttl_seconds
try:
revision = self._get_revision_cached()
self._redis.setex(self._cache_key(key, revision), ttl, pickle.dumps(value))
except Exception:
# Keep request path healthy even when Redis has issues.
pass
return value
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
cached = self.get(key)
if cached is not None:
return cached # type: ignore[return-value]
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
def bump_revision(self) -> int:
try:
value = int(self._redis.incr(self._revision_key))
self._rev_cache_value = value
self._rev_cache_expires_at = time.time() + 1.0
return value
except Exception:
# Fail-open: keep app usable; caller still succeeds.
return self._get_revision_cached()
def revision(self) -> int:
return self._get_revision_cached()
def _build_runtime_cache() -> CacheBackend:
settings = get_settings()
backend = (settings.cache_backend or "memory").strip().lower()
if backend == "redis":
try:
cache = RedisRuntimeCache(
redis_url=settings.cache_redis_url,
prefix=settings.cache_prefix,
default_ttl_seconds=settings.cache_default_ttl_seconds,
)
logger.info("runtime cache backend: redis")
return cache
except Exception as exc:
logger.warning("redis cache unavailable, fallback to memory: %s", exc)
logger.info("runtime cache backend: memory")
return MemoryRuntimeCache()
runtime_cache: CacheBackend = _build_runtime_cache()

View File

@@ -13,6 +13,7 @@ dependencies = [
"passlib[bcrypt]>=1.7.4", "passlib[bcrypt]>=1.7.4",
"pyjwt[crypto]>=2.10.1", "pyjwt[crypto]>=2.10.1",
"httpx>=0.28.1", "httpx>=0.28.1",
"redis>=5.2.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]