diff --git a/stdlib/2and3/warnings.pyi b/stdlib/2and3/warnings.pyi index 138b71341d1d..2e95533f7fc1 100644 --- a/stdlib/2and3/warnings.pyi +++ b/stdlib/2and3/warnings.pyi @@ -1,7 +1,13 @@ # Stubs for warnings -from typing import Any, Dict, List, NamedTuple, Optional, overload, TextIO, Tuple, Type, Union -from types import ModuleType, TracebackType +import sys +from typing import Any, Dict, List, NamedTuple, Optional, overload, TextIO, Tuple, Type, Union, ContextManager +from types import ModuleType + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal @overload def warn(message: str, category: Optional[Type[Warning]] = ..., stacklevel: int = ...) -> None: ... @@ -37,10 +43,12 @@ class _Record(NamedTuple): file: Optional[TextIO] line: Optional[str] -class catch_warnings: - def __init__(self, *, record: bool = ..., - module: Optional[ModuleType] = ...) -> None: ... - def __enter__(self) -> Optional[List[_Record]]: ... - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> None: ... + +@overload +def catch_warnings(*, record: Literal[False] = ..., module: Optional[ModuleType] = ...) -> ContextManager[None]: ... + +@overload +def catch_warnings(*, record: Literal[True], module: Optional[ModuleType] = ...) -> ContextManager[List[_Record]]: ... + +@overload +def catch_warnings(*, record: bool, module: Optional[ModuleType] = ...) -> ContextManager[Optional[List[_Record]]]: ...