From dc1bde176fd9a93b09c8f8d50d03cab1b5a39b3f Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Tue, 19 Sep 2023 19:18:05 +0200
Subject: [PATCH 1/2] Cease usage of the self-made Self , switch to Python's
 backported one

---
 data_diff/cloud/datafold_api.py           | 5 ++---
 data_diff/sqeleton/abcs/database_types.py | 5 +++--
 data_diff/sqeleton/databases/_connect.py  | 3 ++-
 data_diff/sqeleton/databases/base.py      | 3 ++-
 data_diff/sqeleton/utils.py               | 5 -----
 poetry.lock                               | 8 ++++----
 pyproject.toml                            | 1 +
 7 files changed, 14 insertions(+), 16 deletions(-)

diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py
index b6c4531b..f0adaea3 100644
--- a/data_diff/cloud/datafold_api.py
+++ b/data_diff/cloud/datafold_api.py
@@ -2,10 +2,11 @@
 import dataclasses
 import enum
 import time
-from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple
+from typing import Any, Dict, List, Optional, Type, Tuple
 
 import pydantic
 import requests
+from typing_extensions import Self
 
 from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError
 
@@ -13,8 +14,6 @@
 
 logger = getLogger(__name__)
 
-Self = TypeVar("Self", bound=pydantic.BaseModel)
-
 
 class TestDataSourceStatus(str, enum.Enum):
     SUCCESS = "ok"
diff --git a/data_diff/sqeleton/abcs/database_types.py b/data_diff/sqeleton/abcs/database_types.py
index 9bde030b..5f48b567 100644
--- a/data_diff/sqeleton/abcs/database_types.py
+++ b/data_diff/sqeleton/abcs/database_types.py
@@ -1,11 +1,12 @@
 import decimal
 from abc import ABC, abstractmethod
-from typing import Sequence, Optional, Tuple, Union, Dict, List
+from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
 from datetime import datetime
 
 from runtype import dataclass
+from typing_extensions import Self
 
-from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
+from ..utils import ArithAlphanumeric, ArithUUID, Unknown
 
 
 DbPath = Tuple[str, ...]
diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py
index c6638d98..cdd3b2f1 100644
--- a/data_diff/sqeleton/databases/_connect.py
+++ b/data_diff/sqeleton/databases/_connect.py
@@ -5,9 +5,10 @@
 import toml
 
 from runtype import dataclass
+from typing_extensions import Self
 
 from ..abcs.mixins import AbstractMixin
-from ..utils import WeakCache, Self
+from ..utils import WeakCache
 from .base import Database, ThreadedDatabase
 from .postgresql import PostgreSQL
 from .mysql import MySQL
diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py
index 78bfe2bf..fb0a8b89 100644
--- a/data_diff/sqeleton/databases/base.py
+++ b/data_diff/sqeleton/databases/base.py
@@ -11,8 +11,9 @@
 import decimal
 
 from runtype import dataclass
+from typing_extensions import Self
 
-from ..utils import is_uuid, safezip, Self
+from ..utils import is_uuid, safezip
 from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
 from ..queries.ast_classes import Random
 from ..abcs.database_types import (
diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py
index 705d1776..ed8a05c5 100644
--- a/data_diff/sqeleton/utils.py
+++ b/data_diff/sqeleton/utils.py
@@ -21,11 +21,6 @@
 
 # -- Common --
 
-try:
-    from typing import Self
-except ImportError:
-    Self = Any
-
 
 class WeakCache:
     def __init__(self):
diff --git a/poetry.lock b/poetry.lock
index a5fbd86a..afd70d75 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1888,13 +1888,13 @@ tests = ["click", "httpretty (<1.1)", "pytest", "pytest-runner", "requests-kerbe
 
 [[package]]
 name = "typing-extensions"
-version = "4.6.3"
+version = "4.7.1"
 description = "Backported and Experimental Type Hints for Python 3.7+"
 optional = false
 python-versions = ">=3.7"
 files = [
-    {file = "typing_extensions-4.6.3-py3-none-any.whl", hash = "sha256:88a4153d8505aabbb4e13aacb7c486c2b4a33ca3b3f807914a9b4c844c471c26"},
-    {file = "typing_extensions-4.6.3.tar.gz", hash = "sha256:d91d5919357fe7f681a9f2b5b4cb2a5f1ef0a1e9f59c4d8ff0d3491e05c0ffd5"},
+    {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"},
+    {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"},
 ]
 
 [[package]]
@@ -2030,4 +2030,4 @@ vertica = ["vertica-python"]
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.7.2"
-content-hash = "b3e3febf3233c5fb0800870c84422ad8e414d369664e195b7b3d4735028ee464"
+content-hash = "55cde03a00788572dac6310e7bbf61bd2522d70217056a51608bcfc429440fbf"
diff --git a/pyproject.toml b/pyproject.toml
index bd4d80a7..da72fa1e 100755
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,6 +46,7 @@ vertica-python = {version="*", optional=true}
 urllib3 = "<2"
 oracledb = {version = "*", optional=true}
 pyodbc = {version="^4.0.39", optional=true}
+typing-extensions = ">=4.0.1"
 
 [tool.poetry.dev-dependencies]
 parameterized = "*"

From 8040b7fd3194cd775cd38f1e63f043851fd803d7 Mon Sep 17 00:00:00 2001
From: Sergey Vasilyev <sv@datafold.com>
Date: Tue, 19 Sep 2023 19:32:07 +0200
Subject: [PATCH 2/2] Annotate types for self-cloning factories as per PEP-673

---
 data_diff/cloud/datafold_api.py           |  4 ++--
 data_diff/sqeleton/bound_exprs.py         |  5 +++--
 data_diff/sqeleton/databases/_connect.py  |  2 +-
 data_diff/sqeleton/databases/base.py      |  2 +-
 data_diff/sqeleton/queries/ast_classes.py | 19 ++++++++++---------
 data_diff/sqeleton/queries/compiler.py    |  3 ++-
 data_diff/sqeleton/utils.py               | 18 ++++++++++--------
 data_diff/table_segment.py                |  9 +++++----
 8 files changed, 34 insertions(+), 28 deletions(-)

diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py
index f0adaea3..8ed1791a 100644
--- a/data_diff/cloud/datafold_api.py
+++ b/data_diff/cloud/datafold_api.py
@@ -29,7 +29,7 @@ class TCloudApiDataSourceSchema(pydantic.BaseModel):
     secret: List[str]
 
     @classmethod
-    def from_orm(cls: Type[Self], obj: Any) -> Self:
+    def from_orm(cls, obj: Any) -> Self:
         data_source_types_required_parameters = {
             "bigquery": ["projectId", "jsonKeyFile", "location"],
             "databricks": ["host", "http_password", "database", "http_path"],
@@ -153,7 +153,7 @@ class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
     dependencies: Optional[Dict[str, Any]]
 
     @classmethod
-    def from_orm(cls: Type[Self], obj: Any) -> Self:
+    def from_orm(cls, obj: Any) -> Self:
         pks = TSummaryResultPrimaryKeyStats(**obj["pks"]) if "pks" in obj else None
         values = TSummaryResultValueStats(**obj["values"]) if "values" in obj else None
         deps = obj["deps"] if "deps" in obj else None
diff --git a/data_diff/sqeleton/bound_exprs.py b/data_diff/sqeleton/bound_exprs.py
index 188efbca..7ef4dc11 100644
--- a/data_diff/sqeleton/bound_exprs.py
+++ b/data_diff/sqeleton/bound_exprs.py
@@ -5,6 +5,7 @@
 from typing import Union, TYPE_CHECKING
 
 from runtype import dataclass
+from typing_extensions import Self
 
 from .abcs import AbstractDatabase, AbstractCompiler
 from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable
@@ -52,11 +53,11 @@ class BoundTable(BoundNode):  # ITable
     database: AbstractDatabase
     node: TablePath
 
-    def with_schema(self, schema):
+    def with_schema(self, schema) -> Self:
         table_path = self.node.replace(schema=schema)
         return self.replace(node=table_path)
 
-    def query_schema(self, *, columns=None, where=None, case_sensitive=True):
+    def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self:
         table_path = self.node
 
         if table_path.schema:
diff --git a/data_diff/sqeleton/databases/_connect.py b/data_diff/sqeleton/databases/_connect.py
index cdd3b2f1..aee220dd 100644
--- a/data_diff/sqeleton/databases/_connect.py
+++ b/data_diff/sqeleton/databases/_connect.py
@@ -100,7 +100,7 @@ def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME)
         self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
         self.conn_cache = WeakCache()
 
-    def for_databases(self, *dbs):
+    def for_databases(self, *dbs) -> Self:
         database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
         return type(self)(database_by_scheme)
 
diff --git a/data_diff/sqeleton/databases/base.py b/data_diff/sqeleton/databases/base.py
index fb0a8b89..5030b24e 100644
--- a/data_diff/sqeleton/databases/base.py
+++ b/data_diff/sqeleton/databases/base.py
@@ -282,7 +282,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
         return math.floor(math.log(2**p, 10))
 
     @classmethod
-    def load_mixins(cls, *abstract_mixins) -> "Self":
+    def load_mixins(cls, *abstract_mixins) -> Self:
         mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}
 
         class _DialectWithMixins(cls, *mixins, *abstract_mixins):
diff --git a/data_diff/sqeleton/queries/ast_classes.py b/data_diff/sqeleton/queries/ast_classes.py
index aba86c70..8d892232 100644
--- a/data_diff/sqeleton/queries/ast_classes.py
+++ b/data_diff/sqeleton/queries/ast_classes.py
@@ -1,8 +1,9 @@
 from dataclasses import field
 from datetime import datetime
-from typing import Any, Generator, List, Optional, Sequence, Union, Dict
+from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict
 
 from runtype import dataclass
+from typing_extensions import Self
 
 from ..utils import join_iter, ArithString
 from ..abcs import Compilable
@@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
             return QB_When(self, whens[0])
         return QB_When(self, BinBoolOp("AND", whens))
 
-    def else_(self, then: Expr):
+    def else_(self, then: Expr) -> Self:
         """Add an 'else' clause to the case expression.
 
         Can only be called once!
@@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
     schema: Optional[Schema] = field(default=None, repr=False)
 
     @property
-    def source_table(self):
+    def source_table(self) -> Self:
         return self
 
     def compile(self, c: Compiler) -> str:
@@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
     columns: Sequence[Expr] = None
 
     @property
-    def source_table(self):
+    def source_table(self) -> Self:
         return self
 
     @property
@@ -533,7 +534,7 @@ def schema(self):
         s = self.source_tables[0].schema  # TODO validate types match between both tables
         return type(s)({c.name: c.type for c in self.columns})
 
-    def on(self, *exprs) -> "Join":
+    def on(self, *exprs) -> Self:
         """Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
         if len(exprs) == 1:
             (e,) = exprs
@@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":
 
         return self.replace(on_exprs=(self.on_exprs or []) + exprs)
 
-    def select(self, *exprs, **named_exprs) -> ITable:
+    def select(self, *exprs, **named_exprs) -> Union[Self, ITable]:
         """Select fields to return from the JOIN operation
 
         See Also: ``ITable.select()``
@@ -600,7 +601,7 @@ def source_table(self):
     def __post_init__(self):
         assert self.keys or self.values
 
-    def having(self, *exprs):
+    def having(self, *exprs) -> Self:
         """Add a 'HAVING' clause to the group-by"""
         exprs = args_as_tuple(exprs)
         exprs = _drop_skips(exprs)
@@ -610,7 +611,7 @@ def having(self, *exprs):
         resolve_names(self.table, exprs)
         return self.replace(having_exprs=(self.having_exprs or []) + exprs)
 
-    def agg(self, *exprs):
+    def agg(self, *exprs) -> Self:
         """Select aggregated fields for the group-by."""
         exprs = args_as_tuple(exprs)
         exprs = _drop_skips(exprs)
@@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:
 
         return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"
 
-    def returning(self, *exprs):
+    def returning(self, *exprs) -> Self:
         """Add a 'RETURNING' clause to the insert expression.
 
         Note: Not all databases support this feature!
diff --git a/data_diff/sqeleton/queries/compiler.py b/data_diff/sqeleton/queries/compiler.py
index 56b77c0f..1f6793ff 100644
--- a/data_diff/sqeleton/queries/compiler.py
+++ b/data_diff/sqeleton/queries/compiler.py
@@ -4,6 +4,7 @@
 from typing import Any, Dict, Sequence, List
 
 from runtype import dataclass
+from typing_extensions import Self
 
 from ..utils import ArithString
 from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable
@@ -79,7 +80,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
         self._counter[0] += 1
         return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")
 
-    def add_table_context(self, *tables: Sequence, **kw):
+    def add_table_context(self, *tables: Sequence, **kw) -> Self:
         return self.replace(_table_context=self._table_context + list(tables), **kw)
 
     def quote(self, s: str):
diff --git a/data_diff/sqeleton/utils.py b/data_diff/sqeleton/utils.py
index ed8a05c5..d356d18b 100644
--- a/data_diff/sqeleton/utils.py
+++ b/data_diff/sqeleton/utils.py
@@ -2,13 +2,13 @@
     Iterable,
     Iterator,
     MutableMapping,
+    Type,
     Union,
     Any,
     Sequence,
     Dict,
     Hashable,
     TypeVar,
-    TYPE_CHECKING,
     List,
 )
 from abc import abstractmethod
@@ -19,6 +19,8 @@
 from uuid import UUID
 from urllib.parse import urlparse
 
+from typing_extensions import Self
+
 # -- Common --
 
 
@@ -90,7 +92,7 @@ class CaseAwareMapping(MutableMapping[str, V]):
     def get_key(self, key: str) -> str:
         ...
 
-    def new(self, initial=()):
+    def new(self, initial=()) -> Self:
         return type(self)(initial)
 
 
@@ -139,10 +141,10 @@ def as_insensitive(self):
 
 class ArithString:
     @classmethod
-    def new(cls, *args, **kw):
+    def new(cls, *args, **kw) -> Self:
         return cls(*args, **kw)
 
-    def range(self, other: "ArithString", count: int):
+    def range(self, other: "ArithString", count: int) -> List[Self]:
         assert isinstance(other, ArithString)
         checkpoints = split_space(self.int, other.int, count)
         return [self.new(int=i) for i in checkpoints]
@@ -154,7 +156,7 @@ class ArithUUID(UUID, ArithString):
     def __int__(self):
         return self.int
 
-    def __add__(self, other: int):
+    def __add__(self, other: int) -> Self:
         if isinstance(other, int):
             return self.new(int=self.int + other)
         return NotImplemented
@@ -226,7 +228,7 @@ def __len__(self):
     def __repr__(self):
         return f'alphanum"{self._str}"'
 
-    def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric":
+    def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self:
         if isinstance(other, int):
             if other != 1:
                 raise NotImplementedError("not implemented for arbitrary numbers")
@@ -235,7 +237,7 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric"
 
         return NotImplemented
 
-    def range(self, other: "ArithAlphanumeric", count: int):
+    def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]:
         assert isinstance(other, ArithAlphanumeric)
         n1, n2 = alphanums_to_numbers(self._str, other._str)
         split = split_space(n1, n2, count)
@@ -263,7 +265,7 @@ def __eq__(self, other):
             return NotImplemented
         return self._str == other._str
 
-    def new(self, *args, **kw):
+    def new(self, *args, **kw) -> Self:
         return type(self)(*args, **kw, max_len=self._max_len)
 
 
diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py
index f997c8c5..4301f06f 100644
--- a/data_diff/table_segment.py
+++ b/data_diff/table_segment.py
@@ -4,6 +4,7 @@
 from itertools import product
 
 from runtype import dataclass
+from typing_extensions import Self
 
 from .utils import safezip, Vector
 from data_diff.sqeleton.utils import ArithString, split_space
@@ -137,11 +138,11 @@ def __post_init__(self):
     def _where(self):
         return f"({self.where})" if self.where else None
 
-    def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
+    def _with_raw_schema(self, raw_schema: dict) -> Self:
         schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
         return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
 
-    def with_schema(self) -> "TableSegment":
+    def with_schema(self) -> Self:
         "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
         if self._schema:
             return self
@@ -194,11 +195,11 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS
 
         return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)]
 
-    def new(self, **kwargs) -> "TableSegment":
+    def new(self, **kwargs) -> Self:
         """Creates a copy of the instance using 'replace()'"""
         return self.replace(**kwargs)
 
-    def new_key_bounds(self, min_key: Vector, max_key: Vector) -> "TableSegment":
+    def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
         if self.min_key is not None:
             assert self.min_key <= min_key, (self.min_key, min_key)
             assert self.min_key < max_key