Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 240db97

Browse files
authored
Merge pull request #704 from nolar/simplify-self
Annotate types for self-cloning factories as per PEP-673
2 parents 47c070e + 8040b7f commit 240db97

File tree

11 files changed

+48
-44
lines changed

11 files changed

+48
-44
lines changed

data_diff/cloud/datafold_api.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
import dataclasses
33
import enum
44
import time
5-
from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple
5+
from typing import Any, Dict, List, Optional, Type, Tuple
66

77
import pydantic
88
import requests
9+
from typing_extensions import Self
910

1011
from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError
1112

1213
from ..utils import getLogger
1314

1415
logger = getLogger(__name__)
1516

16-
Self = TypeVar("Self", bound=pydantic.BaseModel)
17-
1817

1918
class TestDataSourceStatus(str, enum.Enum):
2019
SUCCESS = "ok"
@@ -30,7 +29,7 @@ class TCloudApiDataSourceSchema(pydantic.BaseModel):
3029
secret: List[str]
3130

3231
@classmethod
33-
def from_orm(cls: Type[Self], obj: Any) -> Self:
32+
def from_orm(cls, obj: Any) -> Self:
3433
data_source_types_required_parameters = {
3534
"bigquery": ["projectId", "jsonKeyFile", "location"],
3635
"databricks": ["host", "http_password", "database", "http_path"],
@@ -154,7 +153,7 @@ class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
154153
dependencies: Optional[Dict[str, Any]]
155154

156155
@classmethod
157-
def from_orm(cls: Type[Self], obj: Any) -> Self:
156+
def from_orm(cls, obj: Any) -> Self:
158157
pks = TSummaryResultPrimaryKeyStats(**obj["pks"]) if "pks" in obj else None
159158
values = TSummaryResultValueStats(**obj["values"]) if "values" in obj else None
160159
deps = obj["deps"] if "deps" in obj else None

data_diff/sqeleton/abcs/database_types.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Sequence, Optional, Tuple, Union, Dict, List
3+
from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
44
from datetime import datetime
55

66
from runtype import dataclass
7+
from typing_extensions import Self
78

8-
from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
9+
from ..utils import ArithAlphanumeric, ArithUUID, Unknown
910

1011

1112
DbPath = Tuple[str, ...]

data_diff/sqeleton/bound_exprs.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Union, TYPE_CHECKING
66

77
from runtype import dataclass
8+
from typing_extensions import Self
89

910
from .abcs import AbstractDatabase, AbstractCompiler
1011
from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable
@@ -52,11 +53,11 @@ class BoundTable(BoundNode): # ITable
5253
database: AbstractDatabase
5354
node: TablePath
5455

55-
def with_schema(self, schema):
56+
def with_schema(self, schema) -> Self:
5657
table_path = self.node.replace(schema=schema)
5758
return self.replace(node=table_path)
5859

59-
def query_schema(self, *, columns=None, where=None, case_sensitive=True):
60+
def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self:
6061
table_path = self.node
6162

6263
if table_path.schema:

data_diff/sqeleton/databases/_connect.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import toml
66

77
from runtype import dataclass
8+
from typing_extensions import Self
89

910
from ..abcs.mixins import AbstractMixin
10-
from ..utils import WeakCache, Self
11+
from ..utils import WeakCache
1112
from .base import Database, ThreadedDatabase
1213
from .postgresql import PostgreSQL
1314
from .mysql import MySQL
@@ -99,7 +100,7 @@ def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME)
99100
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
100101
self.conn_cache = WeakCache()
101102

102-
def for_databases(self, *dbs):
103+
def for_databases(self, *dbs) -> Self:
103104
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
104105
return type(self)(database_by_scheme)
105106

data_diff/sqeleton/databases/base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
import decimal
1212

1313
from runtype import dataclass
14+
from typing_extensions import Self
1415

15-
from ..utils import is_uuid, safezip, Self
16+
from ..utils import is_uuid, safezip
1617
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
1718
from ..queries.ast_classes import Random
1819
from ..abcs.database_types import (
@@ -281,7 +282,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
281282
return math.floor(math.log(2**p, 10))
282283

283284
@classmethod
284-
def load_mixins(cls, *abstract_mixins) -> "Self":
285+
def load_mixins(cls, *abstract_mixins) -> Self:
285286
mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}
286287

287288
class _DialectWithMixins(cls, *mixins, *abstract_mixins):

data_diff/sqeleton/queries/ast_classes.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from dataclasses import field
22
from datetime import datetime
3-
from typing import Any, Generator, List, Optional, Sequence, Union, Dict
3+
from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict
44

55
from runtype import dataclass
6+
from typing_extensions import Self
67

78
from ..utils import join_iter, ArithString
89
from ..abcs import Compilable
@@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
322323
return QB_When(self, whens[0])
323324
return QB_When(self, BinBoolOp("AND", whens))
324325

325-
def else_(self, then: Expr):
326+
def else_(self, then: Expr) -> Self:
326327
"""Add an 'else' clause to the case expression.
327328
328329
Can only be called once!
@@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
422423
schema: Optional[Schema] = field(default=None, repr=False)
423424

424425
@property
425-
def source_table(self):
426+
def source_table(self) -> Self:
426427
return self
427428

428429
def compile(self, c: Compiler) -> str:
@@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
524525
columns: Sequence[Expr] = None
525526

526527
@property
527-
def source_table(self):
528+
def source_table(self) -> Self:
528529
return self
529530

530531
@property
@@ -533,7 +534,7 @@ def schema(self):
533534
s = self.source_tables[0].schema # TODO validate types match between both tables
534535
return type(s)({c.name: c.type for c in self.columns})
535536

536-
def on(self, *exprs) -> "Join":
537+
def on(self, *exprs) -> Self:
537538
"""Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
538539
if len(exprs) == 1:
539540
(e,) = exprs
@@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":
546547

547548
return self.replace(on_exprs=(self.on_exprs or []) + exprs)
548549

549-
def select(self, *exprs, **named_exprs) -> ITable:
550+
def select(self, *exprs, **named_exprs) -> Union[Self, ITable]:
550551
"""Select fields to return from the JOIN operation
551552
552553
See Also: ``ITable.select()``
@@ -600,7 +601,7 @@ def source_table(self):
600601
def __post_init__(self):
601602
assert self.keys or self.values
602603

603-
def having(self, *exprs):
604+
def having(self, *exprs) -> Self:
604605
"""Add a 'HAVING' clause to the group-by"""
605606
exprs = args_as_tuple(exprs)
606607
exprs = _drop_skips(exprs)
@@ -610,7 +611,7 @@ def having(self, *exprs):
610611
resolve_names(self.table, exprs)
611612
return self.replace(having_exprs=(self.having_exprs or []) + exprs)
612613

613-
def agg(self, *exprs):
614+
def agg(self, *exprs) -> Self:
614615
"""Select aggregated fields for the group-by."""
615616
exprs = args_as_tuple(exprs)
616617
exprs = _drop_skips(exprs)
@@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:
991992

992993
return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
993994

994-
def returning(self, *exprs):
995+
def returning(self, *exprs) -> Self:
995996
"""Add a 'RETURNING' clause to the insert expression.
996997
997998
Note: Not all databases support this feature!

data_diff/sqeleton/queries/compiler.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, Sequence, List
55

66
from runtype import dataclass
7+
from typing_extensions import Self
78

89
from ..utils import ArithString
910
from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable
@@ -79,7 +80,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
7980
self._counter[0] += 1
8081
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
8182

82-
def add_table_context(self, *tables: Sequence, **kw):
83+
def add_table_context(self, *tables: Sequence, **kw) -> Self:
8384
return self.replace(_table_context=self._table_context + list(tables), **kw)
8485

8586
def quote(self, s: str):

data_diff/sqeleton/utils.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
Iterable,
33
Iterator,
44
MutableMapping,
5+
Type,
56
Union,
67
Any,
78
Sequence,
89
Dict,
910
Hashable,
1011
TypeVar,
11-
TYPE_CHECKING,
1212
List,
1313
)
1414
from abc import abstractmethod
@@ -19,12 +19,9 @@
1919
from uuid import UUID
2020
from urllib.parse import urlparse
2121

22-
# -- Common --
22+
from typing_extensions import Self
2323

24-
try:
25-
from typing import Self
26-
except ImportError:
27-
Self = Any
24+
# -- Common --
2825

2926

3027
class WeakCache:
@@ -95,7 +92,7 @@ class CaseAwareMapping(MutableMapping[str, V]):
9592
def get_key(self, key: str) -> str:
9693
...
9794

98-
def new(self, initial=()):
95+
def new(self, initial=()) -> Self:
9996
return type(self)(initial)
10097

10198

@@ -144,10 +141,10 @@ def as_insensitive(self):
144141

145142
class ArithString:
146143
@classmethod
147-
def new(cls, *args, **kw):
144+
def new(cls, *args, **kw) -> Self:
148145
return cls(*args, **kw)
149146

150-
def range(self, other: "ArithString", count: int):
147+
def range(self, other: "ArithString", count: int) -> List[Self]:
151148
assert isinstance(other, ArithString)
152149
checkpoints = split_space(self.int, other.int, count)
153150
return [self.new(int=i) for i in checkpoints]
@@ -159,7 +156,7 @@ class ArithUUID(UUID, ArithString):
159156
def __int__(self):
160157
return self.int
161158

162-
def __add__(self, other: int):
159+
def __add__(self, other: int) -> Self:
163160
if isinstance(other, int):
164161
return self.new(int=self.int + other)
165162
return NotImplemented
@@ -231,7 +228,7 @@ def __len__(self):
231228
def __repr__(self):
232229
return f'alphanum"{self._str}"'
233230

234-
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric":
231+
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self:
235232
if isinstance(other, int):
236233
if other != 1:
237234
raise NotImplementedError("not implemented for arbitrary numbers")
@@ -240,7 +237,7 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric"
240237

241238
return NotImplemented
242239

243-
def range(self, other: "ArithAlphanumeric", count: int):
240+
def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]:
244241
assert isinstance(other, ArithAlphanumeric)
245242
n1, n2 = alphanums_to_numbers(self._str, other._str)
246243
split = split_space(n1, n2, count)
@@ -268,7 +265,7 @@ def __eq__(self, other):
268265
return NotImplemented
269266
return self._str == other._str
270267

271-
def new(self, *args, **kw):
268+
def new(self, *args, **kw) -> Self:
272269
return type(self)(*args, **kw, max_len=self._max_len)
273270

274271

data_diff/table_segment.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from itertools import product
55

66
from runtype import dataclass
7+
from typing_extensions import Self
78

89
from .utils import safezip, Vector
910
from data_diff.sqeleton.utils import ArithString, split_space
@@ -137,11 +138,11 @@ def __post_init__(self):
137138
def _where(self):
138139
return f"({self.where})" if self.where else None
139140

140-
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
141+
def _with_raw_schema(self, raw_schema: dict) -> Self:
141142
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
142143
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
143144

144-
def with_schema(self) -> "TableSegment":
145+
def with_schema(self) -> Self:
145146
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
146147
if self._schema:
147148
return self
@@ -194,11 +195,11 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS
194195

195196
return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)]
196197

197-
def new(self, **kwargs) -> "TableSegment":
198+
def new(self, **kwargs) -> Self:
198199
"""Creates a copy of the instance using 'replace()'"""
199200
return self.replace(**kwargs)
200201

201-
def new_key_bounds(self, min_key: Vector, max_key: Vector) -> "TableSegment":
202+
def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
202203
if self.min_key is not None:
203204
assert self.min_key <= min_key, (self.min_key, min_key)
204205
assert self.min_key < max_key

poetry.lock

+4-4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ vertica-python = {version="*", optional=true}
4646
urllib3 = "<2"
4747
oracledb = {version = "*", optional=true}
4848
pyodbc = {version="^4.0.39", optional=true}
49+
typing-extensions = ">=4.0.1"
4950

5051
[tool.poetry.dev-dependencies]
5152
parameterized = "*"

0 commit comments

Comments
 (0)