176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from hashlib import sha256
|
|
import logging
|
|
import pickle
|
|
from threading import RLock
|
|
import time
|
|
from typing import Callable, Protocol, TypeVar
|
|
|
|
from app.core.config import get_settings
|
|
|
|
try:
|
|
import redis
|
|
except Exception: # pragma: no cover - optional dependency in local dev
|
|
redis = None
|
|
|
|
T = TypeVar("T")
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CacheBackend(Protocol):
|
|
def get(self, key: str) -> object | None: ...
|
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object: ...
|
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T: ...
|
|
def bump_revision(self) -> int: ...
|
|
def revision(self) -> int: ...
|
|
|
|
|
|
@dataclass
|
|
class _CacheEntry:
|
|
value: object
|
|
expires_at: float
|
|
revision: int
|
|
|
|
|
|
class MemoryRuntimeCache:
|
|
"""Simple in-memory cache for local/single-instance environments."""
|
|
|
|
def __init__(self) -> None:
|
|
self._lock = RLock()
|
|
self._revision = 0
|
|
self._entries: dict[str, _CacheEntry] = {}
|
|
|
|
def get(self, key: str) -> object | None:
|
|
now = time.time()
|
|
with self._lock:
|
|
entry = self._entries.get(key)
|
|
if not entry:
|
|
return None
|
|
if entry.expires_at <= now or entry.revision != self._revision:
|
|
self._entries.pop(key, None)
|
|
return None
|
|
return entry.value
|
|
|
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
|
|
now = time.time()
|
|
with self._lock:
|
|
self._entries[key] = _CacheEntry(
|
|
value=value,
|
|
expires_at=now + max(ttl_seconds, 1),
|
|
revision=self._revision,
|
|
)
|
|
if len(self._entries) > 2000:
|
|
self._entries.clear()
|
|
return value
|
|
|
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
|
|
cached = self.get(key)
|
|
if cached is not None:
|
|
return cached # type: ignore[return-value]
|
|
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
|
|
|
|
def bump_revision(self) -> int:
|
|
with self._lock:
|
|
self._revision += 1
|
|
if self._revision > 1_000_000_000:
|
|
self._revision = 1
|
|
self._entries.clear()
|
|
return self._revision
|
|
|
|
def revision(self) -> int:
|
|
with self._lock:
|
|
return self._revision
|
|
|
|
|
|
class RedisRuntimeCache:
|
|
"""Redis-backed cache for multi-instance deployments."""
|
|
|
|
def __init__(self, *, redis_url: str, prefix: str, default_ttl_seconds: int = 30) -> None:
|
|
if redis is None:
|
|
raise RuntimeError("redis_package_not_installed")
|
|
self._redis = redis.Redis.from_url(redis_url, decode_responses=False)
|
|
self._prefix = prefix.strip() or "memberapi"
|
|
self._default_ttl_seconds = max(int(default_ttl_seconds), 1)
|
|
self._revision_key = f"{self._prefix}:cache:revision"
|
|
self._rev_cache_value = 0
|
|
self._rev_cache_expires_at = 0.0
|
|
|
|
def _cache_key(self, key: str, revision: int) -> str:
|
|
key_hash = sha256(key.encode("utf-8")).hexdigest()
|
|
return f"{self._prefix}:cache:{revision}:{key_hash}"
|
|
|
|
def _get_revision_cached(self) -> int:
|
|
now = time.time()
|
|
if now < self._rev_cache_expires_at:
|
|
return self._rev_cache_value
|
|
try:
|
|
raw = self._redis.get(self._revision_key)
|
|
value = int(raw) if raw else 0
|
|
except Exception:
|
|
return 0
|
|
self._rev_cache_value = value
|
|
self._rev_cache_expires_at = now + 1.0
|
|
return value
|
|
|
|
def get(self, key: str) -> object | None:
|
|
try:
|
|
revision = self._get_revision_cached()
|
|
raw = self._redis.get(self._cache_key(key, revision))
|
|
if raw is None:
|
|
return None
|
|
return pickle.loads(raw)
|
|
except Exception:
|
|
return None
|
|
|
|
def set(self, key: str, value: object, ttl_seconds: int = 30) -> object:
|
|
ttl = max(int(ttl_seconds), 1) if ttl_seconds else self._default_ttl_seconds
|
|
try:
|
|
revision = self._get_revision_cached()
|
|
self._redis.setex(self._cache_key(key, revision), ttl, pickle.dumps(value))
|
|
except Exception:
|
|
# Keep request path healthy even when Redis has issues.
|
|
pass
|
|
return value
|
|
|
|
def get_or_set(self, key: str, factory: Callable[[], T], ttl_seconds: int = 30) -> T:
|
|
cached = self.get(key)
|
|
if cached is not None:
|
|
return cached # type: ignore[return-value]
|
|
return self.set(key, factory(), ttl_seconds=ttl_seconds) # type: ignore[return-value]
|
|
|
|
def bump_revision(self) -> int:
|
|
try:
|
|
value = int(self._redis.incr(self._revision_key))
|
|
self._rev_cache_value = value
|
|
self._rev_cache_expires_at = time.time() + 1.0
|
|
return value
|
|
except Exception:
|
|
# Fail-open: keep app usable; caller still succeeds.
|
|
return self._get_revision_cached()
|
|
|
|
def revision(self) -> int:
|
|
return self._get_revision_cached()
|
|
|
|
|
|
def _build_runtime_cache() -> CacheBackend:
|
|
settings = get_settings()
|
|
backend = (settings.cache_backend or "memory").strip().lower()
|
|
if backend == "redis":
|
|
try:
|
|
cache = RedisRuntimeCache(
|
|
redis_url=settings.cache_redis_url,
|
|
prefix=settings.cache_prefix,
|
|
default_ttl_seconds=settings.cache_default_ttl_seconds,
|
|
)
|
|
logger.info("runtime cache backend: redis")
|
|
return cache
|
|
except Exception as exc:
|
|
logger.warning("redis cache unavailable, fallback to memory: %s", exc)
|
|
logger.info("runtime cache backend: memory")
|
|
return MemoryRuntimeCache()
|
|
|
|
|
|
runtime_cache: CacheBackend = _build_runtime_cache()
|