from __future__ import annotations from dataclasses import dataclass from hashlib import sha256 import logging import pickle from threading import RLock import time from typing import Callable, Protocol, TypeVar from app.core.config import get_settings try: import redis except Exception: # pragma: no cover - optional dependency in local dev redis = None T = TypeVar("T") logger = logging.getLogger(__name__) class CacheBackend(Protocol): def get(self, key: str) -> object | None: ... def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ... def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: ... def bump_revision(self) -> int: ... def revision(self) -> int: ... @dataclass class _CacheEntry: value: object expires_at: float revision: int class MemoryRuntimeCache: """Simple in-memory cache for local/single-instance environments.""" def __init__(self) -> None: self._lock = RLock() self._revision = 0 self._entries: dict[str, _CacheEntry] = {} def get(self, key: str) -> object | None: now = time.time() with self._lock: entry = self._entries.get(key) if not entry: return None if entry.expires_at <= now or entry.revision != self._revision: self._entries.pop(key, None) return None return entry.value def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: now = time.time() with self._lock: self._entries[key] = _CacheEntry( value=value, expires_at=now + max(ttl_seconds, 1), revision=self._revision, ) if len(self._entries) > 2000: self._entries.clear() return value def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: cached = self.get(key) if cached is not None: return cached # type: ignore[return-value] return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value] def bump_revision(self) -> int: with self._lock: self._revision += 1 if self._revision > 1_000_000_000: self._revision = 1 self._entries.clear() return self._revision def revision(self) -> int: with self._lock: return self._revision class RedisRuntimeCache: """Redis-backed cache for multi-instance deployments.""" def __init__(self, *, redis_url: str, prefix: str, default_ttl_seconds: int = 30) -> None: if redis is None: raise RuntimeError("redis_package_not_installed") self._redis = redis.Redis.from_url(redis_url, decode_responses=False) self._prefix = prefix.strip() or "memberapi" self._default_ttl_seconds = max(int(default_ttl_seconds), 1) self._revision_key = f"{self._prefix}:cache:revision" self._rev_cache_value = 0 self._rev_cache_expires_at = 0.0 def _cache_key(self, key: str, revision: int) -> str: key_hash = sha256(key.encode("utf-8")).hexdigest() return f"{self._prefix}:cache:{revision}:{key_hash}" def _get_revision_cached(self) -> int: now = time.time() if now < self._rev_cache_expires_at: return self._rev_cache_value try: raw = self._redis.get(self._revision_key) value = int(raw) if raw else 0 except Exception: return 0 self._rev_cache_value = value self._rev_cache_expires_at = now + 1.0 return value def get(self, key: str) -> object | None: try: revision = self._get_revision_cached() raw = self._redis.get(self._cache_key(key, revision)) if raw is None: return None return pickle.loads(raw) except Exception: return None def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ttl = max(int(ttl_seconds), 1) if ttl_seconds else self._default_ttl_seconds try: revision = self._get_revision_cached() self._redis.setex(self._cache_key(key, revision), ttl, pickle.dumps(value)) except Exception: # Keep request path healthy even when Redis has issues. pass return value def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: cached = self.get(key) if cached is not None: return cached # type: ignore[return-value] return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value] def bump_revision(self) -> int: try: value = int(self._redis.incr(self._revision_key)) self._rev_cache_value = value self._rev_cache_expires_at = time.time() + 1.0 return value except Exception: # Fail-open: keep app usable; caller still succeeds. return self._get_revision_cached() def revision(self) -> int: return self._get_revision_cached() def _build_runtime_cache() -> CacheBackend: settings = get_settings() backend = (settings.cache_backend or "memory").strip().lower() if backend == "redis": try: cache = RedisRuntimeCache( redis_url=settings.cache_redis_url, prefix=settings.cache_prefix, default_ttl_seconds=settings.cache_default_ttl_seconds, ) logger.info("runtime cache backend: redis") return cache except Exception as exc: logger.warning("redis cache unavailable, fallback to memory: %s", exc) logger.info("runtime cache backend: memory") return MemoryRuntimeCache() runtime_cache: CacheBackend = _build_runtime_cache()