98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
from __future__ import annotations
|
||
|
||
from typing import Any
|
||
|
||
from fastapi import HTTPException
|
||
|
||
from app.domain.keys import generate_variant_key
|
||
from app.domain.mappers import to_variant
|
||
from app.repositories.directus.variants import VariantRepository
|
||
from app.schemas.admin import VariantCreate, VariantRead, VariantUpdate, variant_to_read_model
|
||
|
||
|
||
class VariantService:
|
||
"""Application service for variant management."""
|
||
|
||
def __init__(self, repository: VariantRepository | None = None) -> None:
|
||
self.repository = repository or VariantRepository()
|
||
|
||
async def _assert_weights_sum_100(
|
||
self,
|
||
experiment_id: str,
|
||
new_weight: int,
|
||
exclude_variant_id: str | None = None,
|
||
access_token: str | None = None,
|
||
) -> None:
|
||
items = await self.repository.list(
|
||
params={"filter[experiment_id][_eq]": experiment_id},
|
||
access_token=access_token,
|
||
)
|
||
existing_sum = sum(
|
||
int(item.get("traffic_weight", 0))
|
||
for item in items
|
||
if str(item["id"]) != exclude_variant_id
|
||
)
|
||
if existing_sum + new_weight != 100:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"所有變體的流量權重加總必須等於 100(目前其他變體合計 {existing_sum},本次輸入 {new_weight})",
|
||
)
|
||
|
||
async def list_variants(
|
||
self,
|
||
experiment_id: str | None = None,
|
||
access_token: str | None = None,
|
||
) -> list[VariantRead]:
|
||
params: dict[str, Any] | None = None
|
||
if experiment_id:
|
||
params = {"filter[experiment_id][_eq]": experiment_id}
|
||
|
||
items = await self.repository.list(params=params, access_token=access_token)
|
||
return [variant_to_read_model(to_variant(item)) for item in items]
|
||
|
||
async def get_variant(
|
||
self,
|
||
variant_id: str,
|
||
access_token: str | None = None,
|
||
) -> VariantRead | None:
|
||
item = await self.repository.get(variant_id, access_token=access_token)
|
||
if not item:
|
||
return None
|
||
return variant_to_read_model(to_variant(item))
|
||
|
||
async def create_variant(
|
||
self,
|
||
payload: VariantCreate,
|
||
access_token: str | None = None,
|
||
) -> VariantRead:
|
||
await self._assert_weights_sum_100(
|
||
experiment_id=payload.experiment_id,
|
||
new_weight=payload.traffic_weight,
|
||
access_token=access_token,
|
||
)
|
||
data: dict[str, Any] = payload.model_dump(exclude_none=True)
|
||
data["variant_key"] = generate_variant_key()
|
||
item = await self.repository.create(data, access_token=access_token)
|
||
return variant_to_read_model(to_variant(item))
|
||
|
||
async def update_variant(
|
||
self,
|
||
variant_id: str,
|
||
payload: VariantUpdate,
|
||
access_token: str | None = None,
|
||
) -> VariantRead | None:
|
||
data: dict[str, Any] = payload.model_dump(exclude_unset=True)
|
||
if "traffic_weight" in data:
|
||
existing = await self.repository.get(variant_id, access_token=access_token)
|
||
if existing:
|
||
await self._assert_weights_sum_100(
|
||
experiment_id=str(existing["experiment_id"]),
|
||
new_weight=data["traffic_weight"],
|
||
exclude_variant_id=variant_id,
|
||
access_token=access_token,
|
||
)
|
||
item = await self.repository.update(variant_id, data, access_token=access_token)
|
||
if not item:
|
||
return None
|
||
return variant_to_read_model(to_variant(item))
|