9
9
import warnings
10
10
from abc import ABC , abstractmethod
11
11
12
+ import typing_extensions
13
+
12
14
if sys .version_info >= (3 , 9 ):
13
15
from argparse import BooleanOptionalAction
14
16
from argparse import SUPPRESS , ArgumentParser , Namespace , RawDescriptionHelpFormatter , _SubParsersAction
35
37
overload ,
36
38
)
37
39
38
- import typing_extensions
39
40
from dotenv import dotenv_values
40
41
from pydantic import AliasChoices , AliasPath , BaseModel , Json , RootModel , Secret , TypeAdapter
41
42
from pydantic ._internal ._repr import Representation
42
- from pydantic ._internal ._typing_extra import WithArgsTypes , origin_is_union , typing_base
43
- from pydantic ._internal ._utils import deep_update , is_model_class , lenient_issubclass
43
+ from pydantic ._internal ._utils import deep_update , is_model_class
44
44
from pydantic .dataclasses import is_pydantic_dataclass
45
45
from pydantic .fields import FieldInfo
46
46
from pydantic_core import PydanticUndefined
47
- from typing_extensions import Annotated , _AnnotatedAlias , get_args , get_origin
47
+ from typing_extensions import Annotated , get_args , get_origin
48
+ from typing_inspection import typing_objects
49
+ from typing_inspection .introspection import is_union_origin
48
50
49
- from pydantic_settings .utils import path_type_label
51
+ from pydantic_settings .utils import _lenient_issubclass , path_type_label
50
52
51
53
if TYPE_CHECKING :
52
54
if sys .version_info >= (3 , 11 ):
@@ -482,7 +484,7 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s
482
484
field_info .append ((v_alias , self ._apply_case_sensitive (v_alias ), False ))
483
485
484
486
if not v_alias or self .config .get ('populate_by_name' , False ):
485
- if origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
487
+ if is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
486
488
field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), True ))
487
489
else :
488
490
field_info .append ((field_name , self ._apply_case_sensitive (self .env_prefix + field_name ), False ))
@@ -528,12 +530,13 @@ class Settings(BaseSettings):
528
530
annotation = field .annotation
529
531
530
532
# If field is Optional, we need to find the actual type
531
- args = get_args (annotation )
532
- if origin_is_union (get_origin (field .annotation )) and len (args ) == 2 and type (None ) in args :
533
- for arg in args :
534
- if arg is not None :
535
- annotation = arg
536
- break
533
+ if is_union_origin (get_origin (field .annotation )):
534
+ args = get_args (annotation )
535
+ if len (args ) == 2 and type (None ) in args :
536
+ for arg in args :
537
+ if arg is not None :
538
+ annotation = arg
539
+ break
537
540
538
541
# This is here to make mypy happy
539
542
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
@@ -551,7 +554,7 @@ class Settings(BaseSettings):
551
554
values [name ] = value
552
555
continue
553
556
554
- if lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
557
+ if _lenient_issubclass (sub_model_field .annotation , BaseModel ) and isinstance (value , dict ):
555
558
values [sub_model_field_name ] = self ._replace_field_names_case_insensitively (sub_model_field , value )
556
559
else :
557
560
values [sub_model_field_name ] = value
@@ -621,7 +624,7 @@ def __call__(self) -> dict[str, Any]:
621
624
field_value = None
622
625
if (
623
626
not self .case_sensitive
624
- # and lenient_issubclass (field.annotation, BaseModel)
627
+ # and _lenient_issubclass (field.annotation, BaseModel)
625
628
and isinstance (field_value , dict )
626
629
):
627
630
data [field_key ] = self ._replace_field_names_case_insensitively (field , field_value )
@@ -840,7 +843,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
840
843
"""
841
844
if self .field_is_complex (field ):
842
845
allow_parse_failure = False
843
- elif origin_is_union (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
846
+ elif is_union_origin (get_origin (field .annotation )) and _union_is_complex (field .annotation , field .metadata ):
844
847
allow_parse_failure = True
845
848
else :
846
849
return False , False
@@ -886,12 +889,11 @@ class Cfg(BaseSettings):
886
889
return None
887
890
888
891
annotation = field .annotation if isinstance (field , FieldInfo ) else field
889
- if origin_is_union (get_origin (annotation )) or isinstance (annotation , WithArgsTypes ):
890
- for type_ in get_args (annotation ):
891
- type_has_key = self .next_field (type_ , key , case_sensitive )
892
- if type_has_key :
893
- return type_has_key
894
- elif is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
892
+ for type_ in get_args (annotation ):
893
+ type_has_key = self .next_field (type_ , key , case_sensitive )
894
+ if type_has_key :
895
+ return type_has_key
896
+ if is_model_class (annotation ) or is_pydantic_dataclass (annotation ):
895
897
fields = _get_model_fields (annotation )
896
898
# `case_sensitive is None` is here to be compatible with the old behavior.
897
899
# Has to be removed in V3.
@@ -921,7 +923,8 @@ def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[
921
923
if not self .env_nested_delimiter :
922
924
return {}
923
925
924
- is_dict = lenient_issubclass (get_origin (field .annotation ), dict )
926
+ ann = field .annotation
927
+ is_dict = ann is dict or _lenient_issubclass (get_origin (ann ), dict )
925
928
926
929
prefixes = [
927
930
f'{ env_name } { self .env_nested_delimiter } ' for _ , env_name , _ in self ._extract_field_info (field , field_name )
@@ -1063,7 +1066,7 @@ def __call__(self) -> dict[str, Any]:
1063
1066
(
1064
1067
_annotation_is_complex (field .annotation , field .metadata )
1065
1068
or (
1066
- origin_is_union (get_origin (field .annotation ))
1069
+ is_union_origin (get_origin (field .annotation ))
1067
1070
and _union_is_complex (field .annotation , field .metadata )
1068
1071
)
1069
1072
)
@@ -1380,7 +1383,7 @@ def _get_merge_parsed_list_types(
1380
1383
merge_type = self ._cli_dict_args .get (field_name , list )
1381
1384
if (
1382
1385
merge_type is list
1383
- or not origin_is_union (get_origin (merge_type ))
1386
+ or not is_union_origin (get_origin (merge_type ))
1384
1387
or not any (
1385
1388
type_
1386
1389
for type_ in get_args (merge_type )
@@ -1512,9 +1515,7 @@ def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str,
1512
1515
1513
1516
if field_info .annotation is not bool :
1514
1517
raise SettingsError (f'{ cli_flag_name } argument { model .__name__ } .{ field_name } is not of type bool' )
1515
- elif sys .version_info < (3 , 9 ) and (
1516
- field_info .default is PydanticUndefined and field_info .default_factory is None
1517
- ):
1518
+ elif sys .version_info < (3 , 9 ) and field_info .is_required ():
1518
1519
raise SettingsError (
1519
1520
f'{ cli_flag_name } argument { model .__name__ } .{ field_name } must have default for python versions < 3.9'
1520
1521
)
@@ -1530,7 +1531,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
1530
1531
alias_names , * _ = _get_alias_names (field_name , field_info )
1531
1532
if len (alias_names ) > 1 :
1532
1533
raise SettingsError (f'subcommand argument { model .__name__ } .{ field_name } has multiple aliases' )
1533
- field_types = [ type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None )]
1534
+ field_types = ( type_ for type_ in get_args (field_info .annotation ) if type_ is not type (None ))
1534
1535
for field_type in field_types :
1535
1536
if not (is_model_class (field_type ) or is_pydantic_dataclass (field_type )):
1536
1537
raise SettingsError (
@@ -1996,19 +1997,26 @@ def _metavar_format_recurse(self, obj: Any) -> str:
1996
1997
return '...'
1997
1998
elif isinstance (obj , Representation ):
1998
1999
return repr (obj )
1999
- elif isinstance (obj , typing_extensions . TypeAliasType ):
2000
+ elif typing_objects . is_typealiastype (obj ):
2000
2001
return str (obj )
2001
2002
2002
- if not isinstance (obj , (typing_base , WithArgsTypes , type )):
2003
+ origin = get_origin (obj )
2004
+ if (
2005
+ origin is None
2006
+ and not isinstance (obj , type )
2007
+ and not isinstance (obj , (typing .ForwardRef , typing_extensions .ForwardRef ))
2008
+ ):
2003
2009
obj = obj .__class__
2004
2010
2005
- if origin_is_union (get_origin (obj )):
2011
+ args = get_args (obj )
2012
+
2013
+ if is_union_origin (origin ):
2006
2014
return self ._metavar_format_choices (list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))))
2007
- elif get_origin ( obj ) in ( typing_extensions . Literal , typing . Literal ):
2015
+ elif typing_objects . is_literal ( origin ):
2008
2016
return self ._metavar_format_choices (list (map (str , self ._get_modified_args (obj ))))
2009
- elif lenient_issubclass (obj , Enum ):
2017
+ elif _lenient_issubclass (obj , Enum ):
2010
2018
return self ._metavar_format_choices ([val .name for val in obj ])
2011
- elif isinstance ( obj , WithArgsTypes ) :
2019
+ elif args :
2012
2020
return self ._metavar_format_choices (
2013
2021
list (map (self ._metavar_format_recurse , self ._get_modified_args (obj ))),
2014
2022
obj_qualname = obj .__qualname__ if hasattr (obj , '__qualname__' ) else str (obj ),
@@ -2304,25 +2312,22 @@ def read_env_file(
2304
2312
def _annotation_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
2305
2313
# If the model is a root model, the root annotation should be used to
2306
2314
# evaluate the complexity.
2307
- try :
2308
- if annotation is not None and issubclass (annotation , RootModel ):
2309
- # In some rare cases (see test_root_model_as_field),
2310
- # the root attribute is not available. For these cases, python 3.8 and 3.9
2311
- # return 'RootModelRootType'.
2312
- root_annotation = annotation .__annotations__ .get ('root' , None )
2313
- if root_annotation is not None and root_annotation != 'RootModelRootType' :
2314
- annotation = root_annotation
2315
- except TypeError :
2316
- pass
2315
+ if annotation is not None and _lenient_issubclass (annotation , RootModel ) and annotation is not RootModel :
2316
+ annotation = cast ('type[RootModel[Any]]' , annotation )
2317
+ root_annotation = annotation .model_fields ['root' ].annotation
2318
+ if root_annotation is not None :
2319
+ annotation = root_annotation
2317
2320
2318
2321
if any (isinstance (md , Json ) for md in metadata ): # type: ignore[misc]
2319
2322
return False
2323
+
2324
+ origin = get_origin (annotation )
2325
+
2320
2326
# Check if annotation is of the form Annotated[type, metadata].
2321
- if isinstance ( annotation , _AnnotatedAlias ):
2327
+ if typing_objects . is_annotated ( origin ):
2322
2328
# Return result of recursive call on inner type.
2323
2329
inner , * meta = get_args (annotation )
2324
2330
return _annotation_is_complex (inner , meta )
2325
- origin = get_origin (annotation )
2326
2331
2327
2332
if origin is Secret :
2328
2333
return False
@@ -2336,12 +2341,12 @@ def _annotation_is_complex(annotation: type[Any] | None, metadata: list[Any]) ->
2336
2341
2337
2342
2338
2343
def _annotation_is_complex_inner (annotation : type [Any ] | None ) -> bool :
2339
- if lenient_issubclass (annotation , (str , bytes )):
2344
+ if _lenient_issubclass (annotation , (str , bytes )):
2340
2345
return False
2341
2346
2342
- return lenient_issubclass ( annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )) or is_dataclass (
2343
- annotation
2344
- )
2347
+ return _lenient_issubclass (
2348
+ annotation , ( BaseModel , Mapping , Sequence , tuple , set , frozenset , deque )
2349
+ ) or is_dataclass ( annotation )
2345
2350
2346
2351
2347
2352
def _union_is_complex (annotation : type [Any ] | None , metadata : list [Any ]) -> bool :
@@ -2365,22 +2370,23 @@ def _annotation_contains_types(
2365
2370
2366
2371
2367
2372
def _strip_annotated (annotation : Any ) -> Any :
2368
- while get_origin (annotation ) == Annotated :
2369
- annotation = get_args (annotation )[0 ]
2370
- return annotation
2373
+ if typing_objects .is_annotated (get_origin (annotation )):
2374
+ return annotation .__origin__
2375
+ else :
2376
+ return annotation
2371
2377
2372
2378
2373
2379
def _annotation_enum_val_to_name (annotation : type [Any ] | None , value : Any ) -> Optional [str ]:
2374
2380
for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2375
- if lenient_issubclass (type_ , Enum ):
2381
+ if _lenient_issubclass (type_ , Enum ):
2376
2382
if value in tuple (val .value for val in type_ ):
2377
2383
return type_ (value ).name
2378
2384
return None
2379
2385
2380
2386
2381
2387
def _annotation_enum_name_to_val (annotation : type [Any ] | None , name : Any ) -> Any :
2382
2388
for type_ in (annotation , get_origin (annotation ), * get_args (annotation )):
2383
- if lenient_issubclass (type_ , Enum ):
2389
+ if _lenient_issubclass (type_ , Enum ):
2384
2390
if name in tuple (val .name for val in type_ ):
2385
2391
return type_ [name ]
2386
2392
return None
0 commit comments