Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Annotate types for self-cloning factories as per PEP-673 #704

Merged
merged 2 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions data_diff/cloud/datafold_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
import dataclasses
import enum
import time
from typing import Any, Dict, List, Optional, Type, TypeVar, Tuple
from typing import Any, Dict, List, Optional, Type, Tuple

import pydantic
import requests
from typing_extensions import Self

from data_diff.errors import DataDiffCloudDiffFailed, DataDiffCloudDiffTimedOut, DataDiffDatasourceIdNotFoundError

from ..utils import getLogger

logger = getLogger(__name__)

Self = TypeVar("Self", bound=pydantic.BaseModel)


class TestDataSourceStatus(str, enum.Enum):
SUCCESS = "ok"
Expand All @@ -30,7 +29,7 @@ class TCloudApiDataSourceSchema(pydantic.BaseModel):
secret: List[str]

@classmethod
def from_orm(cls: Type[Self], obj: Any) -> Self:
def from_orm(cls, obj: Any) -> Self:
data_source_types_required_parameters = {
"bigquery": ["projectId", "jsonKeyFile", "location"],
"databricks": ["host", "http_password", "database", "http_path"],
Expand Down Expand Up @@ -154,7 +153,7 @@ class TCloudApiDataDiffSummaryResult(pydantic.BaseModel):
dependencies: Optional[Dict[str, Any]]

@classmethod
def from_orm(cls: Type[Self], obj: Any) -> Self:
def from_orm(cls, obj: Any) -> Self:
pks = TSummaryResultPrimaryKeyStats(**obj["pks"]) if "pks" in obj else None
values = TSummaryResultValueStats(**obj["values"]) if "values" in obj else None
deps = obj["deps"] if "deps" in obj else None
Expand Down
5 changes: 3 additions & 2 deletions data_diff/sqeleton/abcs/database_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import decimal
from abc import ABC, abstractmethod
from typing import Sequence, Optional, Tuple, Union, Dict, List
from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
from datetime import datetime

from runtype import dataclass
from typing_extensions import Self

from ..utils import ArithAlphanumeric, ArithUUID, Self, Unknown
from ..utils import ArithAlphanumeric, ArithUUID, Unknown


DbPath = Tuple[str, ...]
Expand Down
5 changes: 3 additions & 2 deletions data_diff/sqeleton/bound_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union, TYPE_CHECKING

from runtype import dataclass
from typing_extensions import Self

from .abcs import AbstractDatabase, AbstractCompiler
from .queries.ast_classes import ExprNode, ITable, TablePath, Compilable
Expand Down Expand Up @@ -52,11 +53,11 @@ class BoundTable(BoundNode): # ITable
database: AbstractDatabase
node: TablePath

def with_schema(self, schema):
def with_schema(self, schema) -> Self:
table_path = self.node.replace(schema=schema)
return self.replace(node=table_path)

def query_schema(self, *, columns=None, where=None, case_sensitive=True):
def query_schema(self, *, columns=None, where=None, case_sensitive=True) -> Self:
table_path = self.node

if table_path.schema:
Expand Down
5 changes: 3 additions & 2 deletions data_diff/sqeleton/databases/_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import toml

from runtype import dataclass
from typing_extensions import Self

from ..abcs.mixins import AbstractMixin
from ..utils import WeakCache, Self
from ..utils import WeakCache
from .base import Database, ThreadedDatabase
from .postgresql import PostgreSQL
from .mysql import MySQL
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME)
self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()}
self.conn_cache = WeakCache()

def for_databases(self, *dbs):
def for_databases(self, *dbs) -> Self:
database_by_scheme = {k: db for k, db in self.database_by_scheme.items() if k in dbs}
return type(self)(database_by_scheme)

Expand Down
5 changes: 3 additions & 2 deletions data_diff/sqeleton/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import decimal

from runtype import dataclass
from typing_extensions import Self

from ..utils import is_uuid, safezip, Self
from ..utils import is_uuid, safezip
from ..queries import Expr, Compiler, table, Select, SKIP, Explain, Code, this
from ..queries.ast_classes import Random
from ..abcs.database_types import (
Expand Down Expand Up @@ -281,7 +282,7 @@ def _convert_db_precision_to_digits(self, p: int) -> int:
return math.floor(math.log(2**p, 10))

@classmethod
def load_mixins(cls, *abstract_mixins) -> "Self":
def load_mixins(cls, *abstract_mixins) -> Self:
mixins = {m for m in cls.MIXINS if issubclass(m, abstract_mixins)}

class _DialectWithMixins(cls, *mixins, *abstract_mixins):
Expand Down
19 changes: 10 additions & 9 deletions data_diff/sqeleton/queries/ast_classes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import field
from datetime import datetime
from typing import Any, Generator, List, Optional, Sequence, Union, Dict
from typing import Any, Generator, List, Optional, Sequence, Type, Union, Dict

from runtype import dataclass
from typing_extensions import Self

from ..utils import join_iter, ArithString
from ..abcs import Compilable
Expand Down Expand Up @@ -322,7 +323,7 @@ def when(self, *whens: Expr) -> "QB_When":
return QB_When(self, whens[0])
return QB_When(self, BinBoolOp("AND", whens))

def else_(self, then: Expr):
def else_(self, then: Expr) -> Self:
"""Add an 'else' clause to the case expression.

Can only be called once!
Expand Down Expand Up @@ -422,7 +423,7 @@ class TablePath(ExprNode, ITable):
schema: Optional[Schema] = field(default=None, repr=False)

@property
def source_table(self):
def source_table(self) -> Self:
return self

def compile(self, c: Compiler) -> str:
Expand Down Expand Up @@ -524,7 +525,7 @@ class Join(ExprNode, ITable, Root):
columns: Sequence[Expr] = None

@property
def source_table(self):
def source_table(self) -> Self:
return self

@property
Expand All @@ -533,7 +534,7 @@ def schema(self):
s = self.source_tables[0].schema # TODO validate types match between both tables
return type(s)({c.name: c.type for c in self.columns})

def on(self, *exprs) -> "Join":
def on(self, *exprs) -> Self:
"""Add an ON clause, for filtering the result of the cartesian product (i.e. the JOIN)"""
if len(exprs) == 1:
(e,) = exprs
Expand All @@ -546,7 +547,7 @@ def on(self, *exprs) -> "Join":

return self.replace(on_exprs=(self.on_exprs or []) + exprs)

def select(self, *exprs, **named_exprs) -> ITable:
def select(self, *exprs, **named_exprs) -> Union[Self, ITable]:
"""Select fields to return from the JOIN operation

See Also: ``ITable.select()``
Expand Down Expand Up @@ -600,7 +601,7 @@ def source_table(self):
def __post_init__(self):
assert self.keys or self.values

def having(self, *exprs):
def having(self, *exprs) -> Self:
"""Add a 'HAVING' clause to the group-by"""
exprs = args_as_tuple(exprs)
exprs = _drop_skips(exprs)
Expand All @@ -610,7 +611,7 @@ def having(self, *exprs):
resolve_names(self.table, exprs)
return self.replace(having_exprs=(self.having_exprs or []) + exprs)

def agg(self, *exprs):
def agg(self, *exprs) -> Self:
"""Select aggregated fields for the group-by."""
exprs = args_as_tuple(exprs)
exprs = _drop_skips(exprs)
Expand Down Expand Up @@ -991,7 +992,7 @@ def compile(self, c: Compiler) -> str:

return f"INSERT INTO {c.compile(self.path)}{columns} {expr}"

def returning(self, *exprs):
def returning(self, *exprs) -> Self:
"""Add a 'RETURNING' clause to the insert expression.

Note: Not all databases support this feature!
Expand Down
3 changes: 2 additions & 1 deletion data_diff/sqeleton/queries/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, Sequence, List

from runtype import dataclass
from typing_extensions import Self

from ..utils import ArithString
from ..abcs import AbstractDatabase, AbstractDialect, DbPath, AbstractCompiler, Compilable
Expand Down Expand Up @@ -79,7 +80,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
self._counter[0] += 1
return self.database.parse_table_name(f"{prefix}{self._counter[0]}_{'%x'%random.randrange(2**32)}")

def add_table_context(self, *tables: Sequence, **kw):
def add_table_context(self, *tables: Sequence, **kw) -> Self:
return self.replace(_table_context=self._table_context + list(tables), **kw)

def quote(self, s: str):
Expand Down
23 changes: 10 additions & 13 deletions data_diff/sqeleton/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
Iterable,
Iterator,
MutableMapping,
Type,
Union,
Any,
Sequence,
Dict,
Hashable,
TypeVar,
TYPE_CHECKING,
List,
)
from abc import abstractmethod
Expand All @@ -19,12 +19,9 @@
from uuid import UUID
from urllib.parse import urlparse

# -- Common --
from typing_extensions import Self

try:
from typing import Self
except ImportError:
Self = Any
# -- Common --


class WeakCache:
Expand Down Expand Up @@ -95,7 +92,7 @@ class CaseAwareMapping(MutableMapping[str, V]):
def get_key(self, key: str) -> str:
...

def new(self, initial=()):
def new(self, initial=()) -> Self:
return type(self)(initial)


Expand Down Expand Up @@ -144,10 +141,10 @@ def as_insensitive(self):

class ArithString:
@classmethod
def new(cls, *args, **kw):
def new(cls, *args, **kw) -> Self:
return cls(*args, **kw)

def range(self, other: "ArithString", count: int):
def range(self, other: "ArithString", count: int) -> List[Self]:
assert isinstance(other, ArithString)
checkpoints = split_space(self.int, other.int, count)
return [self.new(int=i) for i in checkpoints]
Expand All @@ -159,7 +156,7 @@ class ArithUUID(UUID, ArithString):
def __int__(self):
return self.int

def __add__(self, other: int):
def __add__(self, other: int) -> Self:
if isinstance(other, int):
return self.new(int=self.int + other)
return NotImplemented
Expand Down Expand Up @@ -231,7 +228,7 @@ def __len__(self):
def __repr__(self):
return f'alphanum"{self._str}"'

def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric":
def __add__(self, other: "Union[ArithAlphanumeric, int]") -> Self:
if isinstance(other, int):
if other != 1:
raise NotImplementedError("not implemented for arbitrary numbers")
Expand All @@ -240,7 +237,7 @@ def __add__(self, other: "Union[ArithAlphanumeric, int]") -> "ArithAlphanumeric"

return NotImplemented

def range(self, other: "ArithAlphanumeric", count: int):
def range(self, other: "ArithAlphanumeric", count: int) -> List[Self]:
assert isinstance(other, ArithAlphanumeric)
n1, n2 = alphanums_to_numbers(self._str, other._str)
split = split_space(n1, n2, count)
Expand Down Expand Up @@ -268,7 +265,7 @@ def __eq__(self, other):
return NotImplemented
return self._str == other._str

def new(self, *args, **kw):
def new(self, *args, **kw) -> Self:
return type(self)(*args, **kw, max_len=self._max_len)


Expand Down
9 changes: 5 additions & 4 deletions data_diff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from itertools import product

from runtype import dataclass
from typing_extensions import Self

from .utils import safezip, Vector
from data_diff.sqeleton.utils import ArithString, split_space
Expand Down Expand Up @@ -137,11 +138,11 @@ def __post_init__(self):
def _where(self):
return f"({self.where})" if self.where else None

def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
def _with_raw_schema(self, raw_schema: dict) -> Self:
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))

def with_schema(self) -> "TableSegment":
def with_schema(self) -> Self:
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
if self._schema:
return self
Expand Down Expand Up @@ -194,11 +195,11 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS

return [self.new_key_bounds(min_key=s, max_key=e) for s, e in create_mesh_from_points(*checkpoints)]

def new(self, **kwargs) -> "TableSegment":
def new(self, **kwargs) -> Self:
"""Creates a copy of the instance using 'replace()'"""
return self.replace(**kwargs)

def new_key_bounds(self, min_key: Vector, max_key: Vector) -> "TableSegment":
def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self:
if self.min_key is not None:
assert self.min_key <= min_key, (self.min_key, min_key)
assert self.min_key < max_key
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ vertica-python = {version="*", optional=true}
urllib3 = "<2"
oracledb = {version = "*", optional=true}
pyodbc = {version="^4.0.39", optional=true}
typing-extensions = ">=4.0.1"

[tool.poetry.dev-dependencies]
parameterized = "*"
Expand Down