Skip to content

Commit d93b372

Browse files
committed
Switch to typing-inspection
1 parent 1a4f3f4 commit d93b372

File tree

4 files changed

+82
-57
lines changed

4 files changed

+82
-57
lines changed

pydantic_settings/sources.py

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import warnings
1010
from abc import ABC, abstractmethod
1111

12+
import typing_extensions
13+
1214
if sys.version_info >= (3, 9):
1315
from argparse import BooleanOptionalAction
1416
from argparse import SUPPRESS, ArgumentParser, Namespace, RawDescriptionHelpFormatter, _SubParsersAction
@@ -35,18 +37,18 @@
3537
overload,
3638
)
3739

38-
import typing_extensions
3940
from dotenv import dotenv_values
4041
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, Secret, TypeAdapter
4142
from pydantic._internal._repr import Representation
42-
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
43-
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
43+
from pydantic._internal._utils import deep_update, is_model_class
4444
from pydantic.dataclasses import is_pydantic_dataclass
4545
from pydantic.fields import FieldInfo
4646
from pydantic_core import PydanticUndefined
47-
from typing_extensions import Annotated, _AnnotatedAlias, get_args, get_origin
47+
from typing_extensions import Annotated, get_args, get_origin
48+
from typing_inspection import typing_objects
49+
from typing_inspection.introspection import is_union_origin
4850

49-
from pydantic_settings.utils import path_type_label
51+
from pydantic_settings.utils import _lenient_issubclass, path_type_label
5052

5153
if TYPE_CHECKING:
5254
if sys.version_info >= (3, 11):
@@ -482,7 +484,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
482484
field_info.append((v_alias, self._apply_case_sensitive(v_alias), False))
483485

484486
if not v_alias or self.config.get('populate_by_name', False):
485-
if origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
487+
if is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
486488
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True))
487489
else:
488490
field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False))
@@ -528,12 +530,13 @@ class Settings(BaseSettings):
528530
annotation = field.annotation
529531

530532
# If field is Optional, we need to find the actual type
531-
args = get_args(annotation)
532-
if origin_is_union(get_origin(field.annotation)) and len(args) == 2 and type(None) in args:
533-
for arg in args:
534-
if arg is not None:
535-
annotation = arg
536-
break
533+
if is_union_origin(get_origin(field.annotation)):
534+
args = get_args(annotation)
535+
if len(args) == 2 and type(None) in args:
536+
for arg in args:
537+
if arg is not None:
538+
annotation = arg
539+
break
537540

538541
# This is here to make mypy happy
539542
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
@@ -551,7 +554,7 @@ class Settings(BaseSettings):
551554
values[name] = value
552555
continue
553556

554-
if lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
557+
if _lenient_issubclass(sub_model_field.annotation, BaseModel) and isinstance(value, dict):
555558
values[sub_model_field_name] = self._replace_field_names_case_insensitively(sub_model_field, value)
556559
else:
557560
values[sub_model_field_name] = value
@@ -621,7 +624,7 @@ def __call__(self) -> dict[str, Any]:
621624
field_value = None
622625
if (
623626
not self.case_sensitive
624-
# and lenient_issubclass(field.annotation, BaseModel)
627+
# and _lenient_issubclass(field.annotation, BaseModel)
625628
and isinstance(field_value, dict)
626629
):
627630
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
@@ -840,7 +843,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
840843
"""
841844
if self.field_is_complex(field):
842845
allow_parse_failure = False
843-
elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
846+
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
844847
allow_parse_failure = True
845848
else:
846849
return False, False
@@ -886,12 +889,11 @@ class Cfg(BaseSettings):
886889
return None
887890

888891
annotation = field.annotation if isinstance(field, FieldInfo) else field
889-
if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes):
890-
for type_ in get_args(annotation):
891-
type_has_key = self.next_field(type_, key, case_sensitive)
892-
if type_has_key:
893-
return type_has_key
894-
elif is_model_class(annotation) or is_pydantic_dataclass(annotation):
892+
for type_ in get_args(annotation):
893+
type_has_key = self.next_field(type_, key, case_sensitive)
894+
if type_has_key:
895+
return type_has_key
896+
if is_model_class(annotation) or is_pydantic_dataclass(annotation):
895897
fields = _get_model_fields(annotation)
896898
# `case_sensitive is None` is here to be compatible with the old behavior.
897899
# Has to be removed in V3.
@@ -921,7 +923,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
921923
if not self.env_nested_delimiter:
922924
return {}
923925

924-
is_dict = lenient_issubclass(get_origin(field.annotation), dict)
926+
ann = field.annotation
927+
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
925928

926929
prefixes = [
927930
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
@@ -1063,7 +1066,7 @@ def __call__(self) -> dict[str, Any]:
10631066
(
10641067
_annotation_is_complex(field.annotation, field.metadata)
10651068
or (
1066-
origin_is_union(get_origin(field.annotation))
1069+
is_union_origin(get_origin(field.annotation))
10671070
and _union_is_complex(field.annotation, field.metadata)
10681071
)
10691072
)
@@ -1380,7 +1383,7 @@ def _get_merge_parsed_list_types(
13801383
merge_type = self._cli_dict_args.get(field_name, list)
13811384
if (
13821385
merge_type is list
1383-
or not origin_is_union(get_origin(merge_type))
1386+
or not is_union_origin(get_origin(merge_type))
13841387
or not any(
13851388
type_
13861389
for type_ in get_args(merge_type)
@@ -1512,9 +1515,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
15121515

15131516
if field_info.annotation is not bool:
15141517
raise SettingsError(f'{cli_flag_name} argument {model.__name__}.{field_name} is not of type bool')
1515-
elif sys.version_info < (3, 9) and (
1516-
field_info.default is PydanticUndefined and field_info.default_factory is None
1517-
):
1518+
elif sys.version_info < (3, 9) and field_info.is_required():
15181519
raise SettingsError(
15191520
f'{cli_flag_name} argument {model.__name__}.{field_name} must have default for python versions < 3.9'
15201521
)
@@ -1530,7 +1531,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
15301531
alias_names, *_ = _get_alias_names(field_name, field_info)
15311532
if len(alias_names) > 1:
15321533
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
1533-
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
1534+
field_types = (type_ for type_ in get_args(field_info.annotation) if type_ is not type(None))
15341535
for field_type in field_types:
15351536
if not (is_model_class(field_type) or is_pydantic_dataclass(field_type)):
15361537
raise SettingsError(
@@ -1996,19 +1997,26 @@ def _metavar_format_recurse(self, obj: Any) -> str:
19961997
return '...'
19971998
elif isinstance(obj, Representation):
19981999
return repr(obj)
1999-
elif isinstance(obj, typing_extensions.TypeAliasType):
2000+
elif typing_objects.is_typealiastype(obj):
20002001
return str(obj)
20012002

2002-
if not isinstance(obj, (typing_base, WithArgsTypes, type)):
2003+
origin = get_origin(obj)
2004+
if (
2005+
origin is None
2006+
and not isinstance(obj, type)
2007+
and not isinstance(obj, (typing.ForwardRef, typing_extensions.ForwardRef))
2008+
):
20032009
obj = obj.__class__
20042010

2005-
if origin_is_union(get_origin(obj)):
2011+
args = get_args(obj)
2012+
2013+
if is_union_origin(origin):
20062014
return self._metavar_format_choices(list(map(self._metavar_format_recurse, self._get_modified_args(obj))))
2007-
elif get_origin(obj) in (typing_extensions.Literal, typing.Literal):
2015+
elif typing_objects.is_literal(origin):
20082016
return self._metavar_format_choices(list(map(str, self._get_modified_args(obj))))
2009-
elif lenient_issubclass(obj, Enum):
2017+
elif _lenient_issubclass(obj, Enum):
20102018
return self._metavar_format_choices([val.name for val in obj])
2011-
elif isinstance(obj, WithArgsTypes):
2019+
elif args:
20122020
return self._metavar_format_choices(
20132021
list(map(self._metavar_format_recurse, self._get_modified_args(obj))),
20142022
obj_qualname=obj.__qualname__ if hasattr(obj, '__qualname__') else str(obj),
@@ -2304,25 +2312,22 @@ def read_env_file(
23042312
def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
23052313
# If the model is a root model, the root annotation should be used to
23062314
# evaluate the complexity.
2307-
try:
2308-
if annotation is not None and issubclass(annotation, RootModel):
2309-
# In some rare cases (see test_root_model_as_field),
2310-
# the root attribute is not available. For these cases, python 3.8 and 3.9
2311-
# return 'RootModelRootType'.
2312-
root_annotation = annotation.__annotations__.get('root', None)
2313-
if root_annotation is not None and root_annotation != 'RootModelRootType':
2314-
annotation = root_annotation
2315-
except TypeError:
2316-
pass
2315+
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
2316+
annotation = cast('type[RootModel[Any]]', annotation)
2317+
root_annotation = annotation.model_fields['root'].annotation
2318+
if root_annotation is not None:
2319+
annotation = root_annotation
23172320

23182321
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
23192322
return False
2323+
2324+
origin = get_origin(annotation)
2325+
23202326
# Check if annotation is of the form Annotated[type, metadata].
2321-
if isinstance(annotation, _AnnotatedAlias):
2327+
if typing_objects.is_annotated(origin):
23222328
# Return result of recursive call on inner type.
23232329
inner, *meta = get_args(annotation)
23242330
return _annotation_is_complex(inner, meta)
2325-
origin = get_origin(annotation)
23262331

23272332
if origin is Secret:
23282333
return False
@@ -2336,12 +2341,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->
23362341

23372342

23382343
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
2339-
if lenient_issubclass(annotation, (str, bytes)):
2344+
if _lenient_issubclass(annotation, (str, bytes)):
23402345
return False
23412346

2342-
return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass(
2343-
annotation
2344-
)
2347+
return _lenient_issubclass(
2348+
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
2349+
) or is_dataclass(annotation)
23452350

23462351

23472352
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
@@ -2365,22 +2370,23 @@ def _annotation_contains_types(
23652370

23662371

23672372
def _strip_annotated(annotation: Any) -> Any:
2368-
while get_origin(annotation) == Annotated:
2369-
annotation = get_args(annotation)[0]
2370-
return annotation
2373+
if typing_objects.is_annotated(get_origin(annotation)):
2374+
return annotation.__origin__
2375+
else:
2376+
return annotation
23712377

23722378

23732379
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> Optional[str]:
23742380
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
2375-
if lenient_issubclass(type_, Enum):
2381+
if _lenient_issubclass(type_, Enum):
23762382
if value in tuple(val.value for val in type_):
23772383
return type_(value).name
23782384
return None
23792385

23802386

23812387
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
23822388
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
2383-
if lenient_issubclass(type_, Enum):
2389+
if _lenient_issubclass(type_, Enum):
23842390
if name in tuple(val.name for val in type_):
23852391
return type_[name]
23862392
return None

pydantic_settings/utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from pathlib import Path
2+
from typing import Any
3+
4+
from typing_extensions import get_origin
25

36
_PATH_TYPE_LABELS = {
47
Path.is_dir: 'directory',
@@ -22,3 +25,16 @@ def path_type_label(p: Path) -> str:
2225
return name
2326

2427
return 'unknown'
28+
29+
30+
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
31+
# once we drop support for Python 3.10.
32+
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
33+
try:
34+
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
35+
except TypeError:
36+
if get_origin(cls) is not None:
37+
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
38+
# (e.g. list[int])
39+
return False
40+
raise

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ requires-python = '>=3.8'
4343
dependencies = [
4444
'pydantic>=2.7.0',
4545
'python-dotenv>=0.21.0',
46+
'typing-inspection>=0.4.0',
4647
]
4748
dynamic = ['version']
4849

requirements/pyproject.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#
2-
# This file is autogenerated by pip-compile with Python 3.8
2+
# This file is autogenerated by pip-compile with Python 3.13
33
# by the following command:
44
#
55
# pip-compile --extra=azure-key-vault --extra=toml --extra=yaml --no-emit-index-url --output-file=requirements/pyproject.txt pyproject.toml
@@ -63,11 +63,13 @@ tomli==2.0.1
6363
# via pydantic-settings (pyproject.toml)
6464
typing-extensions==4.12.2
6565
# via
66-
# annotated-types
6766
# azure-core
6867
# azure-identity
6968
# azure-keyvault-secrets
7069
# pydantic
7170
# pydantic-core
71+
# typing-inspection
72+
typing-inspection==0.4.0
73+
# via pydantic-settings (pyproject.toml)
7274
urllib3==2.2.2
7375
# via requests

0 commit comments

Comments
 (0)