From f94cfd339a5e2b548b42abf116b925fe12a7581c Mon Sep 17 00:00:00 2001 From: Akuli Date: Fri, 7 May 2021 16:06:37 +0300 Subject: [PATCH 1/4] make DictWriter generic to accept non-string fieldnames --- stdlib/csv.pyi | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/stdlib/csv.pyi b/stdlib/csv.pyi index 606694dca533..dcefd1a1d27f 100644 --- a/stdlib/csv.pyi +++ b/stdlib/csv.pyi @@ -18,9 +18,9 @@ from _csv import ( writer as writer, ) from collections import OrderedDict -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type +from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar -_DictRow = Mapping[str, Any] +_T = TypeVar("_T") class excel(Dialect): delimiter: str @@ -42,13 +42,12 @@ if sys.version_info >= (3,): lineterminator: str quoting: int -if sys.version_info >= (3, 8): +if sys.version_info >= (3, 8) or sys.version_info < (3, 6): _DRMapping = Dict[str, str] -elif sys.version_info >= (3, 6): - _DRMapping = OrderedDict[str, str] else: - _DRMapping = Dict[str, str] + _DRMapping = OrderedDict[str, str] +# TODO: make keys generic, defaulting to string if no fieldnames are given (#4800) class DictReader(Iterator[_DRMapping]): restkey: Optional[str] restval: Optional[str] @@ -72,15 +71,15 @@ class DictReader(Iterator[_DRMapping]): else: def next(self) -> _DRMapping: ... -class DictWriter(object): - fieldnames: Sequence[str] +class DictWriter(Generic[_T]): + fieldnames: Sequence[_T] restval: Optional[Any] extrasaction: str writer: _writer def __init__( self, f: Any, - fieldnames: Iterable[str], + fieldnames: Iterable[_T], restval: Optional[Any] = ..., extrasaction: str = ..., dialect: _DialectLike = ..., @@ -91,8 +90,8 @@ class DictWriter(object): def writeheader(self) -> Any: ... else: def writeheader(self) -> None: ... - def writerow(self, rowdict: _DictRow) -> Any: ... - def writerows(self, rowdicts: Iterable[_DictRow]) -> None: ... + def writerow(self, rowdict: Mapping[_T, Any]) -> Any: ... + def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None: ... class Sniffer(object): preferred: List[str] From f7bf620c255218bc1f3b30c4cf43e9b4d4dda89e Mon Sep 17 00:00:00 2001 From: Akuli Date: Fri, 7 May 2021 16:12:59 +0300 Subject: [PATCH 2/4] make DictReader generic to suport non-string fieldnames --- stdlib/csv.pyi | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/stdlib/csv.pyi b/stdlib/csv.pyi index dcefd1a1d27f..56acb0d7bef8 100644 --- a/stdlib/csv.pyi +++ b/stdlib/csv.pyi @@ -17,8 +17,12 @@ from _csv import ( unregister_dialect as unregister_dialect, writer as writer, ) -from collections import OrderedDict -from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar +from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar, overload + +if sys.version_info >= (3, 8) or sys.version_info < (3, 6): + from typing import Dict as _DictReadMapping +else: + from collections import OrderedDict as _DictReadMapping _T = TypeVar("_T") @@ -42,22 +46,28 @@ if sys.version_info >= (3,): lineterminator: str quoting: int -if sys.version_info >= (3, 8) or sys.version_info < (3, 6): - _DRMapping = Dict[str, str] -else: - _DRMapping = OrderedDict[str, str] - -# TODO: make keys generic, defaulting to string if no fieldnames are given (#4800) -class DictReader(Iterator[_DRMapping]): +class DictReader(Generic[_T], Iterator[_DictReadMapping[_T, str]]): + fieldnames: Optional[Sequence[_T]] restkey: Optional[str] restval: Optional[str] reader: _reader dialect: _DialectLike line_num: int - fieldnames: Optional[Sequence[str]] + @overload def __init__( self, f: Iterable[Text], + fieldnames: Sequence[_T], + restkey: Optional[str] = ..., + restval: Optional[str] = ..., + dialect: _DialectLike = ..., + *args: Any, + **kwds: Any, + ) -> None: ... + @overload + def __init__( + self: DictReader[str], + f: Iterable[Text], fieldnames: Optional[Sequence[str]] = ..., restkey: Optional[str] = ..., restval: Optional[str] = ..., @@ -67,9 +77,9 @@ class DictReader(Iterator[_DRMapping]): ) -> None: ... def __iter__(self) -> DictReader: ... if sys.version_info >= (3,): - def __next__(self) -> _DRMapping: ... + def __next__(self) -> _DictReadMapping[_T, str]: ... else: - def next(self) -> _DRMapping: ... + def next(self) -> _DictReadMapping[_T, str]: ... class DictWriter(Generic[_T]): fieldnames: Sequence[_T] From 3b9214d9ff2f731720c3260f15fb1eb38ea43a84 Mon Sep 17 00:00:00 2001 From: Akuli Date: Fri, 7 May 2021 16:20:42 +0300 Subject: [PATCH 3/4] fixing --- stdlib/csv.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/csv.pyi b/stdlib/csv.pyi index 56acb0d7bef8..6239ca2e69c2 100644 --- a/stdlib/csv.pyi +++ b/stdlib/csv.pyi @@ -17,7 +17,7 @@ from _csv import ( unregister_dialect as unregister_dialect, writer as writer, ) -from typing import Any, Dict, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar, overload +from typing import Any, Generic, Iterable, Iterator, List, Mapping, Optional, Sequence, Text, Type, TypeVar, overload if sys.version_info >= (3, 8) or sys.version_info < (3, 6): from typing import Dict as _DictReadMapping @@ -75,7 +75,7 @@ class DictReader(Generic[_T], Iterator[_DictReadMapping[_T, str]]): *args: Any, **kwds: Any, ) -> None: ... - def __iter__(self) -> DictReader: ... + def __iter__(self) -> DictReader[_T]: ... if sys.version_info >= (3,): def __next__(self) -> _DictReadMapping[_T, str]: ... else: From fcf3f353947146c5e715c7d02035727a358176bf Mon Sep 17 00:00:00 2001 From: Akuli Date: Fri, 7 May 2021 17:01:45 +0300 Subject: [PATCH 4/4] require Sequence --- stdlib/csv.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/csv.pyi b/stdlib/csv.pyi index 6239ca2e69c2..db48e9e6c78a 100644 --- a/stdlib/csv.pyi +++ b/stdlib/csv.pyi @@ -89,7 +89,7 @@ class DictWriter(Generic[_T]): def __init__( self, f: Any, - fieldnames: Iterable[_T], + fieldnames: Sequence[_T], restval: Optional[Any] = ..., extrasaction: str = ..., dialect: _DialectLike = ...,