diff --git a/app/api/admin_catalog.py b/app/api/admin_catalog.py index d821dbc..5f70e9a 100644 --- a/app/api/admin_catalog.py +++ b/app/api/admin_catalog.py @@ -53,6 +53,7 @@ from app.security.admin_guard import require_admin_principal from app.security.api_client_auth import hash_api_key from app.services.idp_admin_service import ProviderAdminService from app.services.idp_catalog_sync import sync_from_provider +from app.services.runtime_cache import runtime_cache from app.core.config import get_settings router = APIRouter( @@ -161,9 +162,16 @@ def list_companies( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + cache_key = f"admin:companies:{keyword or ''}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + repo = CompaniesRepository(db) items, total = repo.list(keyword=keyword, limit=limit, offset=offset) - return ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset) + result = ListResponse(items=[_company_item(i) for i in items], total=total, limit=limit, offset=offset) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/companies", response_model=CompanyItem) @@ -251,6 +259,11 @@ def list_sites( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + cache_key = f"admin:sites:{keyword or ''}:{company_key or ''}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + companies_repo = CompaniesRepository(db) sites_repo = SitesRepository(db) company_id = None @@ -264,7 +277,9 @@ def list_sites( company_map = {c.id: c for c in companies} items, total = sites_repo.list(keyword=keyword, company_id=company_id, limit=limit, offset=offset) response_items = [_site_item(i, company_map[i.company_id]) for i in items if i.company_id in company_map] - return ListResponse(items=response_items, total=total, limit=limit, offset=offset) + result = ListResponse(items=response_items, total=total, limit=limit, offset=offset) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/sites", response_model=SiteItem) @@ -375,9 +390,16 @@ def list_systems( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + cache_key = f"admin:systems:{keyword or ''}:{status_filter or ''}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + repo = SystemsRepository(db) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) - return ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset) + result = ListResponse(items=[_system_item(i) for i in items], total=total, limit=limit, offset=offset) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/systems", response_model=SystemItem) @@ -404,6 +426,11 @@ def list_roles( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + cache_key = f"admin:roles:{keyword or ''}:{system_key or ''}:{status_filter or ''}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + systems_repo = SystemsRepository(db) roles_repo = RolesRepository(db) @@ -433,7 +460,9 @@ def list_roles( for row in rows if row.system_id in system_map ] - return ListResponse(items=items, total=total, limit=limit, offset=offset) + result = ListResponse(items=items, total=total, limit=limit, offset=offset) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/roles", response_model=RoleItem) @@ -751,9 +780,17 @@ def list_members( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + is_active_key = "" if is_active is None else ("1" if is_active else "0") + cache_key = f"admin:members:{keyword or ''}:{is_active_key}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + repo = UsersRepository(db) rows, total = repo.list(keyword=keyword, is_active=is_active, limit=limit, offset=offset) - return ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset) + result = ListResponse(items=[_member_item(r) for r in rows], total=total, limit=limit, offset=offset) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/members", response_model=MemberItem) @@ -950,14 +987,21 @@ def list_api_clients( limit: int = Query(default=100, ge=1, le=500), offset: int = Query(default=0, ge=0), ) -> ListResponse: + cache_key = f"admin:api_clients:{keyword or ''}:{status_filter or ''}:{limit}:{offset}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, ListResponse): + return cached + repo = ApiClientsRepository(db) items, total = repo.list(keyword=keyword, status=status_filter, limit=limit, offset=offset) - return ListResponse( + result = ListResponse( items=[ApiClientItem.model_validate(i, from_attributes=True) for i in items], total=total, limit=limit, offset=offset, ) + runtime_cache.set(cache_key, result, ttl_seconds=20) + return result @router.post("/sync/from-provider") diff --git a/app/api/internal.py b/app/api/internal.py index ef713ba..00b8437 100644 --- a/app/api/internal.py +++ b/app/api/internal.py @@ -12,6 +12,7 @@ from app.schemas.users import UserUpsertBySubRequest from app.security.api_client_auth import require_api_client from app.services.idp_admin_service import ProviderAdminService from app.services.permission_service import PermissionService +from app.services.runtime_cache import runtime_cache router = APIRouter(prefix="/internal", tags=["internal"], dependencies=[Depends(require_api_client)]) @@ -68,8 +69,13 @@ def _build_user_role_rows(db: Session, user_sub: str) -> list[tuple[str, str, st @router.get("/users/{user_sub}/roles", response_model=InternalUserRoleResponse) def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUserRoleResponse: + cache_key = f"internal:user_roles:{user_sub}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, InternalUserRoleResponse): + return cached + rows = _build_user_role_rows(db, user_sub) - return InternalUserRoleResponse( + result = InternalUserRoleResponse( user_sub=user_sub, roles=[ InternalUserRoleItem( @@ -94,6 +100,8 @@ def get_user_roles(user_sub: str, db: Session = Depends(get_db)) -> InternalUser ) in rows ], ) + runtime_cache.set(cache_key, result, ttl_seconds=30) + return result @router.get("/permissions/{user_sub}/snapshot", response_model=RoleSnapshotResponse) @@ -101,8 +109,15 @@ def get_permission_snapshot( user_sub: str, db: Session = Depends(get_db), ) -> RoleSnapshotResponse: + cache_key = f"internal:permissions_snapshot:{user_sub}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, RoleSnapshotResponse): + return cached + rows = _build_user_role_rows(db, user_sub) - return PermissionService.build_role_snapshot(user_sub=user_sub, rows=rows) + result = PermissionService.build_role_snapshot(user_sub=user_sub, rows=rows) + runtime_cache.set(cache_key, result, ttl_seconds=30) + return result @router.post("/provider/users/ensure", response_model=ProviderEnsureUserResponse) diff --git a/app/api/me.py b/app/api/me.py index ac1a106..432ddd5 100644 --- a/app/api/me.py +++ b/app/api/me.py @@ -9,6 +9,7 @@ from app.schemas.auth import ProviderPrincipal, MeSummaryResponse from app.schemas.permissions import RoleSnapshotResponse from app.security.idp_jwt import require_authenticated_principal from app.services.permission_service import PermissionService +from app.services.runtime_cache import runtime_cache router = APIRouter(prefix="/me", tags=["me"]) @@ -18,6 +19,10 @@ def get_me( principal: ProviderPrincipal = Depends(require_authenticated_principal), db: Session = Depends(get_db), ) -> MeSummaryResponse: + cache_key = f"me:{principal.sub}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, MeSummaryResponse): + return cached try: users_repo = UsersRepository(db) user = users_repo.upsert_by_sub( @@ -28,13 +33,17 @@ def get_me( is_active=True, status="active", ) - return MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name) + result = MeSummaryResponse(sub=user.user_sub, email=user.email, display_name=user.display_name) + runtime_cache.set(cache_key, result, ttl_seconds=30) + return result except SQLAlchemyError: - return MeSummaryResponse( + result = MeSummaryResponse( sub=principal.sub, email=principal.email, display_name=principal.name or principal.preferred_username, ) + runtime_cache.set(cache_key, result, ttl_seconds=15) + return result @router.get("/permissions/snapshot", response_model=RoleSnapshotResponse) @@ -42,6 +51,10 @@ def get_my_permission_snapshot( principal: ProviderPrincipal = Depends(require_authenticated_principal), db: Session = Depends(get_db), ) -> RoleSnapshotResponse: + cache_key = f"me:permissions_snapshot:{principal.sub}" + cached = runtime_cache.get(cache_key) + if isinstance(cached, RoleSnapshotResponse): + return cached try: users_repo = UsersRepository(db) user_sites_repo = UserSitesRepository(db) @@ -68,6 +81,10 @@ def get_my_permission_snapshot( ) for site, company, role, system in rows ] - return PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized) + result = PermissionService.build_role_snapshot(user_sub=principal.sub, rows=serialized) + runtime_cache.set(cache_key, result, ttl_seconds=30) + return result except SQLAlchemyError: - return RoleSnapshotResponse(user_sub=principal.sub, roles=[]) + result = RoleSnapshotResponse(user_sub=principal.sub, roles=[]) + runtime_cache.set(cache_key, result, ttl_seconds=10) + return result diff --git a/app/main.py b/app/main.py index 4ae1a2c..623c046 100644 --- a/app/main.py +++ b/app/main.py @@ -7,6 +7,7 @@ from app.api.internal_catalog import router as internal_catalog_router from app.api.internal import router as internal_router from app.api.me import router as me_router from app.core.config import get_settings +from app.services.runtime_cache import runtime_cache app = FastAPI(title="memberapi.ose.tw", version="0.1.0") @@ -20,6 +21,18 @@ app.add_middleware( ) +@app.middleware("http") +async def invalidate_runtime_cache_on_cud(request, call_next): + response = await call_next(request) + if ( + request.method in {"POST", "PUT", "PATCH", "DELETE"} + and request.url.path.startswith(("/admin", "/internal")) + and response.status_code < 400 + ): + runtime_cache.bump_revision() + return response + + @app.get("/healthz", tags=["health"]) def healthz() -> dict[str, str]: return {"status": "ok"} diff --git a/app/repositories/users_repo.py b/app/repositories/users_repo.py index 6f346ec..8030add 100644 --- a/app/repositories/users_repo.py +++ b/app/repositories/users_repo.py @@ -57,6 +57,7 @@ class UsersRepository: provider_user_id: str | None = None, ) -> User: user = self.get_by_sub(user_sub) + changed = False if user is None: user = User( user_sub=user_sub, @@ -68,17 +69,30 @@ class UsersRepository: status=status, ) self.db.add(user) + changed = True else: - if provider_user_id is not None: + if provider_user_id is not None and user.provider_user_id != provider_user_id: user.provider_user_id = provider_user_id - user.username = username - user.email = email - user.display_name = display_name - user.is_active = is_active - user.status = status + changed = True + if user.username != username: + user.username = username + changed = True + if user.email != email: + user.email = email + changed = True + if user.display_name != display_name: + user.display_name = display_name + changed = True + if user.is_active != is_active: + user.is_active = is_active + changed = True + if user.status != status: + user.status = status + changed = True - self.db.commit() - self.db.refresh(user) + if changed: + self.db.commit() + self.db.refresh(user) return user def update_member( diff --git a/app/services/runtime_cache.py b/app/services/runtime_cache.py new file mode 100644 index 0000000..b44bbc9 --- /dev/null +++ b/app/services/runtime_cache.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from dataclasses import dataclass +from threading import RLock +import time +from typing import Callable, TypeVar + +T = TypeVar("T") + + +@dataclass +class _CacheEntry: + value: object + expires_at: float + revision: int + + +class RuntimeCache: + """Simple in-memory cache for local/prototype use. + + Cache is globally invalidated by `bump_revision()` which we call after CUD. + """ + + def __init__(self) -> None: + self._lock = RLock() + self._revision = 0 + self._entries: dict[str, _CacheEntry] = {} + + def get(self, key: str) -> object | None: + now = time.time() + with self._lock: + entry = self._entries.get(key) + if not entry: + return None + if entry.expires_at <= now or entry.revision != self._revision: + self._entries.pop(key, None) + return None + return entry.value + + def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: + now = time.time() + with self._lock: + self._entries[key] = _CacheEntry( + value=value, + expires_at=now + max(ttl_seconds, 1), + revision=self._revision, + ) + if len(self._entries) > 2000: + self._entries.clear() + return value + + def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: + cached = self.get(key) + if cached is not None: + return cached # type: ignore[return-value] + return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value] + + def bump_revision(self) -> int: + with self._lock: + self._revision += 1 + if self._revision > 1_000_000_000: + self._revision = 1 + self._entries.clear() + return self._revision + + def revision(self) -> int: + with self._lock: + return self._revision + + +runtime_cache = RuntimeCache()