From 46e87b7a483618627551447996444dd63f3411b0 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 2 Oct 2023 13:07:06 +0200 Subject: [PATCH 1/7] Restore MRO calls to inherited constructors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Always, always, always call the inherited constructor — even if it does nothing. If not called, it breaks the MRO chain of inheritance and complicates the mixin/class building. This must be a linting rule. --- data_diff/databases/_connect.py | 1 + data_diff/databases/base.py | 2 ++ data_diff/databases/bigquery.py | 1 + data_diff/databases/databricks.py | 2 +- data_diff/databases/duckdb.py | 1 + data_diff/databases/mysql.py | 3 +-- data_diff/databases/oracle.py | 4 +--- data_diff/databases/postgresql.py | 3 +-- data_diff/databases/presto.py | 1 + data_diff/databases/snowflake.py | 1 + data_diff/databases/trino.py | 1 + data_diff/databases/vertica.py | 3 +-- data_diff/dbt_parser.py | 2 ++ data_diff/lexicographic_space.py | 3 +++ data_diff/thread_utils.py | 1 + data_diff/utils.py | 3 +++ tests/test_database_types.py | 7 +++++++ 17 files changed, 29 insertions(+), 10 deletions(-) diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 9abb7d54..cab8b235 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -98,6 +98,7 @@ class Connect: conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): + super().__init__() self.database_by_scheme = database_by_scheme self.match_uri_path = {name: MatchUriPath(cls) for name, cls in database_by_scheme.items()} self.conn_cache = weakref.WeakValueDictionary() diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 86ece124..9e4f39dc 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -180,6 +180,7 @@ class ThreadLocalInterpreter: """ def __init__(self, compiler: Compiler, gen: Generator): + super().__init__() self.gen = gen self.compiler = compiler @@ -1109,6 +1110,7 @@ class ThreadedDatabase(Database): """ def __init__(self, thread_count=1): + super().__init__() self._init_error = None self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) self.thread_local = threading.local() diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 15d60511..aaf2dbd3 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -224,6 +224,7 @@ class BigQuery(Database): dialect = Dialect() def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): + super().__init__() credentials = bigquery_credentials bigquery = import_bigquery() diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index eba88248..c440c54c 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -104,12 +104,12 @@ class Databricks(ThreadedDatabase): CONNECT_URI_PARAMS = ["catalog", "schema"] def __init__(self, *, thread_count, **kw): + super().__init__(thread_count=thread_count) logging.getLogger("databricks.sql").setLevel(logging.WARNING) self._args = kw self.default_schema = kw.get("schema", "default") self.catalog = self._args.get("catalog", "hive_metastore") - super().__init__(thread_count=thread_count) def create_connection(self): databricks = import_databricks() diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index ca9f5733..b0e061ee 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -139,6 +139,7 @@ class DuckDB(Database): CONNECT_URI_PARAMS = ["database", "dbpath"] def __init__(self, **kw): + super().__init__() self._args = kw self._conn = self.create_connection() diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index e32d34af..a69e6e46 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -138,9 +138,8 @@ class MySQL(ThreadedDatabase): CONNECT_URI_PARAMS = ["database?"] def __init__(self, *, thread_count, **kw): - self._args = kw - super().__init__(thread_count=thread_count) + self._args = kw # In MySQL schema and database are synonymous try: diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 3b4940c5..44681a45 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -182,12 +182,10 @@ class Oracle(ThreadedDatabase): CONNECT_URI_PARAMS = ["database?"] def __init__(self, *, host, database, thread_count, **kw): + super().__init__(thread_count=thread_count) self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) - self.default_schema = kw.get("user").upper() - super().__init__(thread_count=thread_count) - def create_connection(self): self._oracle = import_oracle() try: diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index f3495c99..16d6a1d1 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -128,9 +128,8 @@ class PostgreSQL(ThreadedDatabase): default_schema = "public" def __init__(self, *, thread_count, **kw): - self._args = kw - super().__init__(thread_count=thread_count) + self._args = kw def create_connection(self): if not self._args: diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 0ba4e09d..a829df95 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -161,6 +161,7 @@ class Presto(Database): default_schema = "public" def __init__(self, **kw): + super().__init__() prestodb = import_presto() if kw.get("schema"): diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index f3a70b76..31f85492 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -155,6 +155,7 @@ class Snowflake(Database): CONNECT_URI_KWPARAMS = ["warehouse"] def __init__(self, *, schema: str, **kw): + super().__init__() snowflake, serialization, default_backend = import_snowflake() logging.getLogger("snowflake.connector").setLevel(logging.WARNING) diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index e2095758..ab4913d4 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -40,6 +40,7 @@ class Trino(presto.Presto): CONNECT_URI_PARAMS = ["catalog", "schema"] def __init__(self, **kw): + super().__init__() trino = import_trino() if kw.get("schema"): diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index fc9edd04..f539a4df 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -159,11 +159,10 @@ class Vertica(ThreadedDatabase): default_schema = "public" def __init__(self, *, thread_count, **kw): + super().__init__(thread_count=thread_count) self._args = kw self._args["AUTOCOMMIT"] = False - super().__init__(thread_count=thread_count) - def create_connection(self): vertica = import_vertica() try: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index d5976f74..fcb6ce24 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -101,6 +101,8 @@ def __init__( project_dir_override: Optional[str] = None, state: Optional[str] = None, ) -> None: + super().__init__() + try_set_dbt_flags() self.dbt_runner = try_get_dbt_runner() self.project_dir = Path(project_dir_override or default_project_dir()) diff --git a/data_diff/lexicographic_space.py b/data_diff/lexicographic_space.py index 88cf863d..7ef80686 100644 --- a/data_diff/lexicographic_space.py +++ b/data_diff/lexicographic_space.py @@ -63,6 +63,7 @@ class LexicographicSpace: """ def __init__(self, dims: Vector): + super().__init__() self.dims = dims def __contains__(self, v: Vector): @@ -120,6 +121,8 @@ class BoundedLexicographicSpace: """ def __init__(self, min_bound: Vector, max_bound: Vector): + super().__init__() + dims = tuple(mx - mn for mn, mx in safezip(min_bound, max_bound)) if not all(d >= 0 for d in dims): raise ValueError("Error: Negative dimension!") diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1be94ad4..c5526771 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -46,6 +46,7 @@ class ThreadedYielder(Iterable): """ def __init__(self, max_workers: Optional[int] = None): + super().__init__() self._pool = PriorityThreadPoolExecutor(max_workers) self._futures = deque() self._yield = deque() diff --git a/data_diff/utils.py b/data_diff/utils.py index a3ce90cb..558a18e9 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -73,6 +73,7 @@ def new(self, initial=()) -> Self: class CaseInsensitiveDict(CaseAwareMapping): def __init__(self, initial): + super().__init__() self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} def __getitem__(self, key: str) -> V: @@ -175,6 +176,8 @@ def alphanums_to_numbers(s1: str, s2: str): class ArithAlphanumeric(ArithString): def __init__(self, s: str, max_len=None): + super().__init__() + if s is None: raise ValueError("Alphanum string cannot be None") if max_len and len(s) > max_len: diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 6e0e5215..70fd01aa 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -367,6 +367,7 @@ class PaginatedTable: RECORDS_PER_BATCH = 1000000 def __init__(self, table_path, conn): + super().__init__() self.table_path = table_path self.conn = conn @@ -398,6 +399,7 @@ class DateTimeFaker: ] def __init__(self, max): + super().__init__() self.max = max def __iter__(self): @@ -413,6 +415,7 @@ class IntFaker: MANUAL_FAKES = [127, -3, -9, 37, 15, 0] def __init__(self, max): + super().__init__() self.max = max def __iter__(self): @@ -428,6 +431,7 @@ class BooleanFaker: MANUAL_FAKES = [False, True, True, False] def __init__(self, max): + super().__init__() self.max = max def __iter__(self): @@ -458,6 +462,7 @@ class FloatFaker: ] def __init__(self, max): + super().__init__() self.max = max def __iter__(self): @@ -471,6 +476,7 @@ def __len__(self): class UUID_Faker: def __init__(self, max): + super().__init__() self.max = max def __len__(self): @@ -486,6 +492,7 @@ class JsonFaker: ] def __init__(self, max): + super().__init__() self.max = max def __iter__(self): From 48b29553850f28d7aefe5482b979073ac64173d2 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 15:21:19 +0200 Subject: [PATCH 2/7] Convert class-level constants to properties for attrs compatibility `attrs` cannot use multiple inheritance when both parents introduce their attributes (as documented). Only one side can inherit the attributes, other bases must be pure interfaces/protocols. Reimplement the `ExprNode.type` via properties to exclude it from the sight of `attrs`. --- data_diff/abcs/database_types.py | 12 +++++++--- data_diff/queries/ast_classes.py | 40 +++++++++++++++++++++++++------- data_diff/queries/extras.py | 7 ++++-- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 43764b39..e5ec393a 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -15,7 +15,9 @@ @dataclass class ColType: - supported = True + @property + def supported(self) -> bool: + return True @dataclass @@ -132,7 +134,9 @@ def make_value(self, value): @dataclass class Text(StringType): - supported = False + @property + def supported(self) -> bool: + return False # In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT. @@ -169,4 +173,6 @@ def __post_init__(self): class UnknownColType(ColType): text: str - supported = False + @property + def supported(self) -> bool: + return False diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 4c5c45f4..1d841237 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -28,7 +28,9 @@ class Root: class ExprNode(Compilable): "Base class for query expression nodes" - type: Any = None + @property + def type(self) -> Optional[type]: + return None def _dfs_values(self): yield self @@ -216,7 +218,10 @@ class Concat(ExprNode): class Count(ExprNode): expr: Expr = None distinct: bool = False - type = int + + @property + def type(self) -> Optional[type]: + return int class LazyOps: @@ -337,7 +342,10 @@ def then(self, then: Expr) -> CaseWhen: class IsDistinctFrom(ExprNode, LazyOps): a: Expr b: Expr - type = bool + + @property + def type(self) -> Optional[type]: + return bool @dataclass(eq=False, order=False) @@ -361,7 +369,9 @@ class UnaryOp(ExprNode, LazyOps): class BinBoolOp(BinOp): - type = bool + @property + def type(self) -> Optional[type]: + return bool @dataclass(eq=False, order=False) @@ -700,7 +710,10 @@ def __getitem__(self, name): class In(ExprNode): expr: Expr list: Sequence[Expr] - type = bool + + @property + def type(self) -> Optional[type]: + return bool @dataclass @@ -711,7 +724,9 @@ class Cast(ExprNode): @dataclass class Random(ExprNode, LazyOps): - type = float + @property + def type(self) -> Optional[type]: + return float @dataclass @@ -722,11 +737,16 @@ class ConstantTable(ExprNode): @dataclass class Explain(ExprNode, Root): select: Select - type = str + + @property + def type(self) -> Optional[type]: + return str class CurrentTimestamp(ExprNode): - type = datetime + @property + def type(self) -> Optional[type]: + return datetime @dataclass @@ -742,7 +762,9 @@ class TimeTravel(ITable): class Statement(Compilable, Root): - type = None + @property + def type(self) -> Optional[type]: + return None @dataclass diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index bb0c8299..556325f6 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,5 +1,5 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" -from typing import Callable, Sequence +from typing import Callable, Optional, Sequence from runtype import dataclass from data_diff.abcs.database_types import ColType @@ -11,7 +11,10 @@ class NormalizeAsString(ExprNode): expr: ExprNode expr_type: ColType = None - type = str + + @property + def type(self) -> Optional[type]: + return str @dataclass From 5c10704ca808f8e71db288a1d19ef888e3abc4d4 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 23:00:03 +0200 Subject: [PATCH 3/7] Convert source_table & schema to overridable properties for attrs compatibility --- data_diff/queries/ast_classes.py | 61 ++++++++++++++------------------ 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 1d841237..2c11f79a 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -83,8 +83,13 @@ def _drop_skips_dict(exprs_dict): class ITable: - source_table: Any - schema: Schema = None + @property + def source_table(self) -> "ITable": # not always Self, it can be a substitute + return self + + @property + def schema(self) -> Optional[Schema]: + return None def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> "ITable": """Choose new columns, based on the old ones. (aka Projection) @@ -389,11 +394,7 @@ def type(self): @dataclass class TablePath(ExprNode, ITable): path: DbPath - schema: Optional[Schema] = field(default=None, repr=False) - - @property - def source_table(self) -> Self: - return self + schema: Optional[Schema] = None # overrides the inherited property # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): @@ -475,9 +476,17 @@ def time_travel( @dataclass class TableAlias(ExprNode, ITable): - source_table: ITable + table: ITable name: str + @property + def source_table(self) -> ITable: + return self.table + + @property + def schema(self) -> Schema: + return self.table.schema + @dataclass class Join(ExprNode, ITable, Root): @@ -487,11 +496,7 @@ class Join(ExprNode, ITable, Root): columns: Sequence[Expr] = None @property - def source_table(self) -> Self: - return self - - @property - def schema(self): + def schema(self) -> Schema: assert self.columns # TODO Implement SELECT * s = self.source_tables[0].schema # TODO validate types match between both tables return type(s)({c.name: c.type for c in self.columns}) @@ -533,10 +538,6 @@ class GroupBy(ExprNode, ITable, Root): values: Sequence[Expr] = None having_exprs: Sequence[Expr] = None - @property - def source_table(self): - return self - def __post_init__(self): assert self.keys or self.values @@ -564,17 +565,13 @@ class TableOp(ExprNode, ITable, Root): table1: ITable table2: ITable - @property - def source_table(self): - return self - @property def type(self): # TODO ensure types of both tables are compatible return self.table1.type @property - def schema(self): + def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema assert len(s1) == len(s2) @@ -594,16 +591,12 @@ class Select(ExprNode, ITable, Root): optimizer_hints: Sequence[Expr] = None @property - def schema(self): + def schema(self) -> Schema: s = self.table.schema if s is None or self.columns is None: return s return type(s)({c.name: c.type for c in self.columns}) - @property - def source_table(self): - return self - @classmethod def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs @@ -642,14 +635,18 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, @dataclass class Cte(ExprNode, ITable): - source_table: Expr + table: Expr name: str = None params: Sequence[str] = None @property - def schema(self): + def source_table(self) -> "ITable": + return self.table + + @property + def schema(self) -> Schema: # TODO add cte to schema - return self.source_table.schema + return self.table.schema def _named_exprs_as_aliases(named_exprs): @@ -820,7 +817,3 @@ class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" name: str - - @property - def source_table(self): - return self From d7d0c86c0071e1f401cad1f61dbeba41913e446c Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 15:57:19 +0200 Subject: [PATCH 4/7] Clarify optional (nullable) fields: explicit is better than implicit We are going to do strict type checking. The default values of fields that clearly contradict the declared types is an error for MyPy and all other type checkers and IDEs. Remove the implicit behaviour and make nullable fields explicitly declared as such. --- data_diff/databases/base.py | 2 +- data_diff/info_tree.py | 10 +++--- data_diff/joindiff_tables.py | 2 +- data_diff/queries/ast_classes.py | 54 ++++++++++++++++---------------- data_diff/queries/extras.py | 4 +-- data_diff/table_segment.py | 18 +++++------ 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 9e4f39dc..e9f60653 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -815,7 +815,7 @@ def set_timezone_to_utc(self) -> str: @dataclass class QueryResult: rows: list - columns: list = None + columns: Optional[list] = None def __iter__(self): return iter(self.rows) diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index bd2282ba..b30ba2f2 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -10,13 +10,13 @@ class SegmentInfo: tables: List[TableSegment] - diff: List[Union[Tuple[Any, ...], List[Any]]] = None - diff_schema: Tuple[Tuple[str, type], ...] = None - is_diff: bool = None - diff_count: int = None + diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None + diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None + is_diff: Optional[bool] = None + diff_count: Optional[int] = None rowcounts: Dict[int, int] = field(default_factory=dict) - max_rows: int = None + max_rows: Optional[int] = None def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): self.diff_schema = schema diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 91e2aecd..14834ffd 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -138,7 +138,7 @@ class JoinDiffer(TableDiffer): validate_unique_key: bool = True sample_exclusive_rows: bool = False - materialize_to_table: DbPath = None + materialize_to_table: Optional[DbPath] = None materialize_all_rows: bool = False table_write_limit: int = TABLE_WRITE_LIMIT skip_null_keys: bool = False diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 2c11f79a..710f1316 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -55,7 +55,7 @@ def cast_to(self, to): @dataclass class Code(ExprNode, Root): code: str - args: Dict[str, Expr] = None + args: Optional[Dict[str, Expr]] = None def _expr_type(e: Expr) -> type: @@ -216,7 +216,7 @@ def intersect(self, other: "ITable"): @dataclass class Concat(ExprNode): exprs: list - sep: str = None + sep: Optional[str] = None @dataclass @@ -293,7 +293,7 @@ class WhenThen(ExprNode): @dataclass class CaseWhen(ExprNode): cases: Sequence[WhenThen] - else_expr: Expr = None + else_expr: Optional[Expr] = None @property def type(self): @@ -491,9 +491,9 @@ def schema(self) -> Schema: @dataclass class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] - op: str = None - on_exprs: Sequence[Expr] = None - columns: Sequence[Expr] = None + op: Optional[str] = None + on_exprs: Optional[Sequence[Expr]] = None + columns: Optional[Sequence[Expr]] = None @property def schema(self) -> Schema: @@ -534,9 +534,9 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: @dataclass class GroupBy(ExprNode, ITable, Root): table: ITable - keys: Sequence[Expr] = None # IKey? - values: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None + keys: Optional[Sequence[Expr]] = None # IKey? + values: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None def __post_init__(self): assert self.keys or self.values @@ -580,15 +580,15 @@ def schema(self) -> Schema: @dataclass class Select(ExprNode, ITable, Root): - table: Expr = None - columns: Sequence[Expr] = None - where_exprs: Sequence[Expr] = None - order_by_exprs: Sequence[Expr] = None - group_by_exprs: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None - limit_expr: int = None + table: Optional[Expr] = None + columns: Optional[Sequence[Expr]] = None + where_exprs: Optional[Sequence[Expr]] = None + order_by_exprs: Optional[Sequence[Expr]] = None + group_by_exprs: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None + limit_expr: Optional[int] = None distinct: bool = False - optimizer_hints: Sequence[Expr] = None + optimizer_hints: Optional[Sequence[Expr]] = None @property def schema(self) -> Schema: @@ -636,8 +636,8 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, @dataclass class Cte(ExprNode, ITable): table: Expr - name: str = None - params: Sequence[str] = None + name: Optional[str] = None + params: Optional[Sequence[str]] = None @property def source_table(self) -> "ITable": @@ -667,7 +667,7 @@ def resolve_names(source_table, exprs): @dataclass(frozen=False, eq=False, order=False) class _ResolveColumn(ExprNode, LazyOps): resolve_name: str - resolved: Expr = None + resolved: Optional[Expr] = None def resolve(self, expr: Expr): if self.resolved is not None: @@ -750,9 +750,9 @@ def type(self) -> Optional[type]: class TimeTravel(ITable): table: TablePath before: bool = False - timestamp: datetime = None - offset: int = None - statement: str = None + timestamp: Optional[datetime] = None + offset: Optional[int] = None + statement: Optional[str] = None # DDL @@ -767,9 +767,9 @@ def type(self) -> Optional[type]: @dataclass class CreateTable(Statement): path: TablePath - source_table: Expr = None + source_table: Optional[Expr] = None if_not_exists: bool = False - primary_keys: List[str] = None + primary_keys: Optional[List[str]] = None @dataclass @@ -787,8 +787,8 @@ class TruncateTable(Statement): class InsertToTable(Statement): path: TablePath expr: Expr - columns: List[str] = None - returning_exprs: List[str] = None + columns: Optional[List[str]] = None + returning_exprs: Optional[List[str]] = None def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 556325f6..4467bd0a 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -10,7 +10,7 @@ @dataclass class NormalizeAsString(ExprNode): expr: ExprNode - expr_type: ColType = None + expr_type: Optional[ColType] = None @property def type(self) -> Optional[type]: @@ -20,7 +20,7 @@ def type(self) -> Optional[type]: @dataclass class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode - apply_func: Callable = None + apply_func: Optional[Callable] = None @dataclass diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index aaf747f6..015e5bc4 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -112,18 +112,18 @@ class TableSegment: # Columns key_columns: Tuple[str, ...] - update_column: str = None + update_column: Optional[str] = None extra_columns: Tuple[str, ...] = () # Restrict the segment - min_key: Vector = None - max_key: Vector = None - min_update: DbTime = None - max_update: DbTime = None - where: str = None - - case_sensitive: bool = True - _schema: Schema = None + min_key: Optional[Vector] = None + max_key: Optional[Vector] = None + min_update: Optional[DbTime] = None + max_update: Optional[DbTime] = None + where: Optional[str] = None + + case_sensitive: Optional[bool] = True + _schema: Optional[Schema] = None def __post_init__(self): if not self.update_column and (self.min_update or self.max_update): From 1e0f3100abe252d65990cec1a83257e642fd4309 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Mon, 2 Oct 2023 12:31:15 +0200 Subject: [PATCH 5/7] Annotate missing fields --- data_diff/databases/_connect.py | 2 ++ data_diff/databases/base.py | 28 +++++++++++++++++----------- data_diff/databases/bigquery.py | 4 ++++ data_diff/databases/clickhouse.py | 4 +++- data_diff/databases/databricks.py | 7 +++++-- data_diff/databases/duckdb.py | 7 +++++-- data_diff/databases/mssql.py | 6 ++++-- data_diff/databases/mysql.py | 4 ++++ data_diff/databases/oracle.py | 4 +++- data_diff/databases/postgresql.py | 8 +++++--- data_diff/databases/presto.py | 4 +++- data_diff/databases/redshift.py | 5 +++-- data_diff/databases/snowflake.py | 4 +++- data_diff/databases/trino.py | 4 ++++ data_diff/databases/vertica.py | 5 +++-- data_diff/thread_utils.py | 5 +++++ data_diff/utils.py | 5 ++++- 17 files changed, 77 insertions(+), 29 deletions(-) diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index cab8b235..082e8fab 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -95,6 +95,8 @@ def match_path(self, dsn): class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" + database_by_scheme: Dict[str, Database] + match_uri_path: Dict[str, MatchUriPath] conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index e9f60653..3e336119 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -5,7 +5,7 @@ import math import sys import logging -from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar +from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -179,6 +179,9 @@ class ThreadLocalInterpreter: Useful for cursor-sensitive operations, such as creating a temporary table. """ + compiler: Compiler + gen: Generator + def __init__(self, compiler: Compiler, gen: Generator): super().__init__() self.gen = gen @@ -238,9 +241,9 @@ def optimizer_hints(self, hints: str) -> str: class BaseDialect(abc.ABC): - SUPPORTS_PRIMARY_KEY = False - SUPPORTS_INDEXES = False - TYPE_CLASSES: Dict[str, type] = {} + SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False + SUPPORTS_INDEXES: ClassVar[bool] = False + TYPE_CLASSES: ClassVar[Dict[str, type]] = {} PLACEHOLDER_TABLE = None # Used for Oracle @@ -835,14 +838,13 @@ class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase): Instanciated using :meth:`~data_diff.connect` """ - default_schema: str = None - SUPPORTS_ALPHANUMS = True - SUPPORTS_UNIQUE_CONSTAINT = False - - CONNECT_URI_KWPARAMS = [] + SUPPORTS_ALPHANUMS: ClassVar[bool] = True + SUPPORTS_UNIQUE_CONSTAINT: ClassVar[bool] = False + CONNECT_URI_KWPARAMS: ClassVar[List[str]] = [] - _interactive = False - is_closed = False + default_schema: Optional[str] = None + _interactive: bool = False + is_closed: bool = False @property def name(self): @@ -1109,6 +1111,10 @@ class ThreadedDatabase(Database): Used for database connectors that do not support sharing their connection between different threads. """ + _init_error: Optional[Exception] + _queue: ThreadPoolExecutor + thread_local: threading.local + def __init__(self, thread_count=1): super().__init__() self._init_error = None diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index aaf2dbd3..18140611 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -223,6 +223,10 @@ class BigQuery(Database): CONNECT_URI_PARAMS = ["dataset"] dialect = Dialect() + project: str + dataset: str + _client: Any + def __init__(self, project, *, dataset, bigquery_credentials=None, **kw): super().__init__() credentials = bigquery_credentials diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 70070934..6c6f56e6 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,4 +1,4 @@ -from typing import Optional, Type +from typing import Any, Dict, Optional, Type from data_diff.databases.base import ( MD5_HEXDIGITS, @@ -167,6 +167,8 @@ class Clickhouse(ThreadedDatabase): CONNECT_URI_HELP = "clickhouse://:@/" CONNECT_URI_PARAMS = ["database?"] + _args: Dict[str, Any] + def __init__(self, *, thread_count: int, **kw): super().__init__(thread_count=thread_count) diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index c440c54c..a63e62aa 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -1,5 +1,5 @@ import math -from typing import Dict, Sequence +from typing import Any, Dict, Sequence import logging from data_diff.abcs.database_types import ( @@ -103,13 +103,16 @@ class Databricks(ThreadedDatabase): CONNECT_URI_HELP = "databricks://:@/" CONNECT_URI_PARAMS = ["catalog", "schema"] + catalog: str + _args: Dict[str, Any] + def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) logging.getLogger("databricks.sql").setLevel(logging.WARNING) self._args = kw self.default_schema = kw.get("schema", "default") - self.catalog = self._args.get("catalog", "hive_metastore") + self.catalog = kw.get("catalog", "hive_metastore") def create_connection(self): databricks = import_databricks() diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index b0e061ee..48784565 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Dict, Union from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -134,14 +134,17 @@ def current_timestamp(self) -> str: class DuckDB(Database): dialect = Dialect() SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it - default_schema = "main" CONNECT_URI_HELP = "duckdb://@" CONNECT_URI_PARAMS = ["database", "dbpath"] + _args: Dict[str, Any] + _conn: Any + def __init__(self, **kw): super().__init__() self._args = kw self._conn = self.create_connection() + self.default_schema = "main" @property def is_autocommit(self) -> bool: diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 0e767417..968c410d 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Any, Dict, Optional from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -160,10 +160,12 @@ def constant_values(self, rows) -> str: class MsSQL(ThreadedDatabase): dialect = Dialect() - # CONNECT_URI_HELP = "mssql://:@//" CONNECT_URI_PARAMS = ["database", "schema"] + default_database: str + _args: Dict[str, Any] + def __init__(self, host, port, user, password, *, database, thread_count, **kw): args = dict(server=host, port=port, database=database, user=user, password=password, **kw) self._args = {k: v for k, v in args.items() if v is not None} diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index a69e6e46..a48aa3fe 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from data_diff.abcs.database_types import ( Datetime, Timestamp, @@ -137,6 +139,8 @@ class MySQL(ThreadedDatabase): CONNECT_URI_HELP = "mysql://:@/" CONNECT_URI_PARAMS = ["database?"] + _args: Dict[str, Any] + def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) self._args = kw diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index 44681a45..e548c7c3 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -181,6 +181,8 @@ class Oracle(ThreadedDatabase): CONNECT_URI_HELP = "oracle://:@/" CONNECT_URI_PARAMS = ["database?"] + kwargs: Dict[str, Any] + def __init__(self, *, host, database, thread_count, **kw): super().__init__(thread_count=thread_count) self.kwargs = dict(dsn=f"{host}/{database}" if database else host, **kw) diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 16d6a1d1..1c622739 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,5 +1,6 @@ -from typing import List +from typing import Any, ClassVar, Dict, List, Type from data_diff.abcs.database_types import ( + ColType, DbPath, JSON, Timestamp, @@ -68,7 +69,7 @@ class PostgresqlDialect( SUPPORTS_PRIMARY_KEY = True SUPPORTS_INDEXES = True - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { # Timestamps "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, @@ -125,11 +126,12 @@ class PostgreSQL(ThreadedDatabase): CONNECT_URI_HELP = "postgresql://:@/" CONNECT_URI_PARAMS = ["database?"] - default_schema = "public" + _args: Dict[str, Any] def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) self._args = kw + self.default_schema = "public" def create_connection(self): if not self._args: diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index a829df95..7a8c7eba 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,5 +1,6 @@ from functools import partial import re +from typing import Any from data_diff.utils import match_regexps @@ -158,10 +159,11 @@ class Presto(Database): CONNECT_URI_HELP = "presto://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] - default_schema = "public" + _conn: Any def __init__(self, **kw): super().__init__() + self.default_schema = "public" prestodb = import_presto() if kw.get("schema"): diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index d11029c0..f7aa8d06 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,5 +1,6 @@ -from typing import List, Dict +from typing import ClassVar, List, Dict, Type from data_diff.abcs.database_types import ( + ColType, Float, JSON, TemporalType, @@ -53,7 +54,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Redshift" - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { **PostgresqlDialect.TYPE_CLASSES, "double": Float, "real": Float, diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 31f85492..19898185 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,4 +1,4 @@ -from typing import Union, List +from typing import Any, Union, List import logging from data_diff.abcs.database_types import ( @@ -154,6 +154,8 @@ class Snowflake(Database): CONNECT_URI_PARAMS = ["database", "schema"] CONNECT_URI_KWPARAMS = ["warehouse"] + _conn: Any + def __init__(self, *, schema: str, **kw): super().__init__() snowflake, serialization, default_backend = import_snowflake() diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index ab4913d4..bce2bddf 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,3 +1,5 @@ +from typing import Any + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto @@ -39,6 +41,8 @@ class Trino(presto.Presto): CONNECT_URI_HELP = "trino://@//" CONNECT_URI_PARAMS = ["catalog", "schema"] + _conn: Any + def __init__(self, **kw): super().__init__() trino = import_trino() diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index f539a4df..6df48085 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List from data_diff.utils import match_regexps from data_diff.databases.base import ( @@ -156,12 +156,13 @@ class Vertica(ThreadedDatabase): CONNECT_URI_HELP = "vertica://:@/" CONNECT_URI_PARAMS = ["database?"] - default_schema = "public" + _args: Dict[str, Any] def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) self._args = kw self._args["AUTOCOMMIT"] = False + self.default_schema = "public" def create_connection(self): vertica = import_vertica() diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index c5526771..7a36fd58 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -45,6 +45,11 @@ class ThreadedYielder(Iterable): Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first) """ + _pool: ThreadPoolExecutor + _futures: deque + _yield: deque + _exception: Optional[None] + def __init__(self, max_workers: Optional[int] = None): super().__init__() self._pool = PriorityThreadPoolExecutor(max_workers) diff --git a/data_diff/utils.py b/data_diff/utils.py index 558a18e9..93e6db98 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -4,7 +4,7 @@ import re import string from abc import abstractmethod -from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union +from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Optional, Sequence, TypeVar, Union from urllib.parse import urlparse import operator import threading @@ -175,6 +175,9 @@ def alphanums_to_numbers(s1: str, s2: str): class ArithAlphanumeric(ArithString): + _str: str + _max_len: Optional[int] + def __init__(self, s: str, max_len=None): super().__init__() From c1b24ef22f1afead97377f5763cad2cf71db62de Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 16:21:52 +0200 Subject: [PATCH 6/7] Convert the runtype classes to attrs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `attrs` is much more beneficial: * `attrs` is supported by type checkers, such as MyPy & IDEs * `attrs` is widely used and industry-proven * `attrs` is explicit in its declarations, there is no magic * `attrs` has slots But mainly for the first item — type checking by type checkers. Those that were runtype classes, are frozen. Those that were not, are unfrozen for now, but we can freeze them later if and where it works (the stricter, the better). --- data_diff/abcs/compiler.py | 4 ++ data_diff/abcs/database_types.py | 26 ++++---- data_diff/cloud/datafold_api.py | 6 +- data_diff/databases/_connect.py | 5 +- data_diff/databases/base.py | 41 ++++++------ data_diff/diff_tables.py | 13 ++-- data_diff/format.py | 72 +++++++++++---------- data_diff/hashdiff_tables.py | 17 +++-- data_diff/info_tree.py | 11 ++-- data_diff/joindiff_tables.py | 11 ++-- data_diff/queries/ast_classes.py | 107 ++++++++++++++++--------------- data_diff/queries/extras.py | 10 +-- data_diff/table_segment.py | 14 ++-- poetry.lock | 2 +- pyproject.toml | 2 +- tests/test_database.py | 5 +- tests/test_diff_tables.py | 14 ++-- tests/test_joindiff.py | 6 +- tests/test_sql.py | 7 +- 19 files changed, 191 insertions(+), 182 deletions(-) diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py index 4a847d05..e5153b36 100644 --- a/data_diff/abcs/compiler.py +++ b/data_diff/abcs/compiler.py @@ -1,9 +1,13 @@ from abc import ABC +import attrs + +@attrs.define(frozen=False) class AbstractCompiler(ABC): pass +@attrs.define(frozen=False, eq=False) class Compilable(ABC): pass diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index e5ec393a..58ba9159 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -3,7 +3,7 @@ from typing import Tuple, Union from datetime import datetime -from runtype import dataclass +import attrs from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -13,14 +13,14 @@ DbTime = datetime -@dataclass +@attrs.define(frozen=True) class ColType: @property def supported(self) -> bool: return True -@dataclass +@attrs.define(frozen=True) class PrecisionType(ColType): precision: int rounds: Union[bool, Unknown] = Unknown @@ -50,7 +50,7 @@ class Date(TemporalType): pass -@dataclass +@attrs.define(frozen=True) class NumericType(ColType): # 'precision' signifies how many fractional digits (after the dot) we want to compare precision: int @@ -84,7 +84,7 @@ def python_type(self) -> type: return decimal.Decimal -@dataclass +@attrs.define(frozen=True) class StringType(ColType): python_type = str @@ -122,7 +122,7 @@ class String_VaryingAlphanum(String_Alphanum): pass -@dataclass +@attrs.define(frozen=True) class String_FixedAlphanum(String_Alphanum): length: int @@ -132,7 +132,7 @@ def make_value(self, value): return self.python_type(value, max_len=self.length) -@dataclass +@attrs.define(frozen=True) class Text(StringType): @property def supported(self) -> bool: @@ -140,12 +140,12 @@ def supported(self) -> bool: # In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT. -@dataclass +@attrs.define(frozen=True) class JSON(ColType): pass -@dataclass +@attrs.define(frozen=True) class Array(ColType): item_type: ColType @@ -155,21 +155,21 @@ class Array(ColType): # For example, in BigQuery: # - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type # - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals -@dataclass +@attrs.define(frozen=True) class Struct(ColType): pass -@dataclass +@attrs.define(frozen=True) class Integer(NumericType, IKey): precision: int = 0 python_type: type = int - def __post_init__(self): + def __attrs_post_init__(self): assert self.precision == 0 -@dataclass +@attrs.define(frozen=True) class UnknownColType(ColType): text: str diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index ea5a04e8..1da9e45d 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -1,9 +1,9 @@ import base64 -import dataclasses import enum import time from typing import Any, Dict, List, Optional, Type, Tuple +import attrs import pydantic import requests from typing_extensions import Self @@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel): result: Optional[TCloudDataSourceTestResult] -@dataclasses.dataclass +@attrs.define(frozen=True) class DatafoldAPI: api_key: str host: str = "https://app.datafold.com" timeout: int = 30 - def __post_init__(self): + def __attrs_post_init__(self): self.host = self.host.rstrip("/") self.headers = { "Authorization": f"Key {self.api_key}", diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 082e8fab..3af342cd 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -3,10 +3,11 @@ from itertools import zip_longest from contextlib import suppress import weakref + +import attrs import dsnparse import toml -from runtype import dataclass from typing_extensions import Self from data_diff.databases.base import Database, ThreadedDatabase @@ -25,7 +26,7 @@ from data_diff.databases.mssql import MsSQL -@dataclass +@attrs.define(frozen=True) class MatchUriPath: database_cls: Type[Database] diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 3e336119..26ab1703 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,6 +1,6 @@ import abc import functools -from dataclasses import field +import random from datetime import datetime import math import sys @@ -14,7 +14,7 @@ import decimal import contextvars -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.abcs.compiler import AbstractCompiler @@ -90,12 +90,7 @@ class CompileError(Exception): pass -# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved. -class _RuntypeHackToFixCicularRefrencedDatabase: - dialect: "BaseDialect" - - -@dataclass +@attrs.define(frozen=True) class Compiler(AbstractCompiler): """ Compiler bears the context for a single compilation. @@ -107,16 +102,16 @@ class Compiler(AbstractCompiler): # Database is needed to normalize tables. Dialect is needed for recursive compilations. # In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects. # In practice, we currently bind the dialects to the specific database classes. - database: _RuntypeHackToFixCicularRefrencedDatabase + database: "Database" in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag - _table_context: List = field(default_factory=list) # List[ITable] - _subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe + _table_context: List = attrs.field(factory=list) # List[ITable] + _subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe root: bool = True - _counter: List = field(default_factory=lambda: [0]) + _counter: List = attrs.field(factory=lambda: [0]) @property def dialect(self) -> "BaseDialect": @@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: - return self.replace(_table_context=self._table_context + list(tables), **kw) + return attrs.evolve(self, table_context=self._table_context + list(tables), **kw) def parse_table_name(t): @@ -271,7 +266,7 @@ def _compile(self, compiler: Compiler, elem) -> str: if elem is None: return "NULL" elif isinstance(elem, Compilable): - return self.render_compilable(compiler.replace(root=False), elem) + return self.render_compilable(attrs.evolve(compiler, root=False), elem) elif isinstance(elem, str): return f"'{elem}'" elif isinstance(elem, (int, float)): @@ -381,7 +376,7 @@ def render_column(self, c: Compiler, elem: Column) -> str: return self.quote(elem.name) def render_cte(self, parent_c: Compiler, elem: Cte) -> str: - c: Compiler = parent_c.replace(_table_context=[], in_select=False) + c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False) compiled = self.compile(c, elem.source_table) name = elem.name or parent_c.new_unique_name() @@ -494,7 +489,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str: return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}" def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: - c: Compiler = parent_c.replace(in_select=False) + c: Compiler = attrs.evolve(parent_c, in_select=False) table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}" if parent_c.in_select: table_expr = f"({table_expr}) {c.new_unique_name()}" @@ -506,7 +501,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: return self.compile(c, elem._get_resolved()) def render_select(self, parent_c: Compiler, elem: Select) -> str: - c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table) + c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table) compile_fn = functools.partial(self.compile, c) columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*" @@ -544,7 +539,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: def render_join(self, parent_c: Compiler, elem: Join) -> str: tables = [ - t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables + t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name()) + for t in elem.source_tables ] c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if elem.op is None else f" {elem.op} JOIN " @@ -577,7 +573,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None: return self.compile( c, - elem.table.replace( + attrs.evolve( + elem.table, columns=columns, group_by_exprs=[Code(k) for k in keys], having_exprs=elem.having_exprs, @@ -589,7 +586,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: having_str = ( " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" ) - select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" + select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" if c.in_select: select = f"({select}) {c.new_unique_name()}" @@ -815,7 +812,7 @@ def set_timezone_to_utc(self) -> str: T = TypeVar("T", bound=BaseDialect) -@dataclass +@attrs.define(frozen=True) class QueryResult: rows: list columns: Optional[list] = None @@ -830,7 +827,7 @@ def __getitem__(self, i): return self.rows[i] -class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase): +class Database(abc.ABC): """Base abstract class for databases. Used for providing connection code and implementation specific SQL utilities. diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 08c18391..26a93241 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -3,14 +3,13 @@ import time from abc import ABC, abstractmethod -from dataclasses import field from enum import Enum from contextlib import contextmanager from operator import methodcaller from typing import Dict, Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from runtype import dataclass +import attrs from data_diff.info_tree import InfoTree, SegmentInfo from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector @@ -31,7 +30,7 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] -@dataclass +@attrs.define(frozen=True) class ThreadBase: "Provides utility methods for optional threading" @@ -72,7 +71,7 @@ def _run_in_background(self, *funcs): f.result() -@dataclass +@attrs.define(frozen=True) class DiffStats: diff_by_sign: Dict[str, int] table1_count: int @@ -82,12 +81,12 @@ class DiffStats: extra_column_diffs: Optional[Dict[str, int]] -@dataclass +@attrs.define(frozen=True) class DiffResultWrapper: diff: iter # DiffResult info_tree: InfoTree stats: dict - result_list: list = field(default_factory=list) + result_list: list = attrs.field(factory=list) def __iter__(self): yield from self.result_list @@ -203,7 +202,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: if is_tracking_enabled(): - options = dict(self) + options = attrs.asdict(self, recurse=False) options["differ_name"] = type(self).__name__ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) diff --git a/data_diff/format.py b/data_diff/format.py index a8900e84..8eb12159 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, List, Dict, Tuple, Type -from runtype import dataclass +import attrs from data_diff.diff_tables import DiffResultWrapper from data_diff.abcs.database_types import ( JSON, @@ -21,13 +21,15 @@ def jsonify_error(table1: List[str], table2: List[str], dbt_model: str, error: str) -> "FailedDiff": - return FailedDiff( - status="failed", - model=dbt_model, - dataset1=table1, - dataset2=table2, - error=error, - ).json() + return attrs.asdict( + FailedDiff( + status="failed", + model=dbt_model, + dataset1=table1, + dataset2=table2, + error=error, + ) + ) Columns = List[Tuple[str, str, ColType]] @@ -74,19 +76,21 @@ def jsonify( or diff_rows or (columns_diff["added"] or columns_diff["removed"] or columns_diff["changed"]) ) - return JsonDiff( - status="success", - result="different" if is_different else "identical", - model=dbt_model, - dataset1=list(table1.table_path), - dataset2=list(table2.table_path), - rows=rows, - summary=summary, - columns=columns, - ).json() - - -@dataclass + return attrs.asdict( + JsonDiff( + status="success", + result="different" if is_different else "identical", + model=dbt_model, + dataset1=list(table1.table_path), + dataset2=list(table2.table_path), + rows=rows, + summary=summary, + columns=columns, + ) + ) + + +@attrs.define(frozen=True) class JsonExclusiveRowValue: """ Value of a single column in a row @@ -96,7 +100,7 @@ class JsonExclusiveRowValue: value: Any -@dataclass +@attrs.define(frozen=True) class JsonDiffRowValue: """ Pair of diffed values for 2 rows with equal PKs @@ -108,19 +112,19 @@ class JsonDiffRowValue: isPK: bool -@dataclass +@attrs.define(frozen=True) class Total: dataset1: int dataset2: int -@dataclass +@attrs.define(frozen=True) class ExclusiveRows: dataset1: int dataset2: int -@dataclass +@attrs.define(frozen=True) class Rows: total: Total exclusive: ExclusiveRows @@ -128,18 +132,18 @@ class Rows: unchanged: int -@dataclass +@attrs.define(frozen=True) class Stats: diffCounts: Dict[str, int] -@dataclass +@attrs.define(frozen=True) class JsonDiffSummary: rows: Rows stats: Stats -@dataclass +@attrs.define(frozen=True) class ExclusiveColumns: dataset1: List[str] dataset2: List[str] @@ -172,14 +176,14 @@ class ColumnKind(Enum): ] -@dataclass +@attrs.define(frozen=True) class Column: name: str type: str kind: str -@dataclass +@attrs.define(frozen=True) class JsonColumnsSummary: dataset1: List[Column] dataset2: List[Column] @@ -188,19 +192,19 @@ class JsonColumnsSummary: typeChanged: List[str] -@dataclass +@attrs.define(frozen=True) class ExclusiveDiff: dataset1: List[Dict[str, JsonExclusiveRowValue]] dataset2: List[Dict[str, JsonExclusiveRowValue]] -@dataclass +@attrs.define(frozen=True) class RowsDiff: exclusive: ExclusiveDiff diff: List[Dict[str, JsonDiffRowValue]] -@dataclass +@attrs.define(frozen=True) class FailedDiff: status: str # Literal ["failed"] model: str @@ -211,7 +215,7 @@ class FailedDiff: version: str = "1.0.0" -@dataclass +@attrs.define(frozen=True) class JsonDiff: status: str # Literal ["success"] result: str # Literal ["different", "identical"] diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 3fc030ec..8b46c39a 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -1,12 +1,11 @@ import os -from dataclasses import field from numbers import Number import logging from collections import defaultdict from typing import Iterator from operator import attrgetter -from runtype import dataclass +import attrs from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON from data_diff.info_tree import InfoTree @@ -53,7 +52,7 @@ def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator: yield from v -@dataclass +@attrs.define(frozen=True) class HashDiffer(TableDiffer): """Finds the diff between two SQL tables @@ -74,9 +73,9 @@ class HashDiffer(TableDiffer): bisection_factor: int = DEFAULT_BISECTION_FACTOR bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) - def __post_init__(self): + def __attrs_post_init__(self): # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") @@ -102,8 +101,8 @@ def _validate_and_adjust_columns(self, table1, table2): if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds) elif isinstance(col1, (NumericType, Boolean)): if not isinstance(col2, (NumericType, Boolean)): @@ -115,9 +114,9 @@ def _validate_and_adjust_columns(self, table1, table2): logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") if lowest.precision != col1.precision: - table1._schema[c1] = col1.replace(precision=lowest.precision) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision) if lowest.precision != col2.precision: - table2._schema[c2] = col2.replace(precision=lowest.precision) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision) elif isinstance(col1, ColType_UUID): if not isinstance(col2, ColType_UUID): diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index b30ba2f2..d08eb16d 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -1,12 +1,11 @@ -from dataclasses import field from typing import List, Dict, Optional, Any, Tuple, Union -from runtype import dataclass +import attrs from data_diff.table_segment import TableSegment -@dataclass(frozen=False) +@attrs.define(frozen=False) class SegmentInfo: tables: List[TableSegment] @@ -15,7 +14,7 @@ class SegmentInfo: is_diff: Optional[bool] = None diff_count: Optional[int] = None - rowcounts: Dict[int, int] = field(default_factory=dict) + rowcounts: Dict[int, int] = attrs.field(factory=dict) max_rows: Optional[int] = None def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): @@ -40,10 +39,10 @@ def update_from_children(self, child_infos): } -@dataclass +@attrs.define(frozen=True) class InfoTree: info: SegmentInfo - children: List["InfoTree"] = field(default_factory=list) + children: List["InfoTree"] = attrs.field(factory=list) def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: int = None): node = InfoTree(SegmentInfo([table1, table2], max_rows=max_rows)) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 14834ffd..12ca203e 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -1,14 +1,13 @@ """Provides classes for performing a table diff using JOIN """ -from dataclasses import field from decimal import Decimal from functools import partial import logging -from typing import List +from typing import List, Optional from itertools import chain -from runtype import dataclass +import attrs from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake from data_diff.abcs.database_types import NumericType, DbPath @@ -58,7 +57,7 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database - c: Compiler = c.replace(root=False) # we're compiling fragments, not full queries + c: Compiler = attrs.evolve(c, root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}" elif isinstance(db, Presto): @@ -111,7 +110,7 @@ def json_friendly_value(v): return v -@dataclass +@attrs.define(frozen=True) class JoinDiffer(TableDiffer): """Finds the diff between two SQL tables in the same database, using JOINs. @@ -143,7 +142,7 @@ class JoinDiffer(TableDiffer): table_write_limit: int = TABLE_WRITE_LIMIT skip_null_keys: bool = False - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: db = table1.database diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 710f1316..e342de0e 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,8 +1,7 @@ -from dataclasses import field from datetime import datetime from typing import Any, Generator, List, Optional, Sequence, Union, Dict -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import ArithString @@ -25,6 +24,7 @@ class Root: "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" +@attrs.define(frozen=False, eq=False) class ExprNode(Compilable): "Base class for query expression nodes" @@ -34,7 +34,7 @@ def type(self) -> Optional[type]: def _dfs_values(self): yield self - for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + for k, vs in attrs.asdict(self, recurse=False).items(): if k == "source_table": # Skip data-sources, we're only interested in data-parameters continue @@ -52,7 +52,7 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, float, datetime, ArithString, None] -@dataclass +@attrs.define(frozen=True, eq=False) class Code(ExprNode, Root): code: str args: Optional[Dict[str, Expr]] = None @@ -64,7 +64,7 @@ def _expr_type(e: Expr) -> type: return type(e) -@dataclass +@attrs.define(frozen=True, eq=False) class Alias(ExprNode): expr: Expr name: str @@ -213,13 +213,13 @@ def intersect(self, other: "ITable"): return TableOp("INTERSECT", self, other) -@dataclass +@attrs.define(frozen=True, eq=False) class Concat(ExprNode): exprs: list sep: Optional[str] = None -@dataclass +@attrs.define(frozen=True, eq=False) class Count(ExprNode): expr: Expr = None distinct: bool = False @@ -229,6 +229,7 @@ def type(self) -> Optional[type]: return int +@attrs.define(frozen=False, eq=False) class LazyOps: def __add__(self, other): return BinOp("+", [self, other]) @@ -278,19 +279,19 @@ def min(self): return Func("MIN", [self]) -@dataclass(eq=False) -class Func(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class Func(LazyOps, ExprNode): name: str args: Sequence[Expr] -@dataclass +@attrs.define(frozen=True, eq=False) class WhenThen(ExprNode): when: Expr then: Expr -@dataclass +@attrs.define(frozen=True, eq=False) class CaseWhen(ExprNode): cases: Sequence[WhenThen] else_expr: Optional[Expr] = None @@ -328,10 +329,10 @@ def else_(self, then: Expr) -> Self: if self.else_expr is not None: raise QueryBuilderError(f"Else clause already specified in {self}") - return self.replace(else_expr=then) + return attrs.evolve(self, else_expr=then) -@dataclass +@attrs.define(frozen=True, eq=False) class QB_When: "Partial case-when, used for query-building" casewhen: CaseWhen @@ -340,11 +341,11 @@ class QB_When: def then(self, then: Expr) -> CaseWhen: """Add a 'then' clause after a 'when' was added.""" case = WhenThen(self.when, then) - return self.casewhen.replace(cases=self.casewhen.cases + [case]) + return attrs.evolve(self.casewhen, cases=self.casewhen.cases + [case]) -@dataclass(eq=False, order=False) -class IsDistinctFrom(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class IsDistinctFrom(LazyOps, ExprNode): a: Expr b: Expr @@ -353,8 +354,8 @@ def type(self) -> Optional[type]: return bool -@dataclass(eq=False, order=False) -class BinOp(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class BinOp(LazyOps, ExprNode): op: str args: Sequence[Expr] @@ -367,8 +368,8 @@ def type(self): return t -@dataclass -class UnaryOp(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class UnaryOp(LazyOps, ExprNode): op: str expr: Expr @@ -379,8 +380,8 @@ def type(self) -> Optional[type]: return bool -@dataclass(eq=False, order=False) -class Column(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class Column(LazyOps, ExprNode): source_table: ITable name: str @@ -391,7 +392,7 @@ def type(self): return self.source_table.schema[self.name] -@dataclass +@attrs.define(frozen=False, eq=False) class TablePath(ExprNode, ITable): path: DbPath schema: Optional[Schema] = None # overrides the inherited property @@ -474,7 +475,7 @@ def time_travel( assert offset is None and statement is None -@dataclass +@attrs.define(frozen=True, eq=False) class TableAlias(ExprNode, ITable): table: ITable name: str @@ -488,7 +489,7 @@ def schema(self) -> Schema: return self.table.schema -@dataclass +@attrs.define(frozen=True, eq=False) class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] op: Optional[str] = None @@ -512,7 +513,7 @@ def on(self, *exprs) -> Self: if not exprs: return self - return self.replace(on_exprs=(self.on_exprs or []) + exprs) + return attrs.evolve(self, on_exprs=(self.on_exprs or []) + exprs) def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: """Select fields to return from the JOIN operation @@ -528,17 +529,17 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: exprs += _named_exprs_as_aliases(named_exprs) resolve_names(self.source_table, exprs) # TODO Ensure exprs <= self.columns ? - return self.replace(columns=exprs) + return attrs.evolve(self, columns=exprs) -@dataclass +@attrs.define(frozen=True, eq=False) class GroupBy(ExprNode, ITable, Root): table: ITable keys: Optional[Sequence[Expr]] = None # IKey? values: Optional[Sequence[Expr]] = None having_exprs: Optional[Sequence[Expr]] = None - def __post_init__(self): + def __attrs_post_init__(self): assert self.keys or self.values def having(self, *exprs) -> Self: @@ -549,17 +550,17 @@ def having(self, *exprs) -> Self: return self resolve_names(self.table, exprs) - return self.replace(having_exprs=(self.having_exprs or []) + exprs) + return attrs.evolve(self, having_exprs=(self.having_exprs or []) + exprs) def agg(self, *exprs) -> Self: """Select aggregated fields for the group-by.""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) resolve_names(self.table, exprs) - return self.replace(values=(self.values or []) + exprs) + return attrs.evolve(self, values=(self.values or []) + exprs) -@dataclass +@attrs.define(frozen=True, eq=False) class TableOp(ExprNode, ITable, Root): op: str table1: ITable @@ -578,7 +579,7 @@ def schema(self) -> Schema: return s1 -@dataclass +@attrs.define(frozen=True, eq=False) class Select(ExprNode, ITable, Root): table: Optional[Expr] = None columns: Optional[Sequence[Expr]] = None @@ -630,10 +631,10 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, else: raise ValueError(k) - return table.replace(**kwargs) + return attrs.evolve(table, **kwargs) -@dataclass +@attrs.define(frozen=True, eq=False) class Cte(ExprNode, ITable): table: Expr name: Optional[str] = None @@ -664,8 +665,8 @@ def resolve_names(source_table, exprs): i += 1 -@dataclass(frozen=False, eq=False, order=False) -class _ResolveColumn(ExprNode, LazyOps): +@attrs.define(frozen=False, eq=False) +class _ResolveColumn(LazyOps, ExprNode): resolve_name: str resolved: Optional[Expr] = None @@ -703,7 +704,7 @@ def __getitem__(self, name): return _ResolveColumn(name) -@dataclass +@attrs.define(frozen=True, eq=False) class In(ExprNode): expr: Expr list: Sequence[Expr] @@ -713,25 +714,25 @@ def type(self) -> Optional[type]: return bool -@dataclass +@attrs.define(frozen=True, eq=False) class Cast(ExprNode): expr: Expr target_type: Expr -@dataclass -class Random(ExprNode, LazyOps): +@attrs.define(frozen=True, eq=False) +class Random(LazyOps, ExprNode): @property def type(self) -> Optional[type]: return float -@dataclass +@attrs.define(frozen=True, eq=False) class ConstantTable(ExprNode): rows: Sequence[Sequence] -@dataclass +@attrs.define(frozen=True, eq=False) class Explain(ExprNode, Root): select: Select @@ -746,8 +747,8 @@ def type(self) -> Optional[type]: return datetime -@dataclass -class TimeTravel(ITable): +@attrs.define(frozen=True, eq=False) +class TimeTravel(ITable): # TODO: Unused? table: TablePath before: bool = False timestamp: Optional[datetime] = None @@ -764,7 +765,7 @@ def type(self) -> Optional[type]: return None -@dataclass +@attrs.define(frozen=True, eq=False) class CreateTable(Statement): path: TablePath source_table: Optional[Expr] = None @@ -772,18 +773,18 @@ class CreateTable(Statement): primary_keys: Optional[List[str]] = None -@dataclass +@attrs.define(frozen=True, eq=False) class DropTable(Statement): path: TablePath if_exists: bool = False -@dataclass +@attrs.define(frozen=True, eq=False) class TruncateTable(Statement): path: TablePath -@dataclass +@attrs.define(frozen=True, eq=False) class InsertToTable(Statement): path: TablePath expr: Expr @@ -804,16 +805,16 @@ def returning(self, *exprs) -> Self: return self resolve_names(self.path, exprs) - return self.replace(returning_exprs=exprs) + return attrs.evolve(self, returning_exprs=exprs) -@dataclass +@attrs.define(frozen=True, eq=False) class Commit(Statement): """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP.""" -@dataclass -class Param(ExprNode, ITable): +@attrs.define(frozen=True, eq=False) +class Param(ExprNode, ITable): # TODO: Unused? """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" name: str diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 4467bd0a..7d67ae97 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,13 +1,13 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" from typing import Callable, Optional, Sequence -from runtype import dataclass -from data_diff.abcs.database_types import ColType +import attrs +from data_diff.abcs.database_types import ColType from data_diff.queries.ast_classes import Expr, ExprNode -@dataclass +@attrs.define(frozen=True) class NormalizeAsString(ExprNode): expr: ExprNode expr_type: Optional[ColType] = None @@ -17,12 +17,12 @@ def type(self) -> Optional[type]: return str -@dataclass +@attrs.define(frozen=True) class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode apply_func: Optional[Callable] = None -@dataclass +@attrs.define(frozen=True) class Checksum(ExprNode): exprs: Sequence[Expr] diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 015e5bc4..d8f84231 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,9 +1,9 @@ import time -from typing import List, Tuple +from typing import List, Optional, Tuple import logging from itertools import product -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import safezip, Vector @@ -85,7 +85,7 @@ def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector] return res -@dataclass +@attrs.define(frozen=True) class TableSegment: """Signifies a segment of rows (and selected columns) within a table @@ -125,7 +125,7 @@ class TableSegment: case_sensitive: Optional[bool] = True _schema: Optional[Schema] = None - def __post_init__(self): + def __attrs_post_init__(self): if not self.update_column and (self.min_update or self.max_update): raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.") @@ -142,7 +142,7 @@ def _where(self): def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) - return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) + return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." @@ -199,7 +199,7 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS def new(self, **kwargs) -> Self: """Creates a copy of the instance using 'replace()'""" - return self.replace(**kwargs) + return attrs.evolve(self, **kwargs) def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: if self.min_key is not None: @@ -210,7 +210,7 @@ def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: assert min_key < self.max_key assert max_key <= self.max_key - return self.replace(min_key=min_key, max_key=max_key) + return attrs.evolve(self, min_key=min_key, max_key=max_key) @property def relevant_columns(self) -> List[str]: diff --git a/poetry.lock b/poetry.lock index afd70d75..a95c7580 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2030,4 +2030,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "55cde03a00788572dac6310e7bbf61bd2522d70217056a51608bcfc429440fbf" +content-hash = "c7da70c19432ca716980f3421182d54d7f5d2e0d8bbd7e20dbaf521c8ef7d0fb" diff --git a/pyproject.toml b/pyproject.toml index 0ae8f0a4..2ceb0c27 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ packages = [{ include = "data_diff" }] [tool.poetry.dependencies] pydantic = "1.10.12" python = "^3.7.2" -runtype = "^0.2.6" dsnparse = "<0.2.0" click = "^8.1" rich = "*" @@ -47,6 +46,7 @@ urllib3 = "<2" oracledb = {version = "*", optional=true} pyodbc = {version="^4.0.39", optional=true} typing-extensions = ">=4.0.1" +attrs = "^23.1.0" [tool.poetry.dev-dependencies] parameterized = "*" diff --git a/tests/test_database.py b/tests/test_database.py index c63fd95c..3ef06dcc 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Callable, List, Tuple +import attrs import pytz from data_diff import connect @@ -130,8 +131,8 @@ def test_correct_timezone(self): raw_schema = db.query_table_schema(t.path) schema = db._process_table_schema(t.path, raw_schema) schema = create_schema(db.name, t, schema, case_sensitive=True) - t = t.replace(schema=schema) - t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) + t = attrs.evolve(t, schema=schema) + t.schema["created_at"] = attrs.evolve(t.schema["created_at"], precision=t.schema["created_at"].precision) tbl = table(name, schema=t.schema) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index b5885a26..2e48798a 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -3,6 +3,8 @@ import uuid import unittest +import attrs + from data_diff.queries.api import table, this, commit, code from data_diff.utils import ArithAlphanumeric, numberToAlphanum @@ -382,13 +384,13 @@ def test_string_keys(self): self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_where_sampling(self): - a = self.a.replace(where="1=1") + a = attrs.evolve(self.a, where="1=1") differ = HashDiffer(bisection_factor=2) diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - a_empty = self.a.replace(where="1=0") + a_empty = attrs.evolve(self.a, where="1=0") self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) @@ -504,9 +506,9 @@ def setUp(self) -> None: def test_table_segment(self): early = datetime(2021, 1, 1, 0, 0) late = datetime(2022, 1, 1, 0, 0) - self.assertRaises(ValueError, self.table.replace, min_update=late, max_update=early) + self.assertRaises(ValueError, attrs.evolve, self.table, min_update=late, max_update=early) - self.assertRaises(ValueError, self.table.replace, min_key=Vector((10,)), max_key=Vector((0,))) + self.assertRaises(ValueError, attrs.evolve, self.table, min_key=Vector((10,)), max_key=Vector((0,))) def test_case_awareness(self): src_table = table(self.table_src_path, schema={"id": int, "userid": int, "timestamp": datetime}) @@ -519,11 +521,11 @@ def test_case_awareness(self): [src_table.create(), src_table.insert_rows([[1, 9, time_obj], [2, 2, time_obj]], columns=cols), commit] ) - res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) + res = tuple(attrs.evolve(self.table, key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) self.assertEqual(res, (("1",), ("2",))) self.assertRaises( - KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range + KeyError, attrs.evolve(self.table, key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index e7e9ec86..8c00037b 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,6 +1,8 @@ from typing import List from datetime import datetime +import attrs + from data_diff.queries.ast_classes import TablePath from data_diff.queries.api import table, commit from data_diff.table_segment import TableSegment @@ -114,7 +116,7 @@ def test_diff_small_tables(self): # Test materialize materialize_path = self.connection.dialect.parse_table_name(f"test_mat_{random_table_suffix()}") - mdiffer = self.differ.replace(materialize_to_table=materialize_path) + mdiffer = attrs.evolve(self.differ, materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) @@ -126,7 +128,7 @@ def test_diff_small_tables(self): self.connection.query(t.drop()) # Test materialize all rows - mdiffer = mdiffer.replace(materialize_all_rows=True) + mdiffer = attrs.evolve(mdiffer, materialize_all_rows=True) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) rows = self.connection.query(t.select(), List[tuple]) diff --git a/tests/test_sql.py b/tests/test_sql.py index 6293d0bd..7b63ee8f 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,5 +1,7 @@ import unittest +import attrs + from tests.common import TEST_MYSQL_CONN_STRING from data_diff.databases import connect @@ -19,9 +21,8 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): - self.assertEqual( - "`marine_mammals`.`walrus`", self.compiler.replace(root=False).compile(table("marine_mammals", "walrus")) - ) + compiler = attrs.evolve(self.compiler, root=False) + self.assertEqual("`marine_mammals`.`walrus`", compiler.compile(table("marine_mammals", "walrus"))) def test_compile_select(self): expected_sql = "SELECT name FROM `marine_mammals`.`walrus`" From ee37648d2b848ae3c1182608ae5e69ba56e5eee2 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 16:46:51 +0200 Subject: [PATCH 7/7] Convert the remaining classes to attrs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since we now use `attrs` for some classes, let's use `attrs` for them all — at least those belonging to the same hierarchies. This will ensure that all classes are slotted and will strictly check that we define attributes properly, especially in cases of multiple inheritance. Except for Pydantic models and Python exceptions. Despite the attrs classes are not frozen by default, we keep it explicitly stated, so that we see which classes were or were not frozen before the switch from runtype to attrs. We can later freeze more classes if/when it works (the stricter, the better). For this reason, we never unfreeze classes that were previously frozen. ->? --- data_diff/abcs/database_types.py | 16 +++++++++++++++ data_diff/abcs/mixins.py | 10 ++++++++++ data_diff/databases/_connect.py | 2 ++ data_diff/databases/base.py | 33 ++++++++++++++++--------------- data_diff/databases/bigquery.py | 9 +++++++++ data_diff/databases/clickhouse.py | 6 ++++++ data_diff/databases/databricks.py | 6 ++++++ data_diff/databases/duckdb.py | 11 +++++++++-- data_diff/databases/mssql.py | 7 +++++++ data_diff/databases/mysql.py | 6 ++++++ data_diff/databases/oracle.py | 7 +++++++ data_diff/databases/postgresql.py | 8 ++++++++ data_diff/databases/presto.py | 5 +++++ data_diff/databases/redshift.py | 7 +++++++ data_diff/databases/snowflake.py | 6 ++++++ data_diff/databases/trino.py | 4 ++++ data_diff/databases/vertica.py | 5 +++++ data_diff/dbt_parser.py | 19 +++++++++++++++++- data_diff/diff_tables.py | 1 + data_diff/lexicographic_space.py | 4 ++++ data_diff/queries/ast_classes.py | 6 ++++++ data_diff/queries/base.py | 3 +++ data_diff/thread_utils.py | 9 ++++++++- data_diff/utils.py | 21 ++++++++++---------- 24 files changed, 180 insertions(+), 31 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 58ba9159..0eb5bb69 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -26,26 +26,32 @@ class PrecisionType(ColType): rounds: Union[bool, Unknown] = Unknown +@attrs.define(frozen=True) class Boolean(ColType): precision = 0 +@attrs.define(frozen=True) class TemporalType(PrecisionType): pass +@attrs.define(frozen=True) class Timestamp(TemporalType): pass +@attrs.define(frozen=True) class TimestampTZ(TemporalType): pass +@attrs.define(frozen=True) class Datetime(TemporalType): pass +@attrs.define(frozen=True) class Date(TemporalType): pass @@ -56,14 +62,17 @@ class NumericType(ColType): precision: int +@attrs.define(frozen=True) class FractionalType(NumericType): pass +@attrs.define(frozen=True) class Float(FractionalType): python_type = float +@attrs.define(frozen=True) class IKey(ABC): "Interface for ColType, for using a column as a key in table." @@ -76,6 +85,7 @@ def make_value(self, value): return self.python_type(value) +@attrs.define(frozen=True) class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key @property def python_type(self) -> type: @@ -89,22 +99,27 @@ class StringType(ColType): python_type = str +@attrs.define(frozen=True) class ColType_UUID(ColType, IKey): python_type = ArithUUID +@attrs.define(frozen=True) class ColType_Alphanum(ColType, IKey): python_type = ArithAlphanumeric +@attrs.define(frozen=True) class Native_UUID(ColType_UUID): pass +@attrs.define(frozen=True) class String_UUID(ColType_UUID, StringType): pass +@attrs.define(frozen=True) class String_Alphanum(ColType_Alphanum, StringType): @staticmethod def test_value(value: str) -> bool: @@ -118,6 +133,7 @@ def make_value(self, value): return self.python_type(value) +@attrs.define(frozen=True) class String_VaryingAlphanum(String_Alphanum): pass diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py index 9a30f41e..e597f480 100644 --- a/data_diff/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -1,4 +1,7 @@ from abc import ABC, abstractmethod + +import attrs + from data_diff.abcs.database_types import ( Array, TemporalType, @@ -13,10 +16,12 @@ from data_diff.abcs.compiler import Compilable +@attrs.define(frozen=False) class AbstractMixin(ABC): "A mixin for a database dialect" +@attrs.define(frozen=False) class AbstractMixin_NormalizeValue(AbstractMixin): @abstractmethod def to_comparable(self, value: str, coltype: ColType) -> str: @@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.to_string(value) +@attrs.define(frozen=False) class AbstractMixin_MD5(AbstractMixin): """Methods for calculating an MD6 hash as an integer.""" @@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" +@attrs.define(frozen=False) class AbstractMixin_Schema(AbstractMixin): """Methods for querying the database schema @@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: """ +@attrs.define(frozen=False) class AbstractMixin_RandomSample(AbstractMixin): @abstractmethod def random_sample_n(self, tbl: str, size: int) -> str: @@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: # """ +@attrs.define(frozen=False) class AbstractMixin_TimeTravel(AbstractMixin): @abstractmethod def time_travel( @@ -173,6 +182,7 @@ def time_travel( """ +@attrs.define(frozen=False) class AbstractMixin_OptimizerHints(AbstractMixin): @abstractmethod def optimizer_hints(self, optimizer_hints: str) -> str: diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 3af342cd..be55cc2d 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -93,6 +93,7 @@ def match_path(self, dsn): } +@attrs.define(frozen=False, init=False) class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" @@ -288,6 +289,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: return db_conf +@attrs.define(frozen=False, init=False) class Connect_SetUTC(Connect): """Provides methods for connecting to a supported database using a URL or connection dict. diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 26ab1703..e4e215b7 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -168,6 +168,7 @@ def _one(seq): return x +@attrs.define(frozen=False) class ThreadLocalInterpreter: """An interpeter used to execute a sequence of queries within the same thread and cursor. @@ -177,11 +178,6 @@ class ThreadLocalInterpreter: compiler: Compiler gen: Generator - def __init__(self, compiler: Compiler, gen: Generator): - super().__init__() - self.gen = gen - self.compiler = compiler - def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) while True: @@ -205,6 +201,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) +@attrs.define(frozen=False) class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("information_schema", "tables") @@ -221,6 +218,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define(frozen=False) class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: # TODO use a more efficient algorithm, when the table count is known @@ -230,15 +228,17 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return tbl.where(Random() < ratio) +@attrs.define(frozen=False) class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): def optimizer_hints(self, hints: str) -> str: return f"/*+ {hints} */ " +@attrs.define(frozen=False) class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False SUPPORTS_INDEXES: ClassVar[bool] = False - TYPE_CLASSES: ClassVar[Dict[str, type]] = {} + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} PLACEHOLDER_TABLE = None # Used for Oracle @@ -539,7 +539,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: def render_join(self, parent_c: Compiler, elem: Join) -> str: tables = [ - t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name()) + t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name()) for t in elem.source_tables ] c = parent_c.add_table_context(*tables, in_join=True, in_select=False) @@ -827,6 +827,7 @@ def __getitem__(self, i): return self.rows[i] +@attrs.define(frozen=False, kw_only=True) class Database(abc.ABC): """Base abstract class for databases. @@ -1102,22 +1103,22 @@ def is_autocommit(self) -> bool: "Return whether the database autocommits changes. When false, COMMIT statements are skipped." +@attrs.define(frozen=False) class ThreadedDatabase(Database): """Access the database through singleton threads. Used for database connectors that do not support sharing their connection between different threads. """ - _init_error: Optional[Exception] - _queue: ThreadPoolExecutor - thread_local: threading.local + thread_count: int = 1 + + _init_error: Optional[Exception] = None + _queue: Optional[ThreadPoolExecutor] = None + thread_local: threading.local = attrs.field(factory=threading.local) - def __init__(self, thread_count=1): - super().__init__() - self._init_error = None - self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) - self.thread_local = threading.local() - logger.info(f"[{self.name}] Starting a threadpool, size={thread_count}.") + def __attrs_post_init__(self): + self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn) + logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.") def set_conn(self): assert not hasattr(self.thread_local, "conn") diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index 18140611..72ab3797 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,8 @@ import re from typing import Any, List, Union + +import attrs + from data_diff.abcs.database_types import ( ColType, Array, @@ -50,11 +53,13 @@ def import_bigquery_service_account(): return service_account +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -99,6 +104,7 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: return f"to_json_string({value})" +@attrs.define(frozen=False) class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -112,6 +118,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define(frozen=False) class Mixin_TimeTravel(AbstractMixin_TimeTravel): def time_travel( self, @@ -139,6 +146,7 @@ def time_travel( ) +@attrs.define(frozen=False) class Dialect( BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue ): @@ -218,6 +226,7 @@ def parse_table_name(self, name: str) -> DbPath: return tuple(i for i in path if i is not None) +@attrs.define(frozen=False, init=False, kw_only=True) class BigQuery(Database): CONNECT_URI_HELP = "bigquery:///" CONNECT_URI_PARAMS = ["dataset"] diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 6c6f56e6..193a4b44 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Optional, Type +import attrs + from data_diff.databases.base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -35,12 +37,14 @@ def import_clickhouse(): return clickhouse_driver +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -99,6 +103,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" +@attrs.define(frozen=False) class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False @@ -162,6 +167,7 @@ def current_timestamp(self) -> str: return "now()" +@attrs.define(frozen=False, init=False, kw_only=True) class Clickhouse(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "clickhouse://:@/" diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index a63e62aa..96772766 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -2,6 +2,8 @@ from typing import Any, Dict, Sequence import logging +import attrs + from data_diff.abcs.database_types import ( Integer, Float, @@ -34,11 +36,13 @@ def import_databricks(): return databricks +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") +@attrs.define(frozen=False) class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Databricks" ROUNDS_ON_PREC_LOSS = True @@ -98,6 +103,7 @@ def parse_table_name(self, name: str) -> DbPath: return tuple(i for i in path if i is not None) +@attrs.define(frozen=False, init=False, kw_only=True) class Databricks(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "databricks://:@/" diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index 48784565..d70395a3 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,5 +1,7 @@ from typing import Any, Dict, Union +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Timestamp, @@ -40,11 +42,13 @@ def import_duckdb(): return duckdb +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::INTEGER") +@attrs.define(frozen=False) class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) @@ -68,6 +73,7 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) +@attrs.define(frozen=False) class Dialect( BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue ): @@ -131,14 +137,15 @@ def current_timestamp(self) -> str: return "current_timestamp" +@attrs.define(frozen=False, init=False, kw_only=True) class DuckDB(Database): dialect = Dialect() SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it CONNECT_URI_HELP = "duckdb://@" CONNECT_URI_PARAMS = ["database", "dbpath"] - _args: Dict[str, Any] - _conn: Any + _args: Dict[str, Any] = attrs.field(init=False) + _conn: Any = attrs.field(init=False) def __init__(self, **kw): super().__init__() diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 968c410d..2666ea18 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,7 @@ from typing import Any, Dict, Optional + +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -34,6 +37,7 @@ def import_mssql(): return pyodbc +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.precision > 0: @@ -53,11 +57,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"FORMAT({value}, 'N{coltype.precision}')" +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" +@attrs.define(frozen=False) class Dialect( BaseDialect, Mixin_Schema, @@ -158,6 +164,7 @@ def constant_values(self, rows) -> str: return f"VALUES {values}" +@attrs.define(frozen=False, init=False, kw_only=True) class MsSQL(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "mssql://:@//" diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index a48aa3fe..29b26e2a 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,5 +1,7 @@ from typing import Any, Dict +import attrs + from data_diff.abcs.database_types import ( Datetime, Timestamp, @@ -42,11 +44,13 @@ def import_mysql(): return mysql.connector +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -62,6 +66,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" +@attrs.define(frozen=False) class Dialect( BaseDialect, Mixin_Schema, @@ -132,6 +137,7 @@ def set_timezone_to_utc(self) -> str: return "SET @@session.time_zone='+00:00'" +@attrs.define(frozen=False, init=False, kw_only=True) class MySQL(ThreadedDatabase): dialect = Dialect() SUPPORTS_ALPHANUMS = False diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index e548c7c3..03f94b07 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List, Optional +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Decimal, @@ -38,6 +40,7 @@ def import_oracle(): return oracledb +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash @@ -45,6 +48,7 @@ def md5_as_int(self, s: str) -> str: return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) @@ -68,6 +72,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"to_char({value}, '{format_str}')" +@attrs.define(frozen=False) class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -80,6 +85,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define(frozen=False) class Dialect( BaseDialect, Mixin_Schema, @@ -176,6 +182,7 @@ def current_timestamp(self) -> str: return "LOCALTIMESTAMP" +@attrs.define(frozen=False, init=False, kw_only=True) class Oracle(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "oracle://:@/" diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index 1c622739..2b044b36 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,4 +1,7 @@ from typing import Any, ClassVar, Dict, List, Type + +import attrs + from data_diff.abcs.database_types import ( ColType, DbPath, @@ -36,11 +39,13 @@ def import_postgresql(): return psycopg2 +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -61,6 +66,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"{value}::text" +@attrs.define(frozen=False) class PostgresqlDialect( BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue ): @@ -120,6 +126,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define(frozen=False, init=False, kw_only=True) class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() SUPPORTS_UNIQUE_CONSTAINT = True @@ -127,6 +134,7 @@ class PostgreSQL(ThreadedDatabase): CONNECT_URI_PARAMS = ["database?"] _args: Dict[str, Any] + _conn: Any def __init__(self, *, thread_count, **kw): super().__init__(thread_count=thread_count) diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index 7a8c7eba..2aef9991 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -2,6 +2,8 @@ import re from typing import Any +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -51,11 +53,13 @@ def import_presto(): return prestodb +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type @@ -154,6 +158,7 @@ def current_timestamp(self) -> str: return "current_timestamp" +@attrs.define(frozen=False, init=False, kw_only=True) class Presto(Database): dialect = Dialect() CONNECT_URI_HELP = "presto://@//" diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index f7aa8d06..b6820424 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,4 +1,7 @@ from typing import ClassVar, List, Dict, Type + +import attrs + from data_diff.abcs.database_types import ( ColType, Float, @@ -19,11 +22,13 @@ ) +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" +@attrs.define(frozen=False) class Mixin_NormalizeValue(Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -52,6 +57,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"nvl2({value}, json_serialize({value}), NULL)" +@attrs.define(frozen=False) class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Redshift" TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { @@ -75,6 +81,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define(frozen=False, init=False, kw_only=True) class Redshift(PostgreSQL): dialect = Dialect() CONNECT_URI_HELP = "redshift://:@/" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 19898185..d83c0f40 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,6 +1,8 @@ from typing import Any, Union, List import logging +import attrs + from data_diff.abcs.database_types import ( Timestamp, TimestampTZ, @@ -41,11 +43,13 @@ def import_snowflake(): return snowflake, serialization, default_backend +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -62,6 +66,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::int") +@attrs.define(frozen=False) class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("INFORMATION_SCHEMA", "TABLES") @@ -148,6 +153,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define(frozen=False, init=False, kw_only=True) class Snowflake(Database): dialect = Dialect() CONNECT_URI_HELP = "snowflake://:@//?warehouse=" diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index bce2bddf..f26fb973 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,5 +1,7 @@ from typing import Any +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto @@ -17,6 +19,7 @@ def import_trino(): Mixin_MD5 = presto.Mixin_MD5 +@attrs.define(frozen=False) class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -36,6 +39,7 @@ class Dialect(presto.Dialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5 name = "Trino" +@attrs.define(frozen=False, init=False, kw_only=True) class Trino(presto.Presto): dialect = Dialect() CONNECT_URI_HELP = "trino://@//" diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index 6df48085..dda4e1dd 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List +import attrs + from data_diff.utils import match_regexps from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -37,11 +39,13 @@ def import_vertica(): return vertica_python +@attrs.define(frozen=False) class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" +@attrs.define(frozen=False) class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -151,6 +155,7 @@ def current_timestamp(self) -> str: return "current_timestamp(6)" +@attrs.define(frozen=False, init=False, kw_only=True) class Vertica(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "vertica://:@/" diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index fcb6ce24..c4110c84 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -2,7 +2,9 @@ from collections import defaultdict import json from pathlib import Path -from typing import List, Dict, Tuple, Set, Optional +from typing import Any, List, Dict, Tuple, Set, Optional + +import attrs import yaml from pydantic import BaseModel @@ -94,7 +96,22 @@ class TDatadiffConfig(BaseModel): datasource_id: Optional[int] = None +@attrs.define(frozen=False, init=False) class DbtParser: + dbt_runner: Optional[Any] # dbt.cli.main.dbtRunner if installed + project_dir: Path + connection: Dict[str, Any] + project_dict: Dict[str, Any] + dev_manifest_obj: ManifestJsonConfig + prod_manifest_obj: Optional[ManifestJsonConfig] + dbt_user_id: str + dbt_version: str + dbt_project_id: str + requires_upper: bool + threads: Optional[int] + unique_columns: Dict[str, Set[str]] + profiles_dir: Path + def __init__( self, profiles_dir_override: Optional[str] = None, diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 26a93241..a3f52181 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -179,6 +179,7 @@ def get_stats_dict(self, is_dbt: bool = False): return json_output +@attrs.define(frozen=True) class TableDiffer(ThreadBase, ABC): bisection_factor = 32 stats: dict = {} diff --git a/data_diff/lexicographic_space.py b/data_diff/lexicographic_space.py index 7ef80686..b7d88e36 100644 --- a/data_diff/lexicographic_space.py +++ b/data_diff/lexicographic_space.py @@ -20,6 +20,9 @@ from random import randint, randrange from typing import Tuple + +import attrs + from data_diff.utils import safezip Vector = Tuple[int] @@ -56,6 +59,7 @@ def irandrange(start, stop): return randrange(start, stop) +@attrs.define(frozen=True) class LexicographicSpace: """Lexicographic space of arbitrary dimensions. diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index e342de0e..7cf2319b 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -20,6 +20,7 @@ class QB_TypeError(QueryBuilderError): pass +@attrs.define(frozen=True) class Root: "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" @@ -82,6 +83,7 @@ def _drop_skips_dict(exprs_dict): return {k: v for k, v in exprs_dict.items() if v is not SKIP} +@attrs.define(frozen=True) class ITable: @property def source_table(self) -> "ITable": # not always Self, it can be a substitute @@ -374,6 +376,7 @@ class UnaryOp(LazyOps, ExprNode): expr: Expr +@attrs.define(frozen=True) class BinBoolOp(BinOp): @property def type(self) -> Optional[type]: @@ -689,6 +692,7 @@ def name(self): return self._get_resolved().name +@attrs.define(frozen=True) class This: """Builder object for accessing table attributes. @@ -741,6 +745,7 @@ def type(self) -> Optional[type]: return str +@attrs.define(frozen=True) class CurrentTimestamp(ExprNode): @property def type(self) -> Optional[type]: @@ -759,6 +764,7 @@ class TimeTravel(ITable): # TODO: Unused? # DDL +@attrs.define(frozen=True) class Statement(Compilable, Root): @property def type(self) -> Optional[type]: diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index 205c2211..ca8953c4 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -1,6 +1,9 @@ from typing import Generator +import attrs + +@attrs.define(frozen=True) class _SKIP: def __repr__(self): return "SKIP" diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 7a36fd58..55b2d9d5 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -7,6 +7,8 @@ from time import sleep from typing import Callable, Iterator, Optional +import attrs + class AutoPriorityQueue(PriorityQueue): """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs @@ -34,10 +36,10 @@ class PriorityThreadPoolExecutor(ThreadPoolExecutor): def __init__(self, *args): super().__init__(*args) - self._work_queue = AutoPriorityQueue() +@attrs.define(frozen=False, init=False) class ThreadedYielder(Iterable): """Yields results from multiple threads into a single iterator, ordered by priority. @@ -50,6 +52,11 @@ class ThreadedYielder(Iterable): _yield: deque _exception: Optional[None] + _pool: ThreadPoolExecutor + _futures: deque + _yield: deque = attrs.field(alias="_yield") # Python keyword! + _exception: Optional[None] + def __init__(self, max_workers: Optional[int] = None): super().__init__() self._pool = PriorityThreadPoolExecutor(max_workers) diff --git a/data_diff/utils.py b/data_diff/utils.py index 93e6db98..98e0e624 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -11,6 +11,7 @@ from datetime import datetime from uuid import UUID +import attrs from packaging.version import parse as parse_version import requests from tabulate import tabulate @@ -115,6 +116,7 @@ def as_insensitive(self): alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase +@attrs.define(frozen=True) class ArithString: @classmethod def new(cls, *args, **kw) -> Self: @@ -126,6 +128,7 @@ def range(self, other: "ArithString", count: int) -> List[Self]: return [self.new(int=i) for i in checkpoints] +# @attrs.define # not as long as it inherits from UUID class ArithUUID(UUID, ArithString): "A UUID that supports basic arithmetic (add, sub)" @@ -174,25 +177,21 @@ def alphanums_to_numbers(s1: str, s2: str): return n1, n2 +@attrs.define(frozen=True) class ArithAlphanumeric(ArithString): _str: str - _max_len: Optional[int] + _max_len: Optional[int] = None - def __init__(self, s: str, max_len=None): - super().__init__() - - if s is None: + def __attrs_post_init__(self): + if self._str is None: raise ValueError("Alphanum string cannot be None") - if max_len and len(s) > max_len: - raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") + if self._max_len and len(self._str) > self._max_len: + raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {self._max_len}") - for ch in s: + for ch in self._str: if ch not in alphanums: raise ValueError(f"Unexpected character {ch} in alphanum string") - self._str = s - self._max_len = max_len - # @property # def int(self): # return alphanumToNumber(self._str, alphanums)