from __future__ import annotations from typing import Any from fastapi import HTTPException from app.domain.keys import generate_variant_key from app.domain.mappers import to_variant from app.repositories.directus.variants import VariantRepository from app.schemas.admin import VariantCreate, VariantRead, VariantUpdate, variant_to_read_model class VariantService: """Application service for variant management.""" def __init__(self, repository: VariantRepository | None = None) -> None: self.repository = repository or VariantRepository() async def _assert_weights_sum_100( self, experiment_id: str, new_weight: int, exclude_variant_id: str | None = None, access_token: str | None = None, ) -> None: items = await self.repository.list( params={"filter[experiment_id][_eq]": experiment_id}, access_token=access_token, ) existing_sum = sum( int(item.get("traffic_weight", 0)) for item in items if str(item["id"]) != exclude_variant_id ) if existing_sum + new_weight != 100: raise HTTPException( status_code=400, detail=f"所有變體的流量權重加總必須等於 100(目前其他變體合計 {existing_sum},本次輸入 {new_weight})", ) async def list_variants( self, experiment_id: str | None = None, access_token: str | None = None, ) -> list[VariantRead]: params: dict[str, Any] | None = None if experiment_id: params = {"filter[experiment_id][_eq]": experiment_id} items = await self.repository.list(params=params, access_token=access_token) return [variant_to_read_model(to_variant(item)) for item in items] async def get_variant( self, variant_id: str, access_token: str | None = None, ) -> VariantRead | None: item = await self.repository.get(variant_id, access_token=access_token) if not item: return None return variant_to_read_model(to_variant(item)) async def create_variant( self, payload: VariantCreate, access_token: str | None = None, ) -> VariantRead: await self._assert_weights_sum_100( experiment_id=payload.experiment_id, new_weight=payload.traffic_weight, access_token=access_token, ) data: dict[str, Any] = payload.model_dump(exclude_none=True) data["variant_key"] = generate_variant_key() item = await self.repository.create(data, access_token=access_token) return variant_to_read_model(to_variant(item)) async def update_variant( self, variant_id: str, payload: VariantUpdate, access_token: str | None = None, ) -> VariantRead | None: data: dict[str, Any] = payload.model_dump(exclude_unset=True) if "traffic_weight" in data: existing = await self.repository.get(variant_id, access_token=access_token) if existing: await self._assert_weights_sum_100( experiment_id=str(existing["experiment_id"]), new_weight=data["traffic_weight"], exclude_variant_id=variant_id, access_token=access_token, ) item = await self.repository.update(variant_id, data, access_token=access_token) if not item: return None return variant_to_read_model(to_variant(item))