"""Utility functions for pydantic-settings sources.""" from __future__ import annotations as _annotations from collections import deque from collections.abc import Mapping, Sequence from dataclasses import is_dataclass from enum import Enum from typing import Any, TypeVar, cast, get_args, get_origin from pydantic import BaseModel, Json, RootModel, Secret from pydantic._internal._utils import is_model_class from pydantic.dataclasses import is_pydantic_dataclass from pydantic.fields import FieldInfo from typing_inspection import typing_objects from ..exceptions import SettingsError from ..utils import _lenient_issubclass from .types import EnvNoneType def _get_env_var_key(key: str, case_sensitive: bool = False) -> str: return key if case_sensitive else key.lower() def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType: return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value) def parse_env_vars( env_vars: Mapping[str, str | None], case_sensitive: bool = False, ignore_empty: bool = False, parse_none_str: str | None = None, ) -> Mapping[str, str | None]: return { _get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str) for k, v in env_vars.items() if not (ignore_empty and v == '') } def _substitute_typevars(tp: Any, param_map: dict[Any, Any]) -> Any: """Substitute TypeVars in a type annotation with concrete types from param_map.""" if isinstance(tp, TypeVar) and tp in param_map: return param_map[tp] args = get_args(tp) if not args: return tp new_args = tuple(_substitute_typevars(arg, param_map) for arg in args) if new_args == args: return tp origin = get_origin(tp) if origin is not None: try: return origin[new_args] except TypeError: # types.UnionType and similar are not directly subscriptable, # reconstruct using | operator import functools import operator return functools.reduce(operator.or_, new_args) return tp def _resolve_type_alias(annotation: Any) -> Any: """Resolve a TypeAliasType to its underlying value, substituting type params if parameterized.""" if typing_objects.is_typealiastype(annotation): return annotation.__value__ origin = get_origin(annotation) if typing_objects.is_typealiastype(origin): type_params = getattr(origin, '__type_params__', ()) type_args = get_args(annotation) value = origin.__value__ if type_params and type_args: return _substitute_typevars(value, dict(zip(type_params, type_args))) return value return annotation def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool: # If the model is a root model, the root annotation should be used to # evaluate the complexity. annotation = _resolve_type_alias(annotation) if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel: annotation = cast('type[RootModel[Any]]', annotation) root_annotation = annotation.model_fields['root'].annotation if root_annotation is not None: # pragma: no branch annotation = root_annotation if any(isinstance(md, Json) for md in metadata): # type: ignore[misc] return False origin = get_origin(annotation) # Check if annotation is of the form Annotated[type, metadata]. if typing_objects.is_annotated(origin): # Return result of recursive call on inner type. inner, *meta = get_args(annotation) return _annotation_is_complex(inner, meta) if origin is Secret: return False return ( _annotation_is_complex_inner(annotation) or _annotation_is_complex_inner(origin) or hasattr(origin, '__pydantic_core_schema__') or hasattr(origin, '__get_pydantic_core_schema__') ) def _get_field_metadata(field: FieldInfo) -> list[Any]: annotation = _resolve_type_alias(field.annotation) metadata = field.metadata origin = get_origin(annotation) if typing_objects.is_annotated(origin): _, *meta = get_args(annotation) metadata += meta return metadata def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: if _lenient_issubclass(annotation, (str, bytes)): return False return _lenient_issubclass( annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque) ) or is_dataclass(annotation) def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: """Check if a union type contains any complex types.""" return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) def _annotation_contains_types( annotation: type[Any] | None, types: tuple[Any, ...], is_include_origin: bool = True, is_strip_annotated: bool = False, is_instance: bool = False, collect: set[Any] | None = None, ) -> bool: """Check if a type annotation contains any of the specified types.""" if is_strip_annotated: annotation = _strip_annotated(annotation) if is_include_origin is True: origin = get_origin(annotation) if origin in types: if collect is None: return True collect.add(annotation) if is_instance and any(isinstance(origin, type_) for type_ in types): if collect is None: return True collect.add(annotation) for type_ in get_args(annotation): if ( _annotation_contains_types( type_, types, is_include_origin=True, is_strip_annotated=is_strip_annotated, is_instance=is_instance, collect=collect, ) and collect is None ): return True if is_instance and any(isinstance(annotation, type_) for type_ in types): if collect is None: return True collect.add(annotation) if annotation in types: if collect is not None: collect.add(annotation) return True return False def _strip_annotated(annotation: Any) -> Any: if typing_objects.is_annotated(get_origin(annotation)): return annotation.__origin__ else: return annotation def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> str | None: for type_ in (annotation, get_origin(annotation), *get_args(annotation)): if _lenient_issubclass(type_, Enum): if value in type_.__members__.values(): return type_(value).name return None def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any: for type_ in (annotation, get_origin(annotation), *get_args(annotation)): if _lenient_issubclass(type_, Enum): if name in type_.__members__.keys(): return type_[name] return None def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]: """Get fields from a pydantic model or dataclass.""" if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'): return model_cls.__pydantic_fields__ if is_model_class(model_cls): return model_cls.model_fields raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass') def _get_alias_names( field_name: str, field_info: Any, alias_path_args: dict[str, int | None] | None = None, case_sensitive: bool = True, ) -> tuple[tuple[str, ...], bool]: """Get alias names for a field, handling alias paths and case sensitivity.""" from pydantic import AliasChoices, AliasPath alias_names: list[str] = [] is_alias_path_only: bool = True if not any((field_info.alias, field_info.validation_alias)): alias_names += [field_name] is_alias_path_only = False else: new_alias_paths: list[AliasPath] = [] for alias in (field_info.alias, field_info.validation_alias): if alias is None: continue elif isinstance(alias, str): alias_names.append(alias) is_alias_path_only = False elif isinstance(alias, AliasChoices): for name in alias.choices: if isinstance(name, str): alias_names.append(name) is_alias_path_only = False else: new_alias_paths.append(name) else: new_alias_paths.append(alias) for alias_path in new_alias_paths: name = cast(str, alias_path.path[0]) name = name.lower() if not case_sensitive else name if alias_path_args is not None: alias_path_args[name] = ( alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None ) if not alias_names and is_alias_path_only: alias_names.append(name) if not case_sensitive: alias_names = [alias_name.lower() for alias_name in alias_names] return tuple(dict.fromkeys(alias_names)), is_alias_path_only def _is_function(obj: Any) -> bool: """Check if an object is a function.""" from types import BuiltinFunctionType, FunctionType return isinstance(obj, (FunctionType, BuiltinFunctionType)) __all__ = [ '_annotation_contains_types', '_annotation_enum_name_to_val', '_annotation_enum_val_to_name', '_annotation_is_complex', '_annotation_is_complex_inner', '_get_alias_names', '_get_env_var_key', '_get_model_fields', '_is_function', '_parse_env_none_str', '_resolve_type_alias', '_strip_annotated', '_union_is_complex', 'parse_env_vars', ]