From dc1bde176fd9a93b09c8f8d50d03cab1b5a39b3f Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev <sv@datafold.com> Date: Tue, 19 Sep 2023 19:18:05 +0200 Subject: [PATCH 1/2] Cease usage of the self-made Self , switch to Python's backported one --- data_diff/cloud/datafold_api.py | 5 ++--- data_diff/sqeleton/abcs/database_types.py | 5 +++-- data_diff/sqeleton/databases/_connect.py | 3 ++- data_diff/sqeleton/databases/base.py | 3 ++- data_diff/sqeleton/utils.py | 5 ----- poetry.lock | 8 ++++---- pyproject.toml | 1 + 7 files changed, 14 insertions(+), 16 deletions(-) diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index b6c4531b..f0adaea3 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -2,10 +2,11 @@ import dataclasses import enum import time -from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple +from typing import Any, Dict, List, Optional, Type, Tuple import pydantic import requests +from typing_extensions import Self from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError @@ -13,8 +14,6 @@ logger = getLogger(__name__) -Self = TypeVar("Self", bound=pydantic.BaseModel) - class TestDataSourceStatus(str, enum.Enum): SUCCESS = "ok" diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py index 9bde030b..5f48b567 100644 --- a/data_diff/sqeleton/abcs/database_types.py +++ b/data_diff/sqeleton/abcs/database_types.py @@ -1,11 +1,12 @@ import decimal from abc import ABC, abstractmethod -from typing import Sequence, Optional, Tuple, Union, Dict, List +from typing import Sequence, Optional, Tuple, Type, Union, Dict, List from datetime import datetime from runtype import dataclass +from typing_extensions import Self -from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown +from ..utils import ArithAlphanumeric, ArithUUID, Unknown DbPath = Tuple[str, ...] diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index c6638d98..cdd3b2f1 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -5,9 +5,10 @@ import toml from runtype import dataclass +from typing_extensions import Self from ..abcs.mixins import AbstractMixin -from ..utils import WeakCache, Self +from ..utils import WeakCache from .base import Database, ThreadedDatabase from .postgresql import PostgreSQL from .mysql import MySQL diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index 78bfe2bf..fb0a8b89 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -11,8 +11,9 @@ import decimal from runtype import dataclass +from typing_extensions import Self -from ..utils import is_uuid, safezip, Self +from ..utils import is_uuid, safezip from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this from ..queries.ast_classes import Random from ..abcs.database_types import ( diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py index 705d1776..ed8a05c5 100644 --- a/data_diff/sqeleton/utils.py +++ b/data_diff/sqeleton/utils.py @@ -21,11 +21,6 @@ # -- Common -- -try: - from typing import Self -except ImportError: - Self = Any - class WeakCache: def __init__(self): diff --git a/poetry.lock b/poetry.lock index a5fbd86a..afd70d75 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1888,13 +1888,13 @@ tests = ["click", "httpretty (<1.1)", "pytest", "pytest-runner", "requests-kerbe [[package]] name = "typing-extensions" -version = "4.6.3" +version = "4.7.1" description = "Backported and Experimental Type Hints for Python 3.7+" optional = false python-versions = ">=3.7" files = [ - {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"}, - {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"}, + {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, + {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, ] [[package]] @@ -2030,4 +2030,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "b3e3febf3233c5fb0800870c84422ad8e414d369664e195b7b3d4735028ee464" +content-hash = "55cde03a00788572dac6310e7bbf61bd2522d70217056a51608bcfc429440fbf" diff --git a/pyproject.toml b/pyproject.toml index bd4d80a7..da72fa1e 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ vertica-python = {version="*", optional=true} urllib3 = "<2" oracledb = {version = "*", optional=true} pyodbc = {version="^4.0.39", optional=true} +typing-extensions = ">=4.0.1" [tool.poetry.dev-dependencies] parameterized = "*" From 8040b7fd3194cd775cd38f1e63f043851fd803d7 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev <sv@datafold.com> Date: Tue, 19 Sep 2023 19:32:07 +0200 Subject: [PATCH 2/2] Annotate types for self-cloning factories as per PEP-673 --- data_diff/cloud/datafold_api.py | 4 ++-- data_diff/sqeleton/bound_exprs.py | 5 +++-- data_diff/sqeleton/databases/_connect.py | 2 +- data_diff/sqeleton/databases/base.py | 2 +- data_diff/sqeleton/queries/ast_classes.py | 19 ++++++++++--------- data_diff/sqeleton/queries/compiler.py | 3 ++- data_diff/sqeleton/utils.py | 18 ++++++++++-------- data_diff/table_segment.py | 9 +++++---- 8 files changed, 34 insertions(+), 28 deletions(-) diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index f0adaea3..8ed1791a 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -29,7 +29,7 @@ class TCloudApiDataSourceSchema(pydantic.BaseModel): secret: List[str] @classmethod - def from_orm(cls: Type[Self], obj: Any) -> Self: + def from_orm(cls, obj: Any) -> Self: data_source_types_required_parameters = { "bigquery": ["projectId", "jsonKeyFile", "location"], "databricks": ["host", "http_password", "database", "http_path"], @@ -153,7 +153,7 @@ class TCloudApiDataDiffSummaryResult(pydantic.BaseModel): dependencies: Optional[Dict[str, Any]] @classmethod - def from_orm(cls: Type[Self], obj: Any) -> Self: + def from_orm(cls, obj: Any) -> Self: pks = TSummaryResultPrimaryKeyStats(**obj["pks"]) if "pks" in obj else None values = TSummaryResultValueStats(**obj["values"]) if "values" in obj else None deps = obj["deps"] if "deps" in obj else None diff --git a/data_diff/sqeleton/bound_exprs.py b/data_diff/sqeleton/bound_exprs.py index 188efbca..7ef4dc11 100644 --- a/data_diff/sqeleton/bound_exprs.py +++ b/data_diff/sqeleton/bound_exprs.py @@ -5,6 +5,7 @@ from typing import Union, TYPE_CHECKING from runtype import dataclass +from typing_extensions import Self from .abcs import AbstractDatabase, AbstractCompiler from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable @@ -52,11 +53,11 @@ class BoundTable(BoundNode): # ITable database: AbstractDatabase node: TablePath - def with_schema(self, schema): + def with_schema(self, schema) -> Self: table_path = self.node.replace(schema=schema) return self.replace(node=table_path) - def query_schema(self, *, columns=None, where=None, case_sensitive=True): + def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self: table_path = self.node if table_path.schema: diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py index cdd3b2f1..aee220dd 100644 --- a/data_diff/sqeleton/databases/_connect.py +++ b/data_diff/sqeleton/databases/_connect.py @@ -100,7 +100,7 @@ def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME) self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} self.conn_cache = WeakCache() - def for_databases(self, *dbs): + def for_databases(self, *dbs) -> Self: database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs} return type(self)(database_by_scheme) diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py index fb0a8b89..5030b24e 100644 --- a/data_diff/sqeleton/databases/base.py +++ b/data_diff/sqeleton/databases/base.py @@ -282,7 +282,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int: return math.floor(math.log(2**p, 10)) @classmethod - def load_mixins(cls, *abstract_mixins) -> "Self": + def load_mixins(cls, *abstract_mixins) -> Self: mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)} class _DialectWithMixins(cls, *mixins, *abstract_mixins): diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py index aba86c70..8d892232 100644 --- a/data_diff/sqeleton/queries/ast_classes.py +++ b/data_diff/sqeleton/queries/ast_classes.py @@ -1,8 +1,9 @@ from dataclasses import field from datetime import datetime -from typing import Any, Generator, List, Optional, Sequence, Union, Dict +from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict from runtype import dataclass +from typing_extensions import Self from ..utils import join_iter, ArithString from ..abcs import Compilable @@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When": return QB_When(self, whens[0]) return QB_When(self, BinBoolOp("AND", whens)) - def else_(self, then: Expr): + def else_(self, then: Expr) -> Self: """Add an 'else' clause to the case expression. Can only be called once! @@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable): schema: Optional[Schema] = field(default=None, repr=False) @property - def source_table(self): + def source_table(self) -> Self: return self def compile(self, c: Compiler) -> str: @@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root): columns: Sequence[Expr] = None @property - def source_table(self): + def source_table(self) -> Self: return self @property @@ -533,7 +534,7 @@ def schema(self): s = self.source_tables[0].schema # TODO validate types match between both tables return type(s)({c.name: c.type for c in self.columns}) - def on(self, *exprs) -> "Join": + def on(self, *exprs) -> Self: """Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)""" if len(exprs) == 1: (e,) = exprs @@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join": return self.replace(on_exprs=(self.on_exprs or []) + exprs) - def select(self, *exprs, **named_exprs) -> ITable: + def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: """Select fields to return from the JOIN operation See Also: ``ITable.select()`` @@ -600,7 +601,7 @@ def source_table(self): def __post_init__(self): assert self.keys or self.values - def having(self, *exprs): + def having(self, *exprs) -> Self: """Add a 'HAVING' clause to the group-by""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) @@ -610,7 +611,7 @@ def having(self, *exprs): resolve_names(self.table, exprs) return self.replace(having_exprs=(self.having_exprs or []) + exprs) - def agg(self, *exprs): + def agg(self, *exprs) -> Self: """Select aggregated fields for the group-by.""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) @@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str: return f"INSERT INTO {c.compile(self.path)}{columns} {expr}" - def returning(self, *exprs): + def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. Note: Not all databases support this feature! diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/sqeleton/queries/compiler.py index 56b77c0f..1f6793ff 100644 --- a/data_diff/sqeleton/queries/compiler.py +++ b/data_diff/sqeleton/queries/compiler.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Sequence, List from runtype import dataclass +from typing_extensions import Self from ..utils import ArithString from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable @@ -79,7 +80,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: self._counter[0] += 1 return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}") - def add_table_context(self, *tables: Sequence, **kw): + def add_table_context(self, *tables: Sequence, **kw) -> Self: return self.replace(_table_context=self._table_context + list(tables), **kw) def quote(self, s: str): diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py index ed8a05c5..d356d18b 100644 --- a/data_diff/sqeleton/utils.py +++ b/data_diff/sqeleton/utils.py @@ -2,13 +2,13 @@ Iterable, Iterator, MutableMapping, + Type, Union, Any, Sequence, Dict, Hashable, TypeVar, - TYPE_CHECKING, List, ) from abc import abstractmethod @@ -19,6 +19,8 @@ from uuid import UUID from urllib.parse import urlparse +from typing_extensions import Self + # -- Common -- @@ -90,7 +92,7 @@ class CaseAwareMapping(MutableMapping[str, V]): def get_key(self, key: str) -> str: ... - def new(self, initial=()): + def new(self, initial=()) -> Self: return type(self)(initial) @@ -139,10 +141,10 @@ def as_insensitive(self): class ArithString: @classmethod - def new(cls, *args, **kw): + def new(cls, *args, **kw) -> Self: return cls(*args, **kw) - def range(self, other: "ArithString", count: int): + def range(self, other: "ArithString", count: int) -> List[Self]: assert isinstance(other, ArithString) checkpoints = split_space(self.int, other.int, count) return [self.new(int=i) for i in checkpoints] @@ -154,7 +156,7 @@ class ArithUUID(UUID, ArithString): def __int__(self): return self.int - def __add__(self, other: int): + def __add__(self, other: int) -> Self: if isinstance(other, int): return self.new(int=self.int + other) return NotImplemented @@ -226,7 +228,7 @@ def __len__(self): def __repr__(self): return f'alphanum"{self._str}"' - def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric": + def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self: if isinstance(other, int): if other != 1: raise NotImplementedError("not implemented for arbitrary numbers") @@ -235,7 +237,7 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric" return NotImplemented - def range(self, other: "ArithAlphanumeric", count: int): + def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]: assert isinstance(other, ArithAlphanumeric) n1, n2 = alphanums_to_numbers(self._str, other._str) split = split_space(n1, n2, count) @@ -263,7 +265,7 @@ def __eq__(self, other): return NotImplemented return self._str == other._str - def new(self, *args, **kw): + def new(self, *args, **kw) -> Self: return type(self)(*args, **kw, max_len=self._max_len) diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index f997c8c5..4301f06f 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -4,6 +4,7 @@ from itertools import product from runtype import dataclass +from typing_extensions import Self from .utils import safezip, Vector from data_diff.sqeleton.utils import ArithString, split_space @@ -137,11 +138,11 @@ def __post_init__(self): def _where(self): return f"({self.where})" if self.where else None - def _with_raw_schema(self, raw_schema: dict) -> "TableSegment": + def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive)) - def with_schema(self) -> "TableSegment": + def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." if self._schema: return self @@ -194,11 +195,11 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)] - def new(self, **kwargs) -> "TableSegment": + def new(self, **kwargs) -> Self: """Creates a copy of the instance using 'replace()'""" return self.replace(**kwargs) - def new_key_bounds(self, min_key: Vector, max_key: Vector) -> "TableSegment": + def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: if self.min_key is not None: assert self.min_key <= min_key, (self.min_key, min_key) assert self.min_key < max_key