Files
2026-03-23 20:23:58 +08:00

98 lines
3.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
from typing import Any
from fastapi import HTTPException
from app.domain.keys import generate_variant_key
from app.domain.mappers import to_variant
from app.repositories.directus.variants import VariantRepository
from app.schemas.admin import VariantCreate, VariantRead, VariantUpdate, variant_to_read_model
class VariantService:
"""Application service for variant management."""
def __init__(self, repository: VariantRepository | None = None) -> None:
self.repository = repository or VariantRepository()
async def _assert_weights_sum_100(
self,
experiment_id: str,
new_weight: int,
exclude_variant_id: str | None = None,
access_token: str | None = None,
) -> None:
items = await self.repository.list(
params={"filter[experiment_id][_eq]": experiment_id},
access_token=access_token,
)
existing_sum = sum(
int(item.get("traffic_weight", 0))
for item in items
if str(item["id"]) != exclude_variant_id
)
if existing_sum + new_weight != 100:
raise HTTPException(
status_code=400,
detail=f"所有變體的流量權重加總必須等於 100目前其他變體合計 {existing_sum},本次輸入 {new_weight}",
)
async def list_variants(
self,
experiment_id: str | None = None,
access_token: str | None = None,
) -> list[VariantRead]:
params: dict[str, Any] | None = None
if experiment_id:
params = {"filter[experiment_id][_eq]": experiment_id}
items = await self.repository.list(params=params, access_token=access_token)
return [variant_to_read_model(to_variant(item)) for item in items]
async def get_variant(
self,
variant_id: str,
access_token: str | None = None,
) -> VariantRead | None:
item = await self.repository.get(variant_id, access_token=access_token)
if not item:
return None
return variant_to_read_model(to_variant(item))
async def create_variant(
self,
payload: VariantCreate,
access_token: str | None = None,
) -> VariantRead:
await self._assert_weights_sum_100(
experiment_id=payload.experiment_id,
new_weight=payload.traffic_weight,
access_token=access_token,
)
data: dict[str, Any] = payload.model_dump(exclude_none=True)
data["variant_key"] = generate_variant_key()
item = await self.repository.create(data, access_token=access_token)
return variant_to_read_model(to_variant(item))
async def update_variant(
self,
variant_id: str,
payload: VariantUpdate,
access_token: str | None = None,
) -> VariantRead | None:
data: dict[str, Any] = payload.model_dump(exclude_unset=True)
if "traffic_weight" in data:
existing = await self.repository.get(variant_id, access_token=access_token)
if existing:
await self._assert_weights_sum_100(
experiment_id=str(existing["experiment_id"]),
new_weight=data["traffic_weight"],
exclude_variant_id=variant_id,
access_token=access_token,
)
item = await self.repository.update(variant_id, data, access_token=access_token)
if not item:
return None
return variant_to_read_model(to_variant(item))