From 7a7b88e8f32919640bc61c894f122a15c7dc3a7e Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 15:21:19 +0200 Subject: [PATCH 1/5] Convert class-level constants to properties for attrs compatibility `attrs` cannot use multiple inheritance when both parents introduce their attributes (as documented). Only one side can inherit the attributes, other bases must be pure interfaces/protocols. Reimplement the `ExprNode.type` via properties to exclude it from the sight of `attrs`. --- data_diff/abcs/database_types.py | 12 +++++++--- data_diff/queries/ast_classes.py | 40 +++++++++++++++++++++++++------- data_diff/queries/extras.py | 7 ++++-- 3 files changed, 45 insertions(+), 14 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 43764b39..e5ec393a 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -15,7 +15,9 @@ @dataclass class ColType: - supported = True + @property + def supported(self) -> bool: + return True @dataclass @@ -132,7 +134,9 @@ def make_value(self, value): @dataclass class Text(StringType): - supported = False + @property + def supported(self) -> bool: + return False # In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT. @@ -169,4 +173,6 @@ def __post_init__(self): class UnknownColType(ColType): text: str - supported = False + @property + def supported(self) -> bool: + return False diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 56efdb20..b06439d5 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -28,7 +28,9 @@ class Root: class ExprNode(Compilable): "Base class for query expression nodes" - type: Any = None + @property + def type(self) -> Optional[type]: + return None def _dfs_values(self): yield self @@ -216,7 +218,10 @@ class Concat(ExprNode): class Count(ExprNode): expr: Expr = None distinct: bool = False - type = int + + @property + def type(self) -> Optional[type]: + return int class LazyOps: @@ -337,7 +342,10 @@ def then(self, then: Expr) -> CaseWhen: class IsDistinctFrom(ExprNode, LazyOps): a: Expr b: Expr - type = bool + + @property + def type(self) -> Optional[type]: + return bool @dataclass(eq=False, order=False) @@ -361,7 +369,9 @@ class UnaryOp(ExprNode, LazyOps): class BinBoolOp(BinOp): - type = bool + @property + def type(self) -> Optional[type]: + return bool @dataclass(eq=False, order=False) @@ -700,7 +710,10 @@ def __getitem__(self, name): class In(ExprNode): expr: Expr list: Sequence[Expr] - type = bool + + @property + def type(self) -> Optional[type]: + return bool @dataclass @@ -711,7 +724,9 @@ class Cast(ExprNode): @dataclass class Random(ExprNode, LazyOps): - type = float + @property + def type(self) -> Optional[type]: + return float @dataclass @@ -722,11 +737,16 @@ class ConstantTable(ExprNode): @dataclass class Explain(ExprNode, Root): select: Select - type = str + + @property + def type(self) -> Optional[type]: + return str class CurrentTimestamp(ExprNode): - type = datetime + @property + def type(self) -> Optional[type]: + return datetime @dataclass @@ -742,7 +762,9 @@ class TimeTravel(ITable): class Statement(Compilable, Root): - type = None + @property + def type(self) -> Optional[type]: + return None @dataclass diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index bb0c8299..556325f6 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,5 +1,5 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" -from typing import Callable, Sequence +from typing import Callable, Optional, Sequence from runtype import dataclass from data_diff.abcs.database_types import ColType @@ -11,7 +11,10 @@ class NormalizeAsString(ExprNode): expr: ExprNode expr_type: ColType = None - type = str + + @property + def type(self) -> Optional[type]: + return str @dataclass From 5b95ce7ca62d2810e52dd39fdf3bab766c3c01e4 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 23:00:03 +0200 Subject: [PATCH 2/5] Convert source_table & schema to overridable properties for attrs compatibility --- data_diff/queries/ast_classes.py | 62 +++++++++++++++----------------- 1 file changed, 28 insertions(+), 34 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index b06439d5..ba282b59 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -83,8 +83,14 @@ def _drop_skips_dict(exprs_dict): class ITable: - source_table: Any - schema: Schema = None + + @property + def source_table(self) -> "ITable": # not always Self, it can be a substitute + return self + + @property + def schema(self) -> Schema: + raise NotImplementedError def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> "ITable": """Choose new columns, based on the old ones. (aka Projection) @@ -389,11 +395,7 @@ def type(self): @dataclass class TablePath(ExprNode, ITable): path: DbPath - schema: Optional[Schema] = field(default=None, repr=False) - - @property - def source_table(self) -> Self: - return self + schema: Schema # overrides the inherited property # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): @@ -475,9 +477,17 @@ def time_travel( @dataclass class TableAlias(ExprNode, ITable): - source_table: ITable + table: ITable name: str + @property + def source_table(self) -> ITable: + return self.table + + @property + def schema(self) -> Schema: + return self.table.schema + @dataclass class Join(ExprNode, ITable, Root): @@ -487,11 +497,7 @@ class Join(ExprNode, ITable, Root): columns: Sequence[Expr] = None @property - def source_table(self) -> Self: - return self - - @property - def schema(self): + def schema(self) -> Schema: assert self.columns # TODO Implement SELECT * s = self.source_tables[0].schema # TODO validate types match between both tables return type(s)({c.name: c.type for c in self.columns}) @@ -533,10 +539,6 @@ class GroupBy(ExprNode, ITable, Root): values: Sequence[Expr] = None having_exprs: Sequence[Expr] = None - @property - def source_table(self): - return self - def __post_init__(self): assert self.keys or self.values @@ -564,17 +566,13 @@ class TableOp(ExprNode, ITable, Root): table1: ITable table2: ITable - @property - def source_table(self): - return self - @property def type(self): # TODO ensure types of both tables are compatible return self.table1.type @property - def schema(self): + def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema assert len(s1) == len(s2) @@ -594,16 +592,12 @@ class Select(ExprNode, ITable, Root): optimizer_hints: Sequence[Expr] = None @property - def schema(self): + def schema(self) -> Schema: s = self.table.schema if s is None or self.columns is None: return s return type(s)({c.name: c.type for c in self.columns}) - @property - def source_table(self): - return self - @classmethod def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs @@ -642,14 +636,18 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, @dataclass class Cte(ExprNode, ITable): - source_table: Expr + table: Expr name: str = None params: Sequence[str] = None @property - def schema(self): + def source_table(self) -> "ITable": + return self.table + + @property + def schema(self) -> Schema: # TODO add cte to schema - return self.source_table.schema + return self.table.schema def _named_exprs_as_aliases(named_exprs): @@ -819,7 +817,3 @@ class Commit(Statement): class Param(ExprNode, ITable): """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" name: str - - @property - def source_table(self): - return self From 22e8ff2ea6548273ab8b25cc922b20cfba869897 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 15:57:19 +0200 Subject: [PATCH 3/5] Clarify optional (nullable) fields: explicit is better than implicit We are going to do strict type checking. The default values of fields that clearly contradict the declared types is an error for MyPy and all other type checkers and IDEs. Remove the implicit behaviour and make nullable fields explicitly declared as such. --- data_diff/databases/base.py | 2 +- data_diff/info_tree.py | 10 +++--- data_diff/joindiff_tables.py | 2 +- data_diff/queries/ast_classes.py | 54 ++++++++++++++++---------------- data_diff/queries/extras.py | 4 +-- data_diff/table_segment.py | 18 +++++------ 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index d55c59a6..62c3c18a 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -810,7 +810,7 @@ def set_timezone_to_utc(self) -> str: @dataclass class QueryResult: rows: list - columns: list = None + columns: Optional[list] = None def __iter__(self): return iter(self.rows) diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index bd2282ba..b30ba2f2 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -10,13 +10,13 @@ class SegmentInfo: tables: List[TableSegment] - diff: List[Union[Tuple[Any, ...], List[Any]]] = None - diff_schema: Tuple[Tuple[str, type], ...] = None - is_diff: bool = None - diff_count: int = None + diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None + diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None + is_diff: Optional[bool] = None + diff_count: Optional[int] = None rowcounts: Dict[int, int] = field(default_factory=dict) - max_rows: int = None + max_rows: Optional[int] = None def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): self.diff_schema = schema diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 91e2aecd..14834ffd 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -138,7 +138,7 @@ class JoinDiffer(TableDiffer): validate_unique_key: bool = True sample_exclusive_rows: bool = False - materialize_to_table: DbPath = None + materialize_to_table: Optional[DbPath] = None materialize_all_rows: bool = False table_write_limit: int = TABLE_WRITE_LIMIT skip_null_keys: bool = False diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index ba282b59..3036bc9b 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -55,7 +55,7 @@ def cast_to(self, to): @dataclass class Code(ExprNode, Root): code: str - args: Dict[str, Expr] = None + args: Optional[Dict[str, Expr]] = None def _expr_type(e: Expr) -> type: @@ -217,7 +217,7 @@ def intersect(self, other: "ITable"): @dataclass class Concat(ExprNode): exprs: list - sep: str = None + sep: Optional[str] = None @dataclass @@ -294,7 +294,7 @@ class WhenThen(ExprNode): @dataclass class CaseWhen(ExprNode): cases: Sequence[WhenThen] - else_expr: Expr = None + else_expr: Optional[Expr] = None @property def type(self): @@ -492,9 +492,9 @@ def schema(self) -> Schema: @dataclass class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] - op: str = None - on_exprs: Sequence[Expr] = None - columns: Sequence[Expr] = None + op: Optional[str] = None + on_exprs: Optional[Sequence[Expr]] = None + columns: Optional[Sequence[Expr]] = None @property def schema(self) -> Schema: @@ -535,9 +535,9 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: @dataclass class GroupBy(ExprNode, ITable, Root): table: ITable - keys: Sequence[Expr] = None # IKey? - values: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None + keys: Optional[Sequence[Expr]] = None # IKey? + values: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None def __post_init__(self): assert self.keys or self.values @@ -581,15 +581,15 @@ def schema(self) -> Schema: @dataclass class Select(ExprNode, ITable, Root): - table: Expr = None - columns: Sequence[Expr] = None - where_exprs: Sequence[Expr] = None - order_by_exprs: Sequence[Expr] = None - group_by_exprs: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None - limit_expr: int = None + table: Optional[Expr] = None + columns: Optional[Sequence[Expr]] = None + where_exprs: Optional[Sequence[Expr]] = None + order_by_exprs: Optional[Sequence[Expr]] = None + group_by_exprs: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None + limit_expr: Optional[int] = None distinct: bool = False - optimizer_hints: Sequence[Expr] = None + optimizer_hints: Optional[Sequence[Expr]] = None @property def schema(self) -> Schema: @@ -637,8 +637,8 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, @dataclass class Cte(ExprNode, ITable): table: Expr - name: str = None - params: Sequence[str] = None + name: Optional[str] = None + params: Optional[Sequence[str]] = None @property def source_table(self) -> "ITable": @@ -668,7 +668,7 @@ def resolve_names(source_table, exprs): @dataclass(frozen=False, eq=False, order=False) class _ResolveColumn(ExprNode, LazyOps): resolve_name: str - resolved: Expr = None + resolved: Optional[Expr] = None def resolve(self, expr: Expr): if self.resolved is not None: @@ -751,9 +751,9 @@ def type(self) -> Optional[type]: class TimeTravel(ITable): table: TablePath before: bool = False - timestamp: datetime = None - offset: int = None - statement: str = None + timestamp: Optional[datetime] = None + offset: Optional[int] = None + statement: Optional[str] = None # DDL @@ -768,9 +768,9 @@ def type(self) -> Optional[type]: @dataclass class CreateTable(Statement): path: TablePath - source_table: Expr = None + source_table: Optional[Expr] = None if_not_exists: bool = False - primary_keys: List[str] = None + primary_keys: Optional[List[str]] = None @dataclass @@ -788,8 +788,8 @@ class TruncateTable(Statement): class InsertToTable(Statement): path: TablePath expr: Expr - columns: List[str] = None - returning_exprs: List[str] = None + columns: Optional[List[str]] = None + returning_exprs: Optional[List[str]] = None def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 556325f6..4467bd0a 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -10,7 +10,7 @@ @dataclass class NormalizeAsString(ExprNode): expr: ExprNode - expr_type: ColType = None + expr_type: Optional[ColType] = None @property def type(self) -> Optional[type]: @@ -20,7 +20,7 @@ def type(self) -> Optional[type]: @dataclass class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode - apply_func: Callable = None + apply_func: Optional[Callable] = None @dataclass diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index aaf747f6..015e5bc4 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -112,18 +112,18 @@ class TableSegment: # Columns key_columns: Tuple[str, ...] - update_column: str = None + update_column: Optional[str] = None extra_columns: Tuple[str, ...] = () # Restrict the segment - min_key: Vector = None - max_key: Vector = None - min_update: DbTime = None - max_update: DbTime = None - where: str = None - - case_sensitive: bool = True - _schema: Schema = None + min_key: Optional[Vector] = None + max_key: Optional[Vector] = None + min_update: Optional[DbTime] = None + max_update: Optional[DbTime] = None + where: Optional[str] = None + + case_sensitive: Optional[bool] = True + _schema: Optional[Schema] = None def __post_init__(self): if not self.update_column and (self.min_update or self.max_update): From 8f8b8446295a89066b9814ae6d17e8cc35fab407 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 16:21:52 +0200 Subject: [PATCH 4/5] Convert all runtypes to attrs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `attrs` is much more beneficial: * `attrs` is supported by type checkers, such as MyPy & IDEs * `attrs` is widely used and industry-proven * `attrs` is explicit in its declarations, there is no magic * `attrs` has slots But mainly for the first item — type checking by type checkers. --- data_diff/abcs/compiler.py | 4 ++ data_diff/abcs/database_types.py | 26 ++++---- data_diff/cloud/datafold_api.py | 6 +- data_diff/databases/_connect.py | 5 +- data_diff/databases/base.py | 31 ++++----- data_diff/diff_tables.py | 13 ++-- data_diff/format.py | 30 ++++----- data_diff/hashdiff_tables.py | 17 +++-- data_diff/info_tree.py | 11 ++-- data_diff/joindiff_tables.py | 11 ++-- data_diff/queries/api.py | 2 +- data_diff/queries/ast_classes.py | 107 ++++++++++++++++--------------- data_diff/queries/extras.py | 10 +-- data_diff/table_segment.py | 14 ++-- poetry.lock | 2 +- pyproject.toml | 2 +- tests/test_database.py | 5 +- tests/test_diff_tables.py | 10 +-- tests/test_joindiff.py | 6 +- tests/test_sql.py | 5 +- 20 files changed, 164 insertions(+), 153 deletions(-) diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py index 4a847d05..a134c845 100644 --- a/data_diff/abcs/compiler.py +++ b/data_diff/abcs/compiler.py @@ -1,9 +1,13 @@ from abc import ABC +import attrs + +@attrs.define class AbstractCompiler(ABC): pass +@attrs.define(eq=False) class Compilable(ABC): pass diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index e5ec393a..8a60cb2c 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -3,7 +3,7 @@ from typing import Tuple, Union from datetime import datetime -from runtype import dataclass +import attrs from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -13,14 +13,14 @@ DbTime = datetime -@dataclass +@attrs.define class ColType: @property def supported(self) -> bool: return True -@dataclass +@attrs.define class PrecisionType(ColType): precision: int rounds: Union[bool, Unknown] = Unknown @@ -50,7 +50,7 @@ class Date(TemporalType): pass -@dataclass +@attrs.define class NumericType(ColType): # 'precision' signifies how many fractional digits (after the dot) we want to compare precision: int @@ -84,7 +84,7 @@ def python_type(self) -> type: return decimal.Decimal -@dataclass +@attrs.define class StringType(ColType): python_type = str @@ -122,7 +122,7 @@ class String_VaryingAlphanum(String_Alphanum): pass -@dataclass +@attrs.define class String_FixedAlphanum(String_Alphanum): length: int @@ -132,7 +132,7 @@ def make_value(self, value): return self.python_type(value, max_len=self.length) -@dataclass +@attrs.define class Text(StringType): @property def supported(self) -> bool: @@ -140,12 +140,12 @@ def supported(self) -> bool: # In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT. -@dataclass +@attrs.define class JSON(ColType): pass -@dataclass +@attrs.define class Array(ColType): item_type: ColType @@ -155,21 +155,21 @@ class Array(ColType): # For example, in BigQuery: # - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type # - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals -@dataclass +@attrs.define class Struct(ColType): pass -@dataclass +@attrs.define class Integer(NumericType, IKey): precision: int = 0 python_type: type = int - def __post_init__(self): + def __attrs_post_init__(self): assert self.precision == 0 -@dataclass +@attrs.define class UnknownColType(ColType): text: str diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index ea5a04e8..99c908ee 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -1,9 +1,9 @@ import base64 -import dataclasses import enum import time from typing import Any, Dict, List, Optional, Type, Tuple +import attrs import pydantic import requests from typing_extensions import Self @@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel): result: Optional[TCloudDataSourceTestResult] -@dataclasses.dataclass +@attrs.define class DatafoldAPI: api_key: str host: str = "https://app.datafold.com" timeout: int = 30 - def __post_init__(self): + def __attrs_post_init__(self): self.host = self.host.rstrip("/") self.headers = { "Authorization": f"Key {self.api_key}", diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 8f842123..4c6a2a4f 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -3,10 +3,11 @@ from itertools import zip_longest from contextlib import suppress import weakref + +import attrs import dsnparse import toml -from runtype import dataclass from typing_extensions import Self from data_diff.databases.base import Database, ThreadedDatabase @@ -25,7 +26,7 @@ from data_diff.databases.mssql import MsSQL -@dataclass +@attrs.define class MatchUriPath: database_cls: Type[Database] diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 62c3c18a..7d5889d1 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,6 +1,6 @@ import abc import functools -from dataclasses import field +import random from datetime import datetime import math import sys @@ -14,7 +14,7 @@ import decimal import contextvars -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.abcs.compiler import AbstractCompiler @@ -78,7 +78,7 @@ class CompileError(Exception): # and be used only as a CompilingContext (a counter/data-bearing class). # As a result, it becomes low-level util, and the circular dependency auto-resolves. # Meanwhile, the easy fix is to simply move the Compiler here. -@dataclass +@attrs.define class Compiler(AbstractCompiler): """ Compiler bears the context for a single compilation. @@ -95,11 +95,11 @@ class Compiler(AbstractCompiler): in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag - _table_context: List = field(default_factory=list) # List[ITable] - _subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe + _table_context: List = attrs.field(factory=list) # List[ITable] + _subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe root: bool = True - _counter: List = field(default_factory=lambda: [0]) + _counter: List = attrs.field(factory=lambda: [0]) @property def dialect(self) -> "Dialect": @@ -119,7 +119,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: - return self.replace(_table_context=self._table_context + list(tables), **kw) + return attrs.evolve(self, table_context=self._table_context + list(tables), **kw) def parse_table_name(t): @@ -251,7 +251,7 @@ def _compile(self, compiler: Compiler, elem) -> str: if elem is None: return "NULL" elif isinstance(elem, Compilable): - return self.render_compilable(compiler.replace(root=False), elem) + return self.render_compilable(attrs.evolve(compiler, root=False), elem) elif isinstance(elem, str): return f"'{elem}'" elif isinstance(elem, (int, float)): @@ -361,7 +361,7 @@ def render_column(self, c: Compiler, elem: Column) -> str: return self.quote(elem.name) def render_cte(self, parent_c: Compiler, elem: Cte) -> str: - c: Compiler = parent_c.replace(_table_context=[], in_select=False) + c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False) compiled = self.compile(c, elem.source_table) name = elem.name or parent_c.new_unique_name() @@ -472,7 +472,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str: return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}" def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: - c: Compiler = parent_c.replace(in_select=False) + c: Compiler = attrs.evolve(parent_c, in_select=False) table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}" if parent_c.in_select: table_expr = f"({table_expr}) {c.new_unique_name()}" @@ -484,7 +484,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: return self.compile(c, elem._get_resolved()) def render_select(self, parent_c: Compiler, elem: Select) -> str: - c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table) + c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table) compile_fn = functools.partial(self.compile, c) columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*" @@ -522,7 +522,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: def render_join(self, parent_c: Compiler, elem: Join) -> str: tables = [ - t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables + t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name()) for t in elem.source_tables ] c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if elem.op is None else f" {elem.op} JOIN " @@ -555,7 +555,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None: return self.compile( c, - elem.table.replace( + attrs.evolve( + elem.table, columns=columns, group_by_exprs=[Code(k) for k in keys], having_exprs=elem.having_exprs, @@ -568,7 +569,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" ) select = ( - f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" + f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" ) if c.in_select: @@ -807,7 +808,7 @@ def set_timezone_to_utc(self) -> str: T = TypeVar("T", bound=BaseDialect) -@dataclass +@attrs.define class QueryResult: rows: list columns: Optional[list] = None diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 08c18391..71bdfb38 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -3,14 +3,13 @@ import time from abc import ABC, abstractmethod -from dataclasses import field from enum import Enum from contextlib import contextmanager from operator import methodcaller from typing import Dict, Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from runtype import dataclass +import attrs from data_diff.info_tree import InfoTree, SegmentInfo from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector @@ -31,7 +30,7 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] -@dataclass +@attrs.define class ThreadBase: "Provides utility methods for optional threading" @@ -72,7 +71,7 @@ def _run_in_background(self, *funcs): f.result() -@dataclass +@attrs.define class DiffStats: diff_by_sign: Dict[str, int] table1_count: int @@ -82,12 +81,12 @@ class DiffStats: extra_column_diffs: Optional[Dict[str, int]] -@dataclass +@attrs.define class DiffResultWrapper: diff: iter # DiffResult info_tree: InfoTree stats: dict - result_list: list = field(default_factory=list) + result_list: list = attrs.field(factory=list) def __iter__(self): yield from self.result_list @@ -203,7 +202,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: if is_tracking_enabled(): - options = dict(self) + options = attrs.asdict(self, recurse=False) options["differ_name"] = type(self).__name__ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) diff --git a/data_diff/format.py b/data_diff/format.py index 8a515e1b..5c3761be 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, List, Dict, Tuple, Type -from runtype import dataclass +import attrs from data_diff.diff_tables import DiffResultWrapper from data_diff.abcs.database_types import ( JSON, @@ -86,7 +86,7 @@ def jsonify( ).json() -@dataclass +@attrs.define class JsonExclusiveRowValue: """ Value of a single column in a row @@ -96,7 +96,7 @@ class JsonExclusiveRowValue: value: Any -@dataclass +@attrs.define class JsonDiffRowValue: """ Pair of diffed values for 2 rows with equal PKs @@ -108,19 +108,19 @@ class JsonDiffRowValue: isPK: bool -@dataclass +@attrs.define class Total: dataset1: int dataset2: int -@dataclass +@attrs.define class ExclusiveRows: dataset1: int dataset2: int -@dataclass +@attrs.define class Rows: total: Total exclusive: ExclusiveRows @@ -128,18 +128,18 @@ class Rows: unchanged: int -@dataclass +@attrs.define class Stats: diffCounts: Dict[str, int] -@dataclass +@attrs.define class JsonDiffSummary: rows: Rows stats: Stats -@dataclass +@attrs.define class ExclusiveColumns: dataset1: List[str] dataset2: List[str] @@ -172,14 +172,14 @@ class ColumnKind(Enum): ] -@dataclass +@attrs.define class Column: name: str type: str kind: str -@dataclass +@attrs.define class JsonColumnsSummary: dataset1: List[Column] dataset2: List[Column] @@ -188,19 +188,19 @@ class JsonColumnsSummary: typeChanged: List[str] -@dataclass +@attrs.define class ExclusiveDiff: dataset1: List[Dict[str, JsonExclusiveRowValue]] dataset2: List[Dict[str, JsonExclusiveRowValue]] -@dataclass +@attrs.define class RowsDiff: exclusive: ExclusiveDiff diff: List[Dict[str, JsonDiffRowValue]] -@dataclass +@attrs.define class FailedDiff: status: str # Literal ["failed"] model: str @@ -211,7 +211,7 @@ class FailedDiff: version: str = "1.0.0" -@dataclass +@attrs.define class JsonDiff: status: str # Literal ["success"] result: str # Literal ["different", "identical"] diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 3fc030ec..c6914c96 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -1,12 +1,11 @@ import os -from dataclasses import field from numbers import Number import logging from collections import defaultdict from typing import Iterator from operator import attrgetter -from runtype import dataclass +import attrs from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON from data_diff.info_tree import InfoTree @@ -53,7 +52,7 @@ def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator: yield from v -@dataclass +@attrs.define class HashDiffer(TableDiffer): """Finds the diff between two SQL tables @@ -74,9 +73,9 @@ class HashDiffer(TableDiffer): bisection_factor: int = DEFAULT_BISECTION_FACTOR bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) - def __post_init__(self): + def __attrs_post_init__(self): # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") @@ -102,8 +101,8 @@ def _validate_and_adjust_columns(self, table1, table2): if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds) elif isinstance(col1, (NumericType, Boolean)): if not isinstance(col2, (NumericType, Boolean)): @@ -115,9 +114,9 @@ def _validate_and_adjust_columns(self, table1, table2): logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") if lowest.precision != col1.precision: - table1._schema[c1] = col1.replace(precision=lowest.precision) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision) if lowest.precision != col2.precision: - table2._schema[c2] = col2.replace(precision=lowest.precision) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision) elif isinstance(col1, ColType_UUID): if not isinstance(col2, ColType_UUID): diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index b30ba2f2..abed5bae 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -1,12 +1,11 @@ -from dataclasses import field from typing import List, Dict, Optional, Any, Tuple, Union -from runtype import dataclass +import attrs from data_diff.table_segment import TableSegment -@dataclass(frozen=False) +@attrs.define class SegmentInfo: tables: List[TableSegment] @@ -15,7 +14,7 @@ class SegmentInfo: is_diff: Optional[bool] = None diff_count: Optional[int] = None - rowcounts: Dict[int, int] = field(default_factory=dict) + rowcounts: Dict[int, int] = attrs.field(factory=dict) max_rows: Optional[int] = None def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): @@ -40,10 +39,10 @@ def update_from_children(self, child_infos): } -@dataclass +@attrs.define class InfoTree: info: SegmentInfo - children: List["InfoTree"] = field(default_factory=list) + children: List["InfoTree"] = attrs.field(factory=list) def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: int = None): node = InfoTree(SegmentInfo([table1, table2], max_rows=max_rows)) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 14834ffd..39cf6e5c 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -1,14 +1,13 @@ """Provides classes for performing a table diff using JOIN """ -from dataclasses import field from decimal import Decimal from functools import partial import logging -from typing import List +from typing import List, Optional from itertools import chain -from runtype import dataclass +import attrs from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake from data_diff.abcs.database_types import NumericType, DbPath @@ -58,7 +57,7 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database - c: Compiler = c.replace(root=False) # we're compiling fragments, not full queries + c: Compiler = attrs.evolve(c, root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}" elif isinstance(db, Presto): @@ -111,7 +110,7 @@ def json_friendly_value(v): return v -@dataclass +@attrs.define class JoinDiffer(TableDiffer): """Finds the diff between two SQL tables in the same database, using JOINs. @@ -143,7 +142,7 @@ class JoinDiffer(TableDiffer): table_write_limit: int = TABLE_WRITE_LIMIT skip_null_keys: bool = False - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: db = table1.database diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 82786871..22dc1a90 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -69,7 +69,7 @@ def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath if schema and not isinstance(schema, CaseAwareMapping): assert isinstance(schema, dict) schema = CaseSensitiveDict(schema) - return TablePath(path, schema) + return TablePath(path, schema=schema) def or_(*exprs: Expr): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 3036bc9b..ab21f60a 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,8 +1,7 @@ -from dataclasses import field from datetime import datetime from typing import Any, Generator, List, Optional, Sequence, Union, Dict -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import ArithString @@ -25,6 +24,7 @@ class Root: "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" +@attrs.define(eq=False) class ExprNode(Compilable): "Base class for query expression nodes" @@ -34,7 +34,7 @@ def type(self) -> Optional[type]: def _dfs_values(self): yield self - for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + for k, vs in attrs.asdict(self, recurse=False).items(): if k == "source_table": # Skip data-sources, we're only interested in data-parameters continue @@ -52,7 +52,7 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, float, datetime, ArithString, None] -@dataclass +@attrs.define(eq=False) class Code(ExprNode, Root): code: str args: Optional[Dict[str, Expr]] = None @@ -64,7 +64,7 @@ def _expr_type(e: Expr) -> type: return type(e) -@dataclass +@attrs.define(eq=False) class Alias(ExprNode): expr: Expr name: str @@ -214,13 +214,13 @@ def intersect(self, other: "ITable"): return TableOp("INTERSECT", self, other) -@dataclass +@attrs.define(eq=False) class Concat(ExprNode): exprs: list sep: Optional[str] = None -@dataclass +@attrs.define(eq=False) class Count(ExprNode): expr: Expr = None distinct: bool = False @@ -230,6 +230,7 @@ def type(self) -> Optional[type]: return int +@attrs.define(eq=False) class LazyOps: def __add__(self, other): return BinOp("+", [self, other]) @@ -279,19 +280,19 @@ def min(self): return Func("MIN", [self]) -@dataclass(eq=False) -class Func(ExprNode, LazyOps): +@attrs.define(eq=False) +class Func(LazyOps, ExprNode): name: str args: Sequence[Expr] -@dataclass +@attrs.define(eq=False) class WhenThen(ExprNode): when: Expr then: Expr -@dataclass +@attrs.define(eq=False) class CaseWhen(ExprNode): cases: Sequence[WhenThen] else_expr: Optional[Expr] = None @@ -329,10 +330,10 @@ def else_(self, then: Expr) -> Self: if self.else_expr is not None: raise QueryBuilderError(f"Else clause already specified in {self}") - return self.replace(else_expr=then) + return attrs.evolve(self, else_expr=then) -@dataclass +@attrs.define(eq=False) class QB_When: "Partial case-when, used for query-building" casewhen: CaseWhen @@ -341,11 +342,11 @@ class QB_When: def then(self, then: Expr) -> CaseWhen: """Add a 'then' clause after a 'when' was added.""" case = WhenThen(self.when, then) - return self.casewhen.replace(cases=self.casewhen.cases + [case]) + return attrs.evolve(self.casewhen, cases=self.casewhen.cases + [case]) -@dataclass(eq=False, order=False) -class IsDistinctFrom(ExprNode, LazyOps): +@attrs.define(eq=False) +class IsDistinctFrom(LazyOps, ExprNode): a: Expr b: Expr @@ -354,8 +355,8 @@ def type(self) -> Optional[type]: return bool -@dataclass(eq=False, order=False) -class BinOp(ExprNode, LazyOps): +@attrs.define(eq=False) +class BinOp(LazyOps, ExprNode): op: str args: Sequence[Expr] @@ -368,8 +369,8 @@ def type(self): return t -@dataclass -class UnaryOp(ExprNode, LazyOps): +@attrs.define(eq=False) +class UnaryOp(LazyOps, ExprNode): op: str expr: Expr @@ -380,8 +381,8 @@ def type(self) -> Optional[type]: return bool -@dataclass(eq=False, order=False) -class Column(ExprNode, LazyOps): +@attrs.define(eq=False) +class Column(LazyOps, ExprNode): source_table: ITable name: str @@ -392,7 +393,7 @@ def type(self): return self.source_table.schema[self.name] -@dataclass +@attrs.define(eq=False) class TablePath(ExprNode, ITable): path: DbPath schema: Schema # overrides the inherited property @@ -475,7 +476,7 @@ def time_travel( assert offset is None and statement is None -@dataclass +@attrs.define(eq=False) class TableAlias(ExprNode, ITable): table: ITable name: str @@ -489,7 +490,7 @@ def schema(self) -> Schema: return self.table.schema -@dataclass +@attrs.define(eq=False) class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] op: Optional[str] = None @@ -513,7 +514,7 @@ def on(self, *exprs) -> Self: if not exprs: return self - return self.replace(on_exprs=(self.on_exprs or []) + exprs) + return attrs.evolve(self, on_exprs=(self.on_exprs or []) + exprs) def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: """Select fields to return from the JOIN operation @@ -529,17 +530,17 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: exprs += _named_exprs_as_aliases(named_exprs) resolve_names(self.source_table, exprs) # TODO Ensure exprs <= self.columns ? - return self.replace(columns=exprs) + return attrs.evolve(self, columns=exprs) -@dataclass +@attrs.define(eq=False) class GroupBy(ExprNode, ITable, Root): table: ITable keys: Optional[Sequence[Expr]] = None # IKey? values: Optional[Sequence[Expr]] = None having_exprs: Optional[Sequence[Expr]] = None - def __post_init__(self): + def __attrs_post_init__(self): assert self.keys or self.values def having(self, *exprs) -> Self: @@ -550,17 +551,17 @@ def having(self, *exprs) -> Self: return self resolve_names(self.table, exprs) - return self.replace(having_exprs=(self.having_exprs or []) + exprs) + return attrs.evolve(self, having_exprs=(self.having_exprs or []) + exprs) def agg(self, *exprs) -> Self: """Select aggregated fields for the group-by.""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) resolve_names(self.table, exprs) - return self.replace(values=(self.values or []) + exprs) + return attrs.evolve(self, values=(self.values or []) + exprs) -@dataclass +@attrs.define(eq=False) class TableOp(ExprNode, ITable, Root): op: str table1: ITable @@ -579,7 +580,7 @@ def schema(self) -> Schema: return s1 -@dataclass +@attrs.define(eq=False) class Select(ExprNode, ITable, Root): table: Optional[Expr] = None columns: Optional[Sequence[Expr]] = None @@ -631,10 +632,10 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, else: raise ValueError(k) - return table.replace(**kwargs) + return attrs.evolve(table, **kwargs) -@dataclass +@attrs.define(eq=False) class Cte(ExprNode, ITable): table: Expr name: Optional[str] = None @@ -665,8 +666,8 @@ def resolve_names(source_table, exprs): i += 1 -@dataclass(frozen=False, eq=False, order=False) -class _ResolveColumn(ExprNode, LazyOps): +@attrs.define(eq=False) +class _ResolveColumn(LazyOps, ExprNode): resolve_name: str resolved: Optional[Expr] = None @@ -704,7 +705,7 @@ def __getitem__(self, name): return _ResolveColumn(name) -@dataclass +@attrs.define(eq=False) class In(ExprNode): expr: Expr list: Sequence[Expr] @@ -714,25 +715,25 @@ def type(self) -> Optional[type]: return bool -@dataclass +@attrs.define(eq=False) class Cast(ExprNode): expr: Expr target_type: Expr -@dataclass -class Random(ExprNode, LazyOps): +@attrs.define(eq=False) +class Random(LazyOps, ExprNode): @property def type(self) -> Optional[type]: return float -@dataclass +@attrs.define(eq=False) class ConstantTable(ExprNode): rows: Sequence[Sequence] -@dataclass +@attrs.define(eq=False) class Explain(ExprNode, Root): select: Select @@ -747,8 +748,8 @@ def type(self) -> Optional[type]: return datetime -@dataclass -class TimeTravel(ITable): +@attrs.define(eq=False) +class TimeTravel(ITable): # TODO: Unused? table: TablePath before: bool = False timestamp: Optional[datetime] = None @@ -765,7 +766,7 @@ def type(self) -> Optional[type]: return None -@dataclass +@attrs.define(eq=False) class CreateTable(Statement): path: TablePath source_table: Optional[Expr] = None @@ -773,18 +774,18 @@ class CreateTable(Statement): primary_keys: Optional[List[str]] = None -@dataclass +@attrs.define(eq=False) class DropTable(Statement): path: TablePath if_exists: bool = False -@dataclass +@attrs.define(eq=False) class TruncateTable(Statement): path: TablePath -@dataclass +@attrs.define(eq=False) class InsertToTable(Statement): path: TablePath expr: Expr @@ -805,15 +806,15 @@ def returning(self, *exprs) -> Self: return self resolve_names(self.path, exprs) - return self.replace(returning_exprs=exprs) + return attrs.evolve(self, returning_exprs=exprs) -@dataclass +@attrs.define(eq=False) class Commit(Statement): """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP.""" -@dataclass -class Param(ExprNode, ITable): +@attrs.define(eq=False) +class Param(ExprNode, ITable): # TODO: Unused? """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" name: str diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index 4467bd0a..5441146f 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,13 +1,13 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" from typing import Callable, Optional, Sequence -from runtype import dataclass -from data_diff.abcs.database_types import ColType +import attrs +from data_diff.abcs.database_types import ColType from data_diff.queries.ast_classes import Expr, ExprNode -@dataclass +@attrs.define class NormalizeAsString(ExprNode): expr: ExprNode expr_type: Optional[ColType] = None @@ -17,12 +17,12 @@ def type(self) -> Optional[type]: return str -@dataclass +@attrs.define class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode apply_func: Optional[Callable] = None -@dataclass +@attrs.define class Checksum(ExprNode): exprs: Sequence[Expr] diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index 015e5bc4..405415a8 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,9 +1,9 @@ import time -from typing import List, Tuple +from typing import List, Optional, Tuple import logging from itertools import product -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import safezip, Vector @@ -85,7 +85,7 @@ def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector] return res -@dataclass +@attrs.define class TableSegment: """Signifies a segment of rows (and selected columns) within a table @@ -125,7 +125,7 @@ class TableSegment: case_sensitive: Optional[bool] = True _schema: Optional[Schema] = None - def __post_init__(self): + def __attrs_post_init__(self): if not self.update_column and (self.min_update or self.max_update): raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.") @@ -142,7 +142,7 @@ def _where(self): def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) - return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) + return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." @@ -199,7 +199,7 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS def new(self, **kwargs) -> Self: """Creates a copy of the instance using 'replace()'""" - return self.replace(**kwargs) + return attrs.evolve(self, **kwargs) def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: if self.min_key is not None: @@ -210,7 +210,7 @@ def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: assert min_key < self.max_key assert max_key <= self.max_key - return self.replace(min_key=min_key, max_key=max_key) + return attrs.evolve(self, min_key=min_key, max_key=max_key) @property def relevant_columns(self) -> List[str]: diff --git a/poetry.lock b/poetry.lock index afd70d75..a95c7580 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2030,4 +2030,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "55cde03a00788572dac6310e7bbf61bd2522d70217056a51608bcfc429440fbf" +content-hash = "c7da70c19432ca716980f3421182d54d7f5d2e0d8bbd7e20dbaf521c8ef7d0fb" diff --git a/pyproject.toml b/pyproject.toml index 0ae8f0a4..2ceb0c27 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ packages = [{ include = "data_diff" }] [tool.poetry.dependencies] pydantic = "1.10.12" python = "^3.7.2" -runtype = "^0.2.6" dsnparse = "<0.2.0" click = "^8.1" rich = "*" @@ -47,6 +46,7 @@ urllib3 = "<2" oracledb = {version = "*", optional=true} pyodbc = {version="^4.0.39", optional=true} typing-extensions = ">=4.0.1" +attrs = "^23.1.0" [tool.poetry.dev-dependencies] parameterized = "*" diff --git a/tests/test_database.py b/tests/test_database.py index b17cb7f0..d0c1d3a4 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Callable, List, Tuple +import attrs import pytz from data_diff import connect @@ -128,8 +129,8 @@ def test_correct_timezone(self): raw_schema = db.query_table_schema(t.path) schema = db._process_table_schema(t.path, raw_schema) schema = create_schema(self.database.name, t, schema, case_sensitive=True) - t = t.replace(schema=schema) - t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) + t = attrs.evolve(t, schema=schema) + t.schema["created_at"] = attrs.evolve(t.schema["created_at"], precision=t.schema["created_at"].precision) tbl = table(name, schema=t.schema) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index b5885a26..26d555dc 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -3,6 +3,8 @@ import uuid import unittest +import attrs + from data_diff.queries.api import table, this, commit, code from data_diff.utils import ArithAlphanumeric, numberToAlphanum @@ -382,13 +384,13 @@ def test_string_keys(self): self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_where_sampling(self): - a = self.a.replace(where="1=1") + a = attrs.evolve(self.a, where="1=1") differ = HashDiffer(bisection_factor=2) diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - a_empty = self.a.replace(where="1=0") + a_empty = attrs.evolve(self.a, where="1=0") self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) @@ -519,11 +521,11 @@ def test_case_awareness(self): [src_table.create(), src_table.insert_rows([[1, 9, time_obj], [2, 2, time_obj]], columns=cols), commit] ) - res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) + res = tuple(attrs.evolve(self.table, key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) self.assertEqual(res, (("1",), ("2",))) self.assertRaises( - KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range + KeyError, attrs.evolve(self.table, key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index b2c5c419..dd424017 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,6 +1,8 @@ from typing import List from datetime import datetime +import attrs + from data_diff.queries.ast_classes import TablePath from data_diff.queries.api import table, commit from data_diff.table_segment import TableSegment @@ -114,7 +116,7 @@ def test_diff_small_tables(self): # Test materialize materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") - mdiffer = self.differ.replace(materialize_to_table=materialize_path) + mdiffer = attrs.evolve(self.differ, materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) @@ -126,7 +128,7 @@ def test_diff_small_tables(self): self.connection.query(t.drop()) # Test materialize all rows - mdiffer = mdiffer.replace(materialize_all_rows=True) + mdiffer = attrs.evolve(mdiffer, materialize_all_rows=True) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) rows = self.connection.query(t.select(), List[tuple]) diff --git a/tests/test_sql.py b/tests/test_sql.py index 2dcab403..f45931ec 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,5 +1,7 @@ import unittest +import attrs + from tests.common import TEST_MYSQL_CONN_STRING from data_diff.databases import connect @@ -18,8 +20,9 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): + compiler = attrs.evolve(self.compiler, root=False) self.assertEqual( - "`marine_mammals`.`walrus`", self.compiler.replace(root=False).compile(table("marine_mammals", "walrus")) + "`marine_mammals`.`walrus`", compiler.compile(table("marine_mammals", "walrus")) ) def test_compile_select(self): From c8b1989eb576df2c3ae4227e8c44d33f7badcc90 Mon Sep 17 00:00:00 2001 From: Sergey Vasilyev Date: Fri, 29 Sep 2023 16:46:51 +0200 Subject: [PATCH 5/5] Convert the remaining classes to attrs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Since we now use `attrs` for some classes, let's use them for them all — at least those belonging to the same hierarchies. This will ensure that all classes are slotted and will strictly check that we define attributes properly, especially in cases of multiple inheritance. Except for Pydantic models and Python exceptions. --- data_diff/abcs/database_types.py | 17 +++++++++++++++++ data_diff/abcs/mixins.py | 10 ++++++++++ data_diff/databases/_connect.py | 4 ++++ data_diff/databases/base.py | 19 ++++++++++++------- data_diff/databases/bigquery.py | 8 ++++++++ data_diff/databases/clickhouse.py | 6 ++++++ data_diff/databases/databricks.py | 6 ++++++ data_diff/databases/duckdb.py | 7 +++++++ data_diff/databases/mssql.py | 6 ++++++ data_diff/databases/mysql.py | 5 +++++ data_diff/databases/oracle.py | 6 ++++++ data_diff/databases/postgresql.py | 13 ++++++++++--- data_diff/databases/presto.py | 4 ++++ data_diff/databases/redshift.py | 13 ++++++++++--- data_diff/databases/snowflake.py | 5 +++++ data_diff/databases/trino.py | 3 +++ data_diff/databases/vertica.py | 4 ++++ data_diff/dbt_parser.py | 3 +++ data_diff/diff_tables.py | 1 + data_diff/lexicographic_space.py | 4 ++++ data_diff/queries/ast_classes.py | 6 ++++++ data_diff/queries/base.py | 3 +++ data_diff/thread_utils.py | 9 +++++++-- data_diff/utils.py | 27 ++++++++++++++++++--------- 24 files changed, 165 insertions(+), 24 deletions(-) diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 8a60cb2c..844d99c5 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -26,26 +26,33 @@ class PrecisionType(ColType): rounds: Union[bool, Unknown] = Unknown + +@attrs.define class Boolean(ColType): precision = 0 +@attrs.define class TemporalType(PrecisionType): pass +@attrs.define class Timestamp(TemporalType): pass +@attrs.define class TimestampTZ(TemporalType): pass +@attrs.define class Datetime(TemporalType): pass +@attrs.define class Date(TemporalType): pass @@ -56,14 +63,17 @@ class NumericType(ColType): precision: int +@attrs.define class FractionalType(NumericType): pass +@attrs.define class Float(FractionalType): python_type = float +@attrs.define class IKey(ABC): "Interface for ColType, for using a column as a key in table." @@ -76,6 +86,7 @@ def make_value(self, value): return self.python_type(value) +@attrs.define class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key @property def python_type(self) -> type: @@ -89,22 +100,27 @@ class StringType(ColType): python_type = str +@attrs.define class ColType_UUID(ColType, IKey): python_type = ArithUUID +@attrs.define class ColType_Alphanum(ColType, IKey): python_type = ArithAlphanumeric +@attrs.define class Native_UUID(ColType_UUID): pass +@attrs.define class String_UUID(ColType_UUID, StringType): pass +@attrs.define class String_Alphanum(ColType_Alphanum, StringType): @staticmethod def test_value(value: str) -> bool: @@ -118,6 +134,7 @@ def make_value(self, value): return self.python_type(value) +@attrs.define class String_VaryingAlphanum(String_Alphanum): pass diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py index 9a30f41e..1789da99 100644 --- a/data_diff/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -1,4 +1,7 @@ from abc import ABC, abstractmethod + +import attrs + from data_diff.abcs.database_types import ( Array, TemporalType, @@ -13,10 +16,12 @@ from data_diff.abcs.compiler import Compilable +@attrs.define class AbstractMixin(ABC): "A mixin for a database dialect" +@attrs.define class AbstractMixin_NormalizeValue(AbstractMixin): @abstractmethod def to_comparable(self, value: str, coltype: ColType) -> str: @@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.to_string(value) +@attrs.define class AbstractMixin_MD5(AbstractMixin): """Methods for calculating an MD6 hash as an integer.""" @@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" +@attrs.define class AbstractMixin_Schema(AbstractMixin): """Methods for querying the database schema @@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: """ +@attrs.define class AbstractMixin_RandomSample(AbstractMixin): @abstractmethod def random_sample_n(self, tbl: str, size: int) -> str: @@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: # """ +@attrs.define class AbstractMixin_TimeTravel(AbstractMixin): @abstractmethod def time_travel( @@ -173,6 +182,7 @@ def time_travel( """ +@attrs.define class AbstractMixin_OptimizerHints(AbstractMixin): @abstractmethod def optimizer_hints(self, optimizer_hints: str) -> str: diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 4c6a2a4f..d0e61582 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -93,8 +93,11 @@ def match_path(self, dsn): } +@attrs.define(init=False) class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" + database_by_scheme: Dict[str, Database] + match_uri_path: Dict[str, MatchUriPath] conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): @@ -284,6 +287,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: return db_conf +@attrs.define(init=False) class Connect_SetUTC(Connect): """Provides methods for connecting to a supported database using a URL or connection dict. diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 7d5889d1..082555cc 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -5,7 +5,7 @@ import math import sys import logging -from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar +from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -156,15 +156,14 @@ def _one(seq): return x +@attrs.define class ThreadLocalInterpreter: """An interpeter used to execute a sequence of queries within the same thread and cursor. Useful for cursor-sensitive operations, such as creating a temporary table. """ - - def __init__(self, compiler: Compiler, gen: Generator): - self.gen = gen - self.compiler = compiler + compiler: Compiler + gen: Generator def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) @@ -189,6 +188,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("information_schema", "tables") @@ -205,6 +205,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: # TODO use a more efficient algorithm, when the table count is known @@ -214,15 +215,17 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return tbl.where(Random() < ratio) +@attrs.define class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): def optimizer_hints(self, hints: str) -> str: return f"/*+ {hints} */ " +@attrs.define class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY = False SUPPORTS_INDEXES = False - TYPE_CLASSES: Dict[str, type] = {} + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} MIXINS = frozenset() PLACEHOLDER_TABLE = None # Used for Oracle @@ -522,7 +525,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: def render_join(self, parent_c: Compiler, elem: Join) -> str: tables = [ - t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name()) for t in elem.source_tables + t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name()) for t in elem.source_tables ] c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if elem.op is None else f" {elem.op} JOIN " @@ -823,6 +826,7 @@ def __getitem__(self, i): return self.rows[i] +@attrs.define class Database(abc.ABC): """Base abstract class for databases. @@ -1099,6 +1103,7 @@ def is_autocommit(self) -> bool: "Return whether the database autocommits changes. When false, COMMIT statements are skipped." +@attrs.define(init=False, slots=False) class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index feb98bde..b164f2b0 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,8 @@ import re from typing import Any, List, Union + +import attrs + from data_diff.abcs.database_types import ( ColType, Array, @@ -50,11 +53,13 @@ def import_bigquery_service_account(): return service_account +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -99,6 +104,7 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: return f"to_json_string({value})" +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -112,6 +118,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Mixin_TimeTravel(AbstractMixin_TimeTravel): def time_travel( self, @@ -139,6 +146,7 @@ def time_travel( ) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 9366b922..0c03536c 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,5 +1,7 @@ from typing import Optional, Type +import attrs + from data_diff.databases.base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -35,12 +37,14 @@ def import_clickhouse(): return clickhouse_driver +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -99,6 +103,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" +@attrs.define class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False @@ -163,6 +168,7 @@ def current_timestamp(self) -> str: return "now()" +@attrs.define class Clickhouse(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "clickhouse://:@/" diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 67d0528d..d1102df1 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -2,6 +2,8 @@ from typing import Dict, Sequence import logging +import attrs + from data_diff.abcs.database_types import ( Integer, Float, @@ -34,11 +36,13 @@ def import_databricks(): return databricks +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") +@attrs.define class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Databricks" ROUNDS_ON_PREC_LOSS = True @@ -99,6 +104,7 @@ def parse_table_name(self, name: str) -> DbPath: return tuple(i for i in path if i is not None) +@attrs.define(init=False) class Databricks(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "databricks://:@/" diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index ba6afd63..1b2c3376 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,5 +1,7 @@ from typing import Union +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Timestamp, @@ -40,11 +42,13 @@ def import_duckdb(): return duckdb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::INTEGER") +@attrs.define class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) @@ -68,6 +73,7 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False @@ -130,6 +136,7 @@ def current_timestamp(self) -> str: return "current_timestamp" +@attrs.define class DuckDB(Database): dialect = Dialect() SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 28d67c99..6dcd465f 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,7 @@ from typing import Optional + +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -34,6 +37,7 @@ def import_mssql(): return pyodbc +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.precision > 0: @@ -53,11 +57,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"FORMAT({value}, 'N{coltype.precision}')" +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index d6dcba9e..5f1ebe6d 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,3 +1,5 @@ +import attrs + from data_diff.abcs.database_types import ( Datetime, Timestamp, @@ -40,11 +42,13 @@ def import_mysql(): return mysql.connector +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -60,6 +64,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MySQL" ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index f0309c11..03f2a2cd 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,5 +1,7 @@ from typing import Dict, List, Optional +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Decimal, @@ -38,6 +40,7 @@ def import_oracle(): return oracledb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash @@ -45,6 +48,7 @@ def md5_as_int(self, s: str) -> str: return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) @@ -68,6 +72,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"to_char({value}, '{format_str}')" +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -80,6 +85,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Oracle" SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index dec9b9d3..30455b2f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,6 +1,9 @@ -from typing import List +from typing import ClassVar, Dict, List, Type + +import attrs + from data_diff.abcs.database_types import ( - DbPath, + ColType, DbPath, JSON, Timestamp, TimestampTZ, @@ -35,11 +38,13 @@ def import_postgresql(): return psycopg2 +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -60,6 +65,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"{value}::text" +@attrs.define class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True @@ -67,7 +73,7 @@ class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeVal SUPPORTS_INDEXES = True MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { # Timestamps "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, @@ -118,6 +124,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() SUPPORTS_UNIQUE_CONSTAINT = True diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index b4c45751..6148b134 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,8 @@ from functools import partial import re +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -50,11 +52,13 @@ def import_presto(): return prestodb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index d11029c0..2930ba2d 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,6 +1,9 @@ -from typing import List, Dict +from typing import ClassVar, List, Dict, Type + +import attrs + from data_diff.abcs.database_types import ( - Float, + ColType, Float, JSON, TemporalType, FractionalType, @@ -18,11 +21,13 @@ ) +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" +@attrs.define class Mixin_NormalizeValue(Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -51,9 +56,10 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"nvl2({value}, json_serialize({value}), NULL)" +@attrs.define class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Redshift" - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { **PostgresqlDialect.TYPE_CLASSES, "double": Float, "real": Float, @@ -74,6 +80,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define class Redshift(PostgreSQL): dialect = Dialect() CONNECT_URI_HELP = "redshift://:@/" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 3a558425..8e73b6b7 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,6 +1,8 @@ from typing import Union, List import logging +import attrs + from data_diff.abcs.database_types import ( Timestamp, TimestampTZ, @@ -41,11 +43,13 @@ def import_snowflake(): return snowflake, serialization, default_backend +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -62,6 +66,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::int") +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("INFORMATION_SCHEMA", "TABLES") diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index e2095758..a2b4442b 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,3 +1,5 @@ +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto @@ -15,6 +17,7 @@ def import_trino(): Mixin_MD5 = presto.Mixin_MD5 +@attrs.define class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e8fe9ec2..12a46228 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,5 +1,7 @@ from typing import List +import attrs + from data_diff.utils import match_regexps from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -37,11 +39,13 @@ def import_vertica(): return vertica_python +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index d5976f74..18394f4f 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -3,6 +3,8 @@ import json from pathlib import Path from typing import List, Dict, Tuple, Set, Optional + +import attrs import yaml from pydantic import BaseModel @@ -94,6 +96,7 @@ class TDatadiffConfig(BaseModel): datasource_id: Optional[int] = None +@attrs.define(init=False) class DbtParser: def __init__( self, diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 71bdfb38..45d5697d 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -179,6 +179,7 @@ def get_stats_dict(self, is_dbt: bool = False): return json_output +@attrs.define class TableDiffer(ThreadBase, ABC): bisection_factor = 32 stats: dict = {} diff --git a/data_diff/lexicographic_space.py b/data_diff/lexicographic_space.py index 88cf863d..2989c2a9 100644 --- a/data_diff/lexicographic_space.py +++ b/data_diff/lexicographic_space.py @@ -20,6 +20,9 @@ from random import randint, randrange from typing import Tuple + +import attrs + from data_diff.utils import safezip Vector = Tuple[int] @@ -56,6 +59,7 @@ def irandrange(start, stop): return randrange(start, stop) +@attrs.define class LexicographicSpace: """Lexicographic space of arbitrary dimensions. diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index ab21f60a..ce1b28f1 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -20,6 +20,7 @@ class QB_TypeError(QueryBuilderError): pass +@attrs.define class Root: "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" @@ -82,6 +83,7 @@ def _drop_skips_dict(exprs_dict): return {k: v for k, v in exprs_dict.items() if v is not SKIP} +@attrs.define class ITable: @property @@ -375,6 +377,7 @@ class UnaryOp(LazyOps, ExprNode): expr: Expr +@attrs.define class BinBoolOp(BinOp): @property def type(self) -> Optional[type]: @@ -690,6 +693,7 @@ def name(self): return self._get_resolved().name +@attrs.define class This: """Builder object for accessing table attributes. @@ -742,6 +746,7 @@ def type(self) -> Optional[type]: return str +@attrs.define class CurrentTimestamp(ExprNode): @property def type(self) -> Optional[type]: @@ -760,6 +765,7 @@ class TimeTravel(ITable): # TODO: Unused? # DDL +@attrs.define class Statement(Compilable, Root): @property def type(self) -> Optional[type]: diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index 205c2211..e895cd8c 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -1,6 +1,9 @@ from typing import Generator +import attrs + +@attrs.define class _SKIP: def __repr__(self): return "SKIP" diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1be94ad4..4a7348f2 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -7,6 +7,8 @@ from time import sleep from typing import Callable, Iterator, Optional +import attrs + class AutoPriorityQueue(PriorityQueue): """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs @@ -31,19 +33,22 @@ class PriorityThreadPoolExecutor(ThreadPoolExecutor): XXX WARNING: Might break in future versions of Python """ - def __init__(self, *args): super().__init__(*args) - self._work_queue = AutoPriorityQueue() +@attrs.define(init=False) class ThreadedYielder(Iterable): """Yields results from multiple threads into a single iterator, ordered by priority. To add a source iterator, call ``submit()`` with a function that returns an iterator. Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first) """ + _pool: ThreadPoolExecutor + _futures: deque + _yield: deque = attrs.field(alias='_yield') # Python keyword! + _exception: Optional[None] def __init__(self, max_workers: Optional[int] = None): self._pool = PriorityThreadPoolExecutor(max_workers) diff --git a/data_diff/utils.py b/data_diff/utils.py index b725285e..59ede253 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -3,13 +3,14 @@ import re import string from abc import abstractmethod -from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union +from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Optional, Sequence, TypeVar, Union from urllib.parse import urlparse import operator import threading from datetime import datetime from uuid import UUID +import attrs from packaging.version import parse as parse_version import requests from tabulate import tabulate @@ -61,6 +62,7 @@ def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: V = TypeVar("V") +@attrs.define class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod def get_key(self, key: str) -> str: @@ -70,7 +72,10 @@ def new(self, initial=()) -> Self: return type(self)(initial) +@attrs.define(repr=False, init=False) class CaseInsensitiveDict(CaseAwareMapping): + _dict: Dict[str, Any] + def __init__(self, initial): self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} @@ -99,6 +104,7 @@ def __repr__(self) -> str: return repr(dict(self.items())) +@attrs.define class CaseSensitiveDict(dict, CaseAwareMapping): def get_key(self, key): self[key] # Throw KeyError if key doesn't exist @@ -114,6 +120,7 @@ def as_insensitive(self): alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase +@attrs.define class ArithString: @classmethod def new(cls, *args, **kw) -> Self: @@ -125,6 +132,7 @@ def range(self, other: "ArithString", count: int) -> List[Self]: return [self.new(int=i) for i in checkpoints] +@attrs.define class ArithUUID(UUID, ArithString): "A UUID that supports basic arithmetic (add, sub)" @@ -173,20 +181,21 @@ def alphanums_to_numbers(s1: str, s2: str): return n1, n2 +@attrs.define class ArithAlphanumeric(ArithString): - def __init__(self, s: str, max_len=None): - if s is None: + _str: str + _max_len: Optional[int] = None + + def __attrs_post_init__(self): + if self._str is None: raise ValueError("Alphanum string cannot be None") - if max_len and len(s) > max_len: - raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") + if self._max_len and len(self._str) > self._max_len: + raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {self._max_len}") - for ch in s: + for ch in self._str: if ch not in alphanums: raise ValueError(f"Unexpected character {ch} in alphanum string") - self._str = s - self._max_len = max_len - # @property # def int(self): # return alphanumToNumber(self._str, alphanums)