Files
member-backend/app/services/runtime_cache.py

176 lines
5.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from hashlib import sha256
import logging
import pickle
from threading import RLock
import time
from typing import Callable, Protocol, TypeVar
from app.core.config import get_settings
try:
import redis
except Exception: # pragma: no cover - optional dependency in local dev
redis = None
T = TypeVar("T")
logger = logging.getLogger(__name__)
class CacheBackend(Protocol):
def get(self, key: str) -> object | None: ...
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ...
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: ...
def bump_revision(self) -> int: ...
def revision(self) -> int: ...
@dataclass
class _CacheEntry:
value: object
expires_at: float
revision: int
class MemoryRuntimeCache:
"""Simple in-memory cache for local/single-instance environments."""
def __init__(self) -> None:
self._lock = RLock()
self._revision = 0
self._entries: dict[str, _CacheEntry] = {}
def get(self, key: str) -> object | None:
now = time.time()
with self._lock:
entry = self._entries.get(key)
if not entry:
return None
if entry.expires_at <= now or entry.revision != self._revision:
self._entries.pop(key, None)
return None
return entry.value
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
now = time.time()
with self._lock:
self._entries[key] = _CacheEntry(
value=value,
expires_at=now + max(ttl_seconds, 1),
revision=self._revision,
)
if len(self._entries) > 2000:
self._entries.clear()
return value
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
cached = self.get(key)
if cached is not None:
return cached # type: ignore[return-value]
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
def bump_revision(self) -> int:
with self._lock:
self._revision += 1
if self._revision > 1_000_000_000:
self._revision = 1
self._entries.clear()
return self._revision
def revision(self) -> int:
with self._lock:
return self._revision
class RedisRuntimeCache:
"""Redis-backed cache for multi-instance deployments."""
def __init__(self, *, redis_url: str, prefix: str, default_ttl_seconds: int = 30) -> None:
if redis is None:
raise RuntimeError("redis_package_not_installed")
self._redis = redis.Redis.from_url(redis_url, decode_responses=False)
self._prefix = prefix.strip() or "memberapi"
self._default_ttl_seconds = max(int(default_ttl_seconds), 1)
self._revision_key = f"{self._prefix}:cache:revision"
self._rev_cache_value = 0
self._rev_cache_expires_at = 0.0
def _cache_key(self, key: str, revision: int) -> str:
key_hash = sha256(key.encode("utf-8")).hexdigest()
return f"{self._prefix}:cache:{revision}:{key_hash}"
def _get_revision_cached(self) -> int:
now = time.time()
if now < self._rev_cache_expires_at:
return self._rev_cache_value
try:
raw = self._redis.get(self._revision_key)
value = int(raw) if raw else 0
except Exception:
return 0
self._rev_cache_value = value
self._rev_cache_expires_at = now + 1.0
return value
def get(self, key: str) -> object | None:
try:
revision = self._get_revision_cached()
raw = self._redis.get(self._cache_key(key, revision))
if raw is None:
return None
return pickle.loads(raw)
except Exception:
return None
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
ttl = max(int(ttl_seconds), 1) if ttl_seconds else self._default_ttl_seconds
try:
revision = self._get_revision_cached()
self._redis.setex(self._cache_key(key, revision), ttl, pickle.dumps(value))
except Exception:
# Keep request path healthy even when Redis has issues.
pass
return value
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
cached = self.get(key)
if cached is not None:
return cached # type: ignore[return-value]
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
def bump_revision(self) -> int:
try:
value = int(self._redis.incr(self._revision_key))
self._rev_cache_value = value
self._rev_cache_expires_at = time.time() + 1.0
return value
except Exception:
# Fail-open: keep app usable; caller still succeeds.
return self._get_revision_cached()
def revision(self) -> int:
return self._get_revision_cached()
def _build_runtime_cache() -> CacheBackend:
settings = get_settings()
backend = (settings.cache_backend or "memory").strip().lower()
if backend == "redis":
try:
cache = RedisRuntimeCache(
redis_url=settings.cache_redis_url,
prefix=settings.cache_prefix,
default_ttl_seconds=settings.cache_default_ttl_seconds,
)
logger.info("runtime cache backend: redis")
return cache
except Exception as exc:
logger.warning("redis cache unavailable, fallback to memory: %s", exc)
logger.info("runtime cache backend: memory")
return MemoryRuntimeCache()
runtime_cache: CacheBackend = _build_runtime_cache()