diff --git a/data_diff/abcs/compiler.py b/data_diff/abcs/compiler.py index 4a847d05..a134c845 100644 --- a/data_diff/abcs/compiler.py +++ b/data_diff/abcs/compiler.py @@ -1,9 +1,13 @@ from abc import ABC +import attrs + +@attrs.define class AbstractCompiler(ABC): pass +@attrs.define(eq=False) class Compilable(ABC): pass diff --git a/data_diff/abcs/database_types.py b/data_diff/abcs/database_types.py index 43764b39..844d99c5 100644 --- a/data_diff/abcs/database_types.py +++ b/data_diff/abcs/database_types.py @@ -3,7 +3,7 @@ from typing import Tuple, Union from datetime import datetime -from runtype import dataclass +import attrs from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown @@ -13,55 +13,67 @@ DbTime = datetime -@dataclass +@attrs.define class ColType: - supported = True + @property + def supported(self) -> bool: + return True -@dataclass +@attrs.define class PrecisionType(ColType): precision: int rounds: Union[bool, Unknown] = Unknown + +@attrs.define class Boolean(ColType): precision = 0 +@attrs.define class TemporalType(PrecisionType): pass +@attrs.define class Timestamp(TemporalType): pass +@attrs.define class TimestampTZ(TemporalType): pass +@attrs.define class Datetime(TemporalType): pass +@attrs.define class Date(TemporalType): pass -@dataclass +@attrs.define class NumericType(ColType): # 'precision' signifies how many fractional digits (after the dot) we want to compare precision: int +@attrs.define class FractionalType(NumericType): pass +@attrs.define class Float(FractionalType): python_type = float +@attrs.define class IKey(ABC): "Interface for ColType, for using a column as a key in table." @@ -74,6 +86,7 @@ def make_value(self, value): return self.python_type(value) +@attrs.define class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key @property def python_type(self) -> type: @@ -82,27 +95,32 @@ def python_type(self) -> type: return decimal.Decimal -@dataclass +@attrs.define class StringType(ColType): python_type = str +@attrs.define class ColType_UUID(ColType, IKey): python_type = ArithUUID +@attrs.define class ColType_Alphanum(ColType, IKey): python_type = ArithAlphanumeric +@attrs.define class Native_UUID(ColType_UUID): pass +@attrs.define class String_UUID(ColType_UUID, StringType): pass +@attrs.define class String_Alphanum(ColType_Alphanum, StringType): @staticmethod def test_value(value: str) -> bool: @@ -116,11 +134,12 @@ def make_value(self, value): return self.python_type(value) +@attrs.define class String_VaryingAlphanum(String_Alphanum): pass -@dataclass +@attrs.define class String_FixedAlphanum(String_Alphanum): length: int @@ -130,18 +149,20 @@ def make_value(self, value): return self.python_type(value, max_len=self.length) -@dataclass +@attrs.define class Text(StringType): - supported = False + @property + def supported(self) -> bool: + return False # In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT. -@dataclass +@attrs.define class JSON(ColType): pass -@dataclass +@attrs.define class Array(ColType): item_type: ColType @@ -151,22 +172,24 @@ class Array(ColType): # For example, in BigQuery: # - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type # - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals -@dataclass +@attrs.define class Struct(ColType): pass -@dataclass +@attrs.define class Integer(NumericType, IKey): precision: int = 0 python_type: type = int - def __post_init__(self): + def __attrs_post_init__(self): assert self.precision == 0 -@dataclass +@attrs.define class UnknownColType(ColType): text: str - supported = False + @property + def supported(self) -> bool: + return False diff --git a/data_diff/abcs/mixins.py b/data_diff/abcs/mixins.py index 9a30f41e..1789da99 100644 --- a/data_diff/abcs/mixins.py +++ b/data_diff/abcs/mixins.py @@ -1,4 +1,7 @@ from abc import ABC, abstractmethod + +import attrs + from data_diff.abcs.database_types import ( Array, TemporalType, @@ -13,10 +16,12 @@ from data_diff.abcs.compiler import Compilable +@attrs.define class AbstractMixin(ABC): "A mixin for a database dialect" +@attrs.define class AbstractMixin_NormalizeValue(AbstractMixin): @abstractmethod def to_comparable(self, value: str, coltype: ColType) -> str: @@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str: return self.to_string(value) +@attrs.define class AbstractMixin_MD5(AbstractMixin): """Methods for calculating an MD6 hash as an integer.""" @@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str: "Provide SQL for computing md5 and returning an int" +@attrs.define class AbstractMixin_Schema(AbstractMixin): """Methods for querying the database schema @@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: """ +@attrs.define class AbstractMixin_RandomSample(AbstractMixin): @abstractmethod def random_sample_n(self, tbl: str, size: int) -> str: @@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str: # """ +@attrs.define class AbstractMixin_TimeTravel(AbstractMixin): @abstractmethod def time_travel( @@ -173,6 +182,7 @@ def time_travel( """ +@attrs.define class AbstractMixin_OptimizerHints(AbstractMixin): @abstractmethod def optimizer_hints(self, optimizer_hints: str) -> str: diff --git a/data_diff/cloud/datafold_api.py b/data_diff/cloud/datafold_api.py index ea5a04e8..99c908ee 100644 --- a/data_diff/cloud/datafold_api.py +++ b/data_diff/cloud/datafold_api.py @@ -1,9 +1,9 @@ import base64 -import dataclasses import enum import time from typing import Any, Dict, List, Optional, Type, Tuple +import attrs import pydantic import requests from typing_extensions import Self @@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel): result: Optional[TCloudDataSourceTestResult] -@dataclasses.dataclass +@attrs.define class DatafoldAPI: api_key: str host: str = "https://app.datafold.com" timeout: int = 30 - def __post_init__(self): + def __attrs_post_init__(self): self.host = self.host.rstrip("/") self.headers = { "Authorization": f"Key {self.api_key}", diff --git a/data_diff/databases/_connect.py b/data_diff/databases/_connect.py index 8f842123..d0e61582 100644 --- a/data_diff/databases/_connect.py +++ b/data_diff/databases/_connect.py @@ -3,10 +3,11 @@ from itertools import zip_longest from contextlib import suppress import weakref + +import attrs import dsnparse import toml -from runtype import dataclass from typing_extensions import Self from data_diff.databases.base import Database, ThreadedDatabase @@ -25,7 +26,7 @@ from data_diff.databases.mssql import MsSQL -@dataclass +@attrs.define class MatchUriPath: database_cls: Type[Database] @@ -92,8 +93,11 @@ def match_path(self, dsn): } +@attrs.define(init=False) class Connect: """Provides methods for connecting to a supported database using a URL or connection dict.""" + database_by_scheme: Dict[str, Database] + match_uri_path: Dict[str, MatchUriPath] conn_cache: MutableMapping[Hashable, Database] def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME): @@ -283,6 +287,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable: return db_conf +@attrs.define(init=False) class Connect_SetUTC(Connect): """Provides methods for connecting to a supported database using a URL or connection dict. diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index d55c59a6..082555cc 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -1,11 +1,11 @@ import abc import functools -from dataclasses import field +import random from datetime import datetime import math import sys import logging -from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar +from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar from functools import partial, wraps from concurrent.futures import ThreadPoolExecutor import threading @@ -14,7 +14,7 @@ import decimal import contextvars -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.abcs.compiler import AbstractCompiler @@ -78,7 +78,7 @@ class CompileError(Exception): # and be used only as a CompilingContext (a counter/data-bearing class). # As a result, it becomes low-level util, and the circular dependency auto-resolves. # Meanwhile, the easy fix is to simply move the Compiler here. -@dataclass +@attrs.define class Compiler(AbstractCompiler): """ Compiler bears the context for a single compilation. @@ -95,11 +95,11 @@ class Compiler(AbstractCompiler): in_select: bool = False # Compilation runtime flag in_join: bool = False # Compilation runtime flag - _table_context: List = field(default_factory=list) # List[ITable] - _subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe + _table_context: List = attrs.field(factory=list) # List[ITable] + _subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe root: bool = True - _counter: List = field(default_factory=lambda: [0]) + _counter: List = attrs.field(factory=lambda: [0]) @property def dialect(self) -> "Dialect": @@ -119,7 +119,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath: return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: - return self.replace(_table_context=self._table_context + list(tables), **kw) + return attrs.evolve(self, table_context=self._table_context + list(tables), **kw) def parse_table_name(t): @@ -156,15 +156,14 @@ def _one(seq): return x +@attrs.define class ThreadLocalInterpreter: """An interpeter used to execute a sequence of queries within the same thread and cursor. Useful for cursor-sensitive operations, such as creating a temporary table. """ - - def __init__(self, compiler: Compiler, gen: Generator): - self.gen = gen - self.compiler = compiler + compiler: Compiler + gen: Generator def apply_queries(self, callback: Callable[[str], Any]): q: Expr = next(self.gen) @@ -189,6 +188,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal return callback(sql_code) +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("information_schema", "tables") @@ -205,6 +205,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: # TODO use a more efficient algorithm, when the table count is known @@ -214,15 +215,17 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return tbl.where(Random() < ratio) +@attrs.define class Mixin_OptimizerHints(AbstractMixin_OptimizerHints): def optimizer_hints(self, hints: str) -> str: return f"/*+ {hints} */ " +@attrs.define class BaseDialect(abc.ABC): SUPPORTS_PRIMARY_KEY = False SUPPORTS_INDEXES = False - TYPE_CLASSES: Dict[str, type] = {} + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {} MIXINS = frozenset() PLACEHOLDER_TABLE = None # Used for Oracle @@ -251,7 +254,7 @@ def _compile(self, compiler: Compiler, elem) -> str: if elem is None: return "NULL" elif isinstance(elem, Compilable): - return self.render_compilable(compiler.replace(root=False), elem) + return self.render_compilable(attrs.evolve(compiler, root=False), elem) elif isinstance(elem, str): return f"'{elem}'" elif isinstance(elem, (int, float)): @@ -361,7 +364,7 @@ def render_column(self, c: Compiler, elem: Column) -> str: return self.quote(elem.name) def render_cte(self, parent_c: Compiler, elem: Cte) -> str: - c: Compiler = parent_c.replace(_table_context=[], in_select=False) + c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False) compiled = self.compile(c, elem.source_table) name = elem.name or parent_c.new_unique_name() @@ -472,7 +475,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str: return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}" def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str: - c: Compiler = parent_c.replace(in_select=False) + c: Compiler = attrs.evolve(parent_c, in_select=False) table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}" if parent_c.in_select: table_expr = f"({table_expr}) {c.new_unique_name()}" @@ -484,7 +487,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str: return self.compile(c, elem._get_resolved()) def render_select(self, parent_c: Compiler, elem: Select) -> str: - c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table) + c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table) compile_fn = functools.partial(self.compile, c) columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*" @@ -522,7 +525,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str: def render_join(self, parent_c: Compiler, elem: Join) -> str: tables = [ - t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables + t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name()) for t in elem.source_tables ] c = parent_c.add_table_context(*tables, in_join=True, in_select=False) op = " JOIN " if elem.op is None else f" {elem.op} JOIN " @@ -555,7 +558,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None: return self.compile( c, - elem.table.replace( + attrs.evolve( + elem.table, columns=columns, group_by_exprs=[Code(k) for k in keys], having_exprs=elem.having_exprs, @@ -568,7 +572,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str: " HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else "" ) select = ( - f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" + f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}" ) if c.in_select: @@ -807,10 +811,10 @@ def set_timezone_to_utc(self) -> str: T = TypeVar("T", bound=BaseDialect) -@dataclass +@attrs.define class QueryResult: rows: list - columns: list = None + columns: Optional[list] = None def __iter__(self): return iter(self.rows) @@ -822,6 +826,7 @@ def __getitem__(self, i): return self.rows[i] +@attrs.define class Database(abc.ABC): """Base abstract class for databases. @@ -1098,6 +1103,7 @@ def is_autocommit(self) -> bool: "Return whether the database autocommits changes. When false, COMMIT statements are skipped." +@attrs.define(init=False, slots=False) class ThreadedDatabase(Database): """Access the database through singleton threads. diff --git a/data_diff/databases/bigquery.py b/data_diff/databases/bigquery.py index feb98bde..b164f2b0 100644 --- a/data_diff/databases/bigquery.py +++ b/data_diff/databases/bigquery.py @@ -1,5 +1,8 @@ import re from typing import Any, List, Union + +import attrs + from data_diff.abcs.database_types import ( ColType, Array, @@ -50,11 +53,13 @@ def import_bigquery_service_account(): return service_account +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -99,6 +104,7 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str: return f"to_json_string({value})" +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -112,6 +118,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Mixin_TimeTravel(AbstractMixin_TimeTravel): def time_travel( self, @@ -139,6 +146,7 @@ def time_travel( ) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "BigQuery" ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation diff --git a/data_diff/databases/clickhouse.py b/data_diff/databases/clickhouse.py index 9366b922..0c03536c 100644 --- a/data_diff/databases/clickhouse.py +++ b/data_diff/databases/clickhouse.py @@ -1,5 +1,7 @@ from typing import Optional, Type +import attrs + from data_diff.databases.base import ( MD5_HEXDIGITS, CHECKSUM_HEXDIGITS, @@ -35,12 +37,14 @@ def import_clickhouse(): return clickhouse_driver +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_number(self, value: str, coltype: FractionalType) -> str: # If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped. @@ -99,6 +103,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')" +@attrs.define class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Clickhouse" ROUNDS_ON_PREC_LOSS = False @@ -163,6 +168,7 @@ def current_timestamp(self) -> str: return "now()" +@attrs.define class Clickhouse(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "clickhouse://:@/" diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index 67d0528d..d1102df1 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -2,6 +2,8 @@ from typing import Dict, Sequence import logging +import attrs + from data_diff.abcs.database_types import ( Integer, Float, @@ -34,11 +36,13 @@ def import_databricks(): return databricks +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: """Databricks timestamp contains no more than 6 digits in precision""" @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"cast ({value} as int)") +@attrs.define class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Databricks" ROUNDS_ON_PREC_LOSS = True @@ -99,6 +104,7 @@ def parse_table_name(self, name: str) -> DbPath: return tuple(i for i in path if i is not None) +@attrs.define(init=False) class Databricks(ThreadedDatabase): dialect = Dialect() CONNECT_URI_HELP = "databricks://:@/" diff --git a/data_diff/databases/duckdb.py b/data_diff/databases/duckdb.py index ba6afd63..1b2c3376 100644 --- a/data_diff/databases/duckdb.py +++ b/data_diff/databases/duckdb.py @@ -1,5 +1,7 @@ from typing import Union +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Timestamp, @@ -40,11 +42,13 @@ def import_duckdb(): return duckdb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('0x' || SUBSTRING(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS},{CHECKSUM_HEXDIGITS}))::BIGINT" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: # It's precision 6 by default. If precision is less than 6 -> we remove the trailing numbers. @@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::INTEGER") +@attrs.define class Mixin_RandomSample(AbstractMixin_RandomSample): def random_sample_n(self, tbl: ITable, size: int) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {size};", tbl=tbl, size=size) @@ -68,6 +73,7 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable: return code("SELECT * FROM ({tbl}) USING SAMPLE {percent}%;", tbl=tbl, percent=int(100 * ratio)) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "DuckDB" ROUNDS_ON_PREC_LOSS = False @@ -130,6 +136,7 @@ def current_timestamp(self) -> str: return "current_timestamp" +@attrs.define class DuckDB(Database): dialect = Dialect() SUPPORTS_UNIQUE_CONSTAINT = False # Temporary, until we implement it diff --git a/data_diff/databases/mssql.py b/data_diff/databases/mssql.py index 28d67c99..6dcd465f 100644 --- a/data_diff/databases/mssql.py +++ b/data_diff/databases/mssql.py @@ -1,4 +1,7 @@ from typing import Optional + +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -34,6 +37,7 @@ def import_mssql(): return pyodbc +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.precision > 0: @@ -53,11 +57,13 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"FORMAT({value}, 'N{coltype.precision}')" +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"convert(bigint, convert(varbinary, '0x' + RIGHT(CONVERT(NVARCHAR(32), HashBytes('MD5', {s}), 2), {CHECKSUM_HEXDIGITS}), 1))" +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MsSQL" ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/mysql.py b/data_diff/databases/mysql.py index d6dcba9e..5f1ebe6d 100644 --- a/data_diff/databases/mysql.py +++ b/data_diff/databases/mysql.py @@ -1,3 +1,5 @@ +import attrs + from data_diff.abcs.database_types import ( Datetime, Timestamp, @@ -40,11 +42,13 @@ def import_mysql(): return mysql.connector +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(conv(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as unsigned)" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -60,6 +64,7 @@ def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: return f"TRIM(CAST({value} AS char))" +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "MySQL" ROUNDS_ON_PREC_LOSS = True diff --git a/data_diff/databases/oracle.py b/data_diff/databases/oracle.py index f0309c11..03f2a2cd 100644 --- a/data_diff/databases/oracle.py +++ b/data_diff/databases/oracle.py @@ -1,5 +1,7 @@ from typing import Dict, List, Optional +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( Decimal, @@ -38,6 +40,7 @@ def import_oracle(): return oracledb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: # standard_hash is faster than DBMS_CRYPTO.Hash @@ -45,6 +48,7 @@ def md5_as_int(self, s: str) -> str: return f"to_number(substr(standard_hash({s}, 'MD5'), 18), 'xxxxxxxxxxxxxxx')" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Cast is necessary for correct MD5 (trimming not enough) @@ -68,6 +72,7 @@ def normalize_number(self, value: str, coltype: FractionalType) -> str: return f"to_char({value}, '{format_str}')" +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: return ( @@ -80,6 +85,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable: ) +@attrs.define class Dialect(BaseDialect, Mixin_Schema, Mixin_OptimizerHints, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Oracle" SUPPORTS_PRIMARY_KEY = True diff --git a/data_diff/databases/postgresql.py b/data_diff/databases/postgresql.py index dec9b9d3..30455b2f 100644 --- a/data_diff/databases/postgresql.py +++ b/data_diff/databases/postgresql.py @@ -1,6 +1,9 @@ -from typing import List +from typing import ClassVar, Dict, List, Type + +import attrs + from data_diff.abcs.database_types import ( - DbPath, + ColType, DbPath, JSON, Timestamp, TimestampTZ, @@ -35,11 +38,13 @@ def import_postgresql(): return psycopg2 +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"('x' || substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}))::bit({_CHECKSUM_BITSIZE})::bigint" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -60,6 +65,7 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"{value}::text" +@attrs.define class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "PostgreSQL" ROUNDS_ON_PREC_LOSS = True @@ -67,7 +73,7 @@ class PostgresqlDialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeVal SUPPORTS_INDEXES = True MIXINS = {Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, Mixin_RandomSample} - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { # Timestamps "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, @@ -118,6 +124,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define class PostgreSQL(ThreadedDatabase): dialect = PostgresqlDialect() SUPPORTS_UNIQUE_CONSTAINT = True diff --git a/data_diff/databases/presto.py b/data_diff/databases/presto.py index b4c45751..6148b134 100644 --- a/data_diff/databases/presto.py +++ b/data_diff/databases/presto.py @@ -1,6 +1,8 @@ from functools import partial import re +import attrs + from data_diff.utils import match_regexps from data_diff.abcs.database_types import ( @@ -50,11 +52,13 @@ def import_presto(): return prestodb +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"cast(from_base(substr(to_hex(md5(to_utf8({s}))), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16) as decimal(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_uuid(self, value: str, coltype: ColType_UUID) -> str: # Trim doesn't work on CHAR type diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index d11029c0..2930ba2d 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -1,6 +1,9 @@ -from typing import List, Dict +from typing import ClassVar, List, Dict, Type + +import attrs + from data_diff.abcs.database_types import ( - Float, + ColType, Float, JSON, TemporalType, FractionalType, @@ -18,11 +21,13 @@ ) +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" +@attrs.define class Mixin_NormalizeValue(Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -51,9 +56,10 @@ def normalize_json(self, value: str, _coltype: JSON) -> str: return f"nvl2({value}, json_serialize({value}), NULL)" +@attrs.define class Dialect(PostgresqlDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue): name = "Redshift" - TYPE_CLASSES = { + TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = { **PostgresqlDialect.TYPE_CLASSES, "double": Float, "real": Float, @@ -74,6 +80,7 @@ def type_repr(self, t) -> str: return super().type_repr(t) +@attrs.define class Redshift(PostgreSQL): dialect = Dialect() CONNECT_URI_HELP = "redshift://:@/" diff --git a/data_diff/databases/snowflake.py b/data_diff/databases/snowflake.py index 3a558425..8e73b6b7 100644 --- a/data_diff/databases/snowflake.py +++ b/data_diff/databases/snowflake.py @@ -1,6 +1,8 @@ from typing import Union, List import logging +import attrs + from data_diff.abcs.database_types import ( Timestamp, TimestampTZ, @@ -41,11 +43,13 @@ def import_snowflake(): return snowflake, serialization, default_backend +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"BITAND(md5_number_lower64({s}), {CHECKSUM_MASK})" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: @@ -62,6 +66,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str: return self.to_string(f"{value}::int") +@attrs.define class Mixin_Schema(AbstractMixin_Schema): def table_information(self) -> Compilable: return table("INFORMATION_SCHEMA", "TABLES") diff --git a/data_diff/databases/trino.py b/data_diff/databases/trino.py index e2095758..a2b4442b 100644 --- a/data_diff/databases/trino.py +++ b/data_diff/databases/trino.py @@ -1,3 +1,5 @@ +import attrs + from data_diff.abcs.mixins import AbstractMixin_MD5, AbstractMixin_NormalizeValue from data_diff.abcs.database_types import TemporalType, ColType_UUID from data_diff.databases import presto @@ -15,6 +17,7 @@ def import_trino(): Mixin_MD5 = presto.Mixin_MD5 +@attrs.define class Mixin_NormalizeValue(presto.Mixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: diff --git a/data_diff/databases/vertica.py b/data_diff/databases/vertica.py index e8fe9ec2..12a46228 100644 --- a/data_diff/databases/vertica.py +++ b/data_diff/databases/vertica.py @@ -1,5 +1,7 @@ from typing import List +import attrs + from data_diff.utils import match_regexps from data_diff.databases.base import ( CHECKSUM_HEXDIGITS, @@ -37,11 +39,13 @@ def import_vertica(): return vertica_python +@attrs.define class Mixin_MD5(AbstractMixin_MD5): def md5_as_int(self, s: str) -> str: return f"CAST(HEX_TO_INTEGER(SUBSTRING(MD5({s}), {1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS})) AS NUMERIC(38, 0))" +@attrs.define class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: diff --git a/data_diff/dbt_parser.py b/data_diff/dbt_parser.py index d5976f74..18394f4f 100644 --- a/data_diff/dbt_parser.py +++ b/data_diff/dbt_parser.py @@ -3,6 +3,8 @@ import json from pathlib import Path from typing import List, Dict, Tuple, Set, Optional + +import attrs import yaml from pydantic import BaseModel @@ -94,6 +96,7 @@ class TDatadiffConfig(BaseModel): datasource_id: Optional[int] = None +@attrs.define(init=False) class DbtParser: def __init__( self, diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 08c18391..45d5697d 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -3,14 +3,13 @@ import time from abc import ABC, abstractmethod -from dataclasses import field from enum import Enum from contextlib import contextmanager from operator import methodcaller from typing import Dict, Tuple, Iterator, Optional from concurrent.futures import ThreadPoolExecutor, as_completed -from runtype import dataclass +import attrs from data_diff.info_tree import InfoTree, SegmentInfo from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector @@ -31,7 +30,7 @@ class Algorithm(Enum): DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]] -@dataclass +@attrs.define class ThreadBase: "Provides utility methods for optional threading" @@ -72,7 +71,7 @@ def _run_in_background(self, *funcs): f.result() -@dataclass +@attrs.define class DiffStats: diff_by_sign: Dict[str, int] table1_count: int @@ -82,12 +81,12 @@ class DiffStats: extra_column_diffs: Optional[Dict[str, int]] -@dataclass +@attrs.define class DiffResultWrapper: diff: iter # DiffResult info_tree: InfoTree stats: dict - result_list: list = field(default_factory=list) + result_list: list = attrs.field(factory=list) def __iter__(self): yield from self.result_list @@ -180,6 +179,7 @@ def get_stats_dict(self, is_dbt: bool = False): return json_output +@attrs.define class TableDiffer(ThreadBase, ABC): bisection_factor = 32 stats: dict = {} @@ -203,7 +203,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: if is_tracking_enabled(): - options = dict(self) + options = attrs.asdict(self, recurse=False) options["differ_name"] = type(self).__name__ event_json = create_start_event_json(options) run_as_daemon(send_event_json, event_json) diff --git a/data_diff/format.py b/data_diff/format.py index 8a515e1b..5c3761be 100644 --- a/data_diff/format.py +++ b/data_diff/format.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Any, Optional, List, Dict, Tuple, Type -from runtype import dataclass +import attrs from data_diff.diff_tables import DiffResultWrapper from data_diff.abcs.database_types import ( JSON, @@ -86,7 +86,7 @@ def jsonify( ).json() -@dataclass +@attrs.define class JsonExclusiveRowValue: """ Value of a single column in a row @@ -96,7 +96,7 @@ class JsonExclusiveRowValue: value: Any -@dataclass +@attrs.define class JsonDiffRowValue: """ Pair of diffed values for 2 rows with equal PKs @@ -108,19 +108,19 @@ class JsonDiffRowValue: isPK: bool -@dataclass +@attrs.define class Total: dataset1: int dataset2: int -@dataclass +@attrs.define class ExclusiveRows: dataset1: int dataset2: int -@dataclass +@attrs.define class Rows: total: Total exclusive: ExclusiveRows @@ -128,18 +128,18 @@ class Rows: unchanged: int -@dataclass +@attrs.define class Stats: diffCounts: Dict[str, int] -@dataclass +@attrs.define class JsonDiffSummary: rows: Rows stats: Stats -@dataclass +@attrs.define class ExclusiveColumns: dataset1: List[str] dataset2: List[str] @@ -172,14 +172,14 @@ class ColumnKind(Enum): ] -@dataclass +@attrs.define class Column: name: str type: str kind: str -@dataclass +@attrs.define class JsonColumnsSummary: dataset1: List[Column] dataset2: List[Column] @@ -188,19 +188,19 @@ class JsonColumnsSummary: typeChanged: List[str] -@dataclass +@attrs.define class ExclusiveDiff: dataset1: List[Dict[str, JsonExclusiveRowValue]] dataset2: List[Dict[str, JsonExclusiveRowValue]] -@dataclass +@attrs.define class RowsDiff: exclusive: ExclusiveDiff diff: List[Dict[str, JsonDiffRowValue]] -@dataclass +@attrs.define class FailedDiff: status: str # Literal ["failed"] model: str @@ -211,7 +211,7 @@ class FailedDiff: version: str = "1.0.0" -@dataclass +@attrs.define class JsonDiff: status: str # Literal ["success"] result: str # Literal ["different", "identical"] diff --git a/data_diff/hashdiff_tables.py b/data_diff/hashdiff_tables.py index 3fc030ec..c6914c96 100644 --- a/data_diff/hashdiff_tables.py +++ b/data_diff/hashdiff_tables.py @@ -1,12 +1,11 @@ import os -from dataclasses import field from numbers import Number import logging from collections import defaultdict from typing import Iterator from operator import attrgetter -from runtype import dataclass +import attrs from data_diff.abcs.database_types import ColType_UUID, NumericType, PrecisionType, StringType, Boolean, JSON from data_diff.info_tree import InfoTree @@ -53,7 +52,7 @@ def diff_sets(a: list, b: list, json_cols: dict = None) -> Iterator: yield from v -@dataclass +@attrs.define class HashDiffer(TableDiffer): """Finds the diff between two SQL tables @@ -74,9 +73,9 @@ class HashDiffer(TableDiffer): bisection_factor: int = DEFAULT_BISECTION_FACTOR bisection_threshold: Number = DEFAULT_BISECTION_THRESHOLD # Accepts inf for tests - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) - def __post_init__(self): + def __attrs_post_init__(self): # Validate options if self.bisection_factor >= self.bisection_threshold: raise ValueError("Incorrect param values (bisection factor must be lower than threshold)") @@ -102,8 +101,8 @@ def _validate_and_adjust_columns(self, table1, table2): if col1.precision != col2.precision: logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") - table1._schema[c1] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) - table2._schema[c2] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision, rounds=lowest.rounds) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision, rounds=lowest.rounds) elif isinstance(col1, (NumericType, Boolean)): if not isinstance(col2, (NumericType, Boolean)): @@ -115,9 +114,9 @@ def _validate_and_adjust_columns(self, table1, table2): logger.warning(f"Using reduced precision {lowest} for column '{c1}'. Types={col1}, {col2}") if lowest.precision != col1.precision: - table1._schema[c1] = col1.replace(precision=lowest.precision) + table1._schema[c1] = attrs.evolve(col1, precision=lowest.precision) if lowest.precision != col2.precision: - table2._schema[c2] = col2.replace(precision=lowest.precision) + table2._schema[c2] = attrs.evolve(col2, precision=lowest.precision) elif isinstance(col1, ColType_UUID): if not isinstance(col2, ColType_UUID): diff --git a/data_diff/info_tree.py b/data_diff/info_tree.py index bd2282ba..abed5bae 100644 --- a/data_diff/info_tree.py +++ b/data_diff/info_tree.py @@ -1,22 +1,21 @@ -from dataclasses import field from typing import List, Dict, Optional, Any, Tuple, Union -from runtype import dataclass +import attrs from data_diff.table_segment import TableSegment -@dataclass(frozen=False) +@attrs.define class SegmentInfo: tables: List[TableSegment] - diff: List[Union[Tuple[Any, ...], List[Any]]] = None - diff_schema: Tuple[Tuple[str, type], ...] = None - is_diff: bool = None - diff_count: int = None + diff: Optional[List[Union[Tuple[Any, ...], List[Any]]]] = None + diff_schema: Optional[Tuple[Tuple[str, type], ...]] = None + is_diff: Optional[bool] = None + diff_count: Optional[int] = None - rowcounts: Dict[int, int] = field(default_factory=dict) - max_rows: int = None + rowcounts: Dict[int, int] = attrs.field(factory=dict) + max_rows: Optional[int] = None def set_diff(self, diff: List[Union[Tuple[Any, ...], List[Any]]], schema: Optional[Tuple[Tuple[str, type]]] = None): self.diff_schema = schema @@ -40,10 +39,10 @@ def update_from_children(self, child_infos): } -@dataclass +@attrs.define class InfoTree: info: SegmentInfo - children: List["InfoTree"] = field(default_factory=list) + children: List["InfoTree"] = attrs.field(factory=list) def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: int = None): node = InfoTree(SegmentInfo([table1, table2], max_rows=max_rows)) diff --git a/data_diff/joindiff_tables.py b/data_diff/joindiff_tables.py index 91e2aecd..39cf6e5c 100644 --- a/data_diff/joindiff_tables.py +++ b/data_diff/joindiff_tables.py @@ -1,14 +1,13 @@ """Provides classes for performing a table diff using JOIN """ -from dataclasses import field from decimal import Decimal from functools import partial import logging -from typing import List +from typing import List, Optional from itertools import chain -from runtype import dataclass +import attrs from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake from data_diff.abcs.database_types import NumericType, DbPath @@ -58,7 +57,7 @@ def sample(table_expr): def create_temp_table(c: Compiler, path: TablePath, expr: Expr) -> str: db = c.database - c: Compiler = c.replace(root=False) # we're compiling fragments, not full queries + c: Compiler = attrs.evolve(c, root=False) # we're compiling fragments, not full queries if isinstance(db, BigQuery): return f"create table {c.dialect.compile(c, path)} OPTIONS(expiration_timestamp=TIMESTAMP_ADD(CURRENT_TIMESTAMP(), INTERVAL 1 DAY)) as {c.dialect.compile(c, expr)}" elif isinstance(db, Presto): @@ -111,7 +110,7 @@ def json_friendly_value(v): return v -@dataclass +@attrs.define class JoinDiffer(TableDiffer): """Finds the diff between two SQL tables in the same database, using JOINs. @@ -138,12 +137,12 @@ class JoinDiffer(TableDiffer): validate_unique_key: bool = True sample_exclusive_rows: bool = False - materialize_to_table: DbPath = None + materialize_to_table: Optional[DbPath] = None materialize_all_rows: bool = False table_write_limit: int = TABLE_WRITE_LIMIT skip_null_keys: bool = False - stats: dict = field(default_factory=dict) + stats: dict = attrs.field(factory=dict) def _diff_tables_root(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult: db = table1.database diff --git a/data_diff/lexicographic_space.py b/data_diff/lexicographic_space.py index 88cf863d..2989c2a9 100644 --- a/data_diff/lexicographic_space.py +++ b/data_diff/lexicographic_space.py @@ -20,6 +20,9 @@ from random import randint, randrange from typing import Tuple + +import attrs + from data_diff.utils import safezip Vector = Tuple[int] @@ -56,6 +59,7 @@ def irandrange(start, stop): return randrange(start, stop) +@attrs.define class LexicographicSpace: """Lexicographic space of arbitrary dimensions. diff --git a/data_diff/queries/api.py b/data_diff/queries/api.py index 82786871..22dc1a90 100644 --- a/data_diff/queries/api.py +++ b/data_diff/queries/api.py @@ -69,7 +69,7 @@ def table(*path: str, schema: Union[dict, CaseAwareMapping] = None) -> TablePath if schema and not isinstance(schema, CaseAwareMapping): assert isinstance(schema, dict) schema = CaseSensitiveDict(schema) - return TablePath(path, schema) + return TablePath(path, schema=schema) def or_(*exprs: Expr): diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 56efdb20..ce1b28f1 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -1,8 +1,7 @@ -from dataclasses import field from datetime import datetime from typing import Any, Generator, List, Optional, Sequence, Union, Dict -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import ArithString @@ -21,18 +20,22 @@ class QB_TypeError(QueryBuilderError): pass +@attrs.define class Root: "Nodes inheriting from Root can be used as root statements in SQL (e.g. SELECT yes, RANDOM() no)" +@attrs.define(eq=False) class ExprNode(Compilable): "Base class for query expression nodes" - type: Any = None + @property + def type(self) -> Optional[type]: + return None def _dfs_values(self): yield self - for k, vs in dict(self).items(): # __dict__ provided by runtype.dataclass + for k, vs in attrs.asdict(self, recurse=False).items(): if k == "source_table": # Skip data-sources, we're only interested in data-parameters continue @@ -50,10 +53,10 @@ def cast_to(self, to): Expr = Union[ExprNode, str, bool, int, float, datetime, ArithString, None] -@dataclass +@attrs.define(eq=False) class Code(ExprNode, Root): code: str - args: Dict[str, Expr] = None + args: Optional[Dict[str, Expr]] = None def _expr_type(e: Expr) -> type: @@ -62,7 +65,7 @@ def _expr_type(e: Expr) -> type: return type(e) -@dataclass +@attrs.define(eq=False) class Alias(ExprNode): expr: Expr name: str @@ -80,9 +83,16 @@ def _drop_skips_dict(exprs_dict): return {k: v for k, v in exprs_dict.items() if v is not SKIP} +@attrs.define class ITable: - source_table: Any - schema: Schema = None + + @property + def source_table(self) -> "ITable": # not always Self, it can be a substitute + return self + + @property + def schema(self) -> Schema: + raise NotImplementedError def select(self, *exprs, distinct=SKIP, optimizer_hints=SKIP, **named_exprs) -> "ITable": """Choose new columns, based on the old ones. (aka Projection) @@ -206,19 +216,23 @@ def intersect(self, other: "ITable"): return TableOp("INTERSECT", self, other) -@dataclass +@attrs.define(eq=False) class Concat(ExprNode): exprs: list - sep: str = None + sep: Optional[str] = None -@dataclass +@attrs.define(eq=False) class Count(ExprNode): expr: Expr = None distinct: bool = False - type = int + + @property + def type(self) -> Optional[type]: + return int +@attrs.define(eq=False) class LazyOps: def __add__(self, other): return BinOp("+", [self, other]) @@ -268,22 +282,22 @@ def min(self): return Func("MIN", [self]) -@dataclass(eq=False) -class Func(ExprNode, LazyOps): +@attrs.define(eq=False) +class Func(LazyOps, ExprNode): name: str args: Sequence[Expr] -@dataclass +@attrs.define(eq=False) class WhenThen(ExprNode): when: Expr then: Expr -@dataclass +@attrs.define(eq=False) class CaseWhen(ExprNode): cases: Sequence[WhenThen] - else_expr: Expr = None + else_expr: Optional[Expr] = None @property def type(self): @@ -318,10 +332,10 @@ def else_(self, then: Expr) -> Self: if self.else_expr is not None: raise QueryBuilderError(f"Else clause already specified in {self}") - return self.replace(else_expr=then) + return attrs.evolve(self, else_expr=then) -@dataclass +@attrs.define(eq=False) class QB_When: "Partial case-when, used for query-building" casewhen: CaseWhen @@ -330,18 +344,21 @@ class QB_When: def then(self, then: Expr) -> CaseWhen: """Add a 'then' clause after a 'when' was added.""" case = WhenThen(self.when, then) - return self.casewhen.replace(cases=self.casewhen.cases + [case]) + return attrs.evolve(self.casewhen, cases=self.casewhen.cases + [case]) -@dataclass(eq=False, order=False) -class IsDistinctFrom(ExprNode, LazyOps): +@attrs.define(eq=False) +class IsDistinctFrom(LazyOps, ExprNode): a: Expr b: Expr - type = bool + @property + def type(self) -> Optional[type]: + return bool -@dataclass(eq=False, order=False) -class BinOp(ExprNode, LazyOps): + +@attrs.define(eq=False) +class BinOp(LazyOps, ExprNode): op: str args: Sequence[Expr] @@ -354,18 +371,21 @@ def type(self): return t -@dataclass -class UnaryOp(ExprNode, LazyOps): +@attrs.define(eq=False) +class UnaryOp(LazyOps, ExprNode): op: str expr: Expr +@attrs.define class BinBoolOp(BinOp): - type = bool + @property + def type(self) -> Optional[type]: + return bool -@dataclass(eq=False, order=False) -class Column(ExprNode, LazyOps): +@attrs.define(eq=False) +class Column(LazyOps, ExprNode): source_table: ITable name: str @@ -376,14 +396,10 @@ def type(self): return self.source_table.schema[self.name] -@dataclass +@attrs.define(eq=False) class TablePath(ExprNode, ITable): path: DbPath - schema: Optional[Schema] = field(default=None, repr=False) - - @property - def source_table(self) -> Self: - return self + schema: Schema # overrides the inherited property # Statement shorthands def create(self, source_table: ITable = None, *, if_not_exists: bool = False, primary_keys: List[str] = None): @@ -463,25 +479,29 @@ def time_travel( assert offset is None and statement is None -@dataclass +@attrs.define(eq=False) class TableAlias(ExprNode, ITable): - source_table: ITable + table: ITable name: str + @property + def source_table(self) -> ITable: + return self.table -@dataclass + @property + def schema(self) -> Schema: + return self.table.schema + + +@attrs.define(eq=False) class Join(ExprNode, ITable, Root): source_tables: Sequence[ITable] - op: str = None - on_exprs: Sequence[Expr] = None - columns: Sequence[Expr] = None - - @property - def source_table(self) -> Self: - return self + op: Optional[str] = None + on_exprs: Optional[Sequence[Expr]] = None + columns: Optional[Sequence[Expr]] = None @property - def schema(self): + def schema(self) -> Schema: assert self.columns # TODO Implement SELECT * s = self.source_tables[0].schema # TODO validate types match between both tables return type(s)({c.name: c.type for c in self.columns}) @@ -497,7 +517,7 @@ def on(self, *exprs) -> Self: if not exprs: return self - return self.replace(on_exprs=(self.on_exprs or []) + exprs) + return attrs.evolve(self, on_exprs=(self.on_exprs or []) + exprs) def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: """Select fields to return from the JOIN operation @@ -513,21 +533,17 @@ def select(self, *exprs, **named_exprs) -> Union[Self, ITable]: exprs += _named_exprs_as_aliases(named_exprs) resolve_names(self.source_table, exprs) # TODO Ensure exprs <= self.columns ? - return self.replace(columns=exprs) + return attrs.evolve(self, columns=exprs) -@dataclass +@attrs.define(eq=False) class GroupBy(ExprNode, ITable, Root): table: ITable - keys: Sequence[Expr] = None # IKey? - values: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None - - @property - def source_table(self): - return self + keys: Optional[Sequence[Expr]] = None # IKey? + values: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None - def __post_init__(self): + def __attrs_post_init__(self): assert self.keys or self.values def having(self, *exprs) -> Self: @@ -538,62 +554,54 @@ def having(self, *exprs) -> Self: return self resolve_names(self.table, exprs) - return self.replace(having_exprs=(self.having_exprs or []) + exprs) + return attrs.evolve(self, having_exprs=(self.having_exprs or []) + exprs) def agg(self, *exprs) -> Self: """Select aggregated fields for the group-by.""" exprs = args_as_tuple(exprs) exprs = _drop_skips(exprs) resolve_names(self.table, exprs) - return self.replace(values=(self.values or []) + exprs) + return attrs.evolve(self, values=(self.values or []) + exprs) -@dataclass +@attrs.define(eq=False) class TableOp(ExprNode, ITable, Root): op: str table1: ITable table2: ITable - @property - def source_table(self): - return self - @property def type(self): # TODO ensure types of both tables are compatible return self.table1.type @property - def schema(self): + def schema(self) -> Schema: s1 = self.table1.schema s2 = self.table2.schema assert len(s1) == len(s2) return s1 -@dataclass +@attrs.define(eq=False) class Select(ExprNode, ITable, Root): - table: Expr = None - columns: Sequence[Expr] = None - where_exprs: Sequence[Expr] = None - order_by_exprs: Sequence[Expr] = None - group_by_exprs: Sequence[Expr] = None - having_exprs: Sequence[Expr] = None - limit_expr: int = None + table: Optional[Expr] = None + columns: Optional[Sequence[Expr]] = None + where_exprs: Optional[Sequence[Expr]] = None + order_by_exprs: Optional[Sequence[Expr]] = None + group_by_exprs: Optional[Sequence[Expr]] = None + having_exprs: Optional[Sequence[Expr]] = None + limit_expr: Optional[int] = None distinct: bool = False - optimizer_hints: Sequence[Expr] = None + optimizer_hints: Optional[Sequence[Expr]] = None @property - def schema(self): + def schema(self) -> Schema: s = self.table.schema if s is None or self.columns is None: return s return type(s)({c.name: c.type for c in self.columns}) - @property - def source_table(self): - return self - @classmethod def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, **kwargs): assert "table" not in kwargs @@ -627,19 +635,23 @@ def make(cls, table: ITable, distinct: bool = SKIP, optimizer_hints: str = SKIP, else: raise ValueError(k) - return table.replace(**kwargs) + return attrs.evolve(table, **kwargs) -@dataclass +@attrs.define(eq=False) class Cte(ExprNode, ITable): - source_table: Expr - name: str = None - params: Sequence[str] = None + table: Expr + name: Optional[str] = None + params: Optional[Sequence[str]] = None + + @property + def source_table(self) -> "ITable": + return self.table @property - def schema(self): + def schema(self) -> Schema: # TODO add cte to schema - return self.source_table.schema + return self.table.schema def _named_exprs_as_aliases(named_exprs): @@ -657,10 +669,10 @@ def resolve_names(source_table, exprs): i += 1 -@dataclass(frozen=False, eq=False, order=False) -class _ResolveColumn(ExprNode, LazyOps): +@attrs.define(eq=False) +class _ResolveColumn(LazyOps, ExprNode): resolve_name: str - resolved: Expr = None + resolved: Optional[Expr] = None def resolve(self, expr: Expr): if self.resolved is not None: @@ -681,6 +693,7 @@ def name(self): return self._get_resolved().name +@attrs.define class This: """Builder object for accessing table attributes. @@ -696,80 +709,94 @@ def __getitem__(self, name): return _ResolveColumn(name) -@dataclass +@attrs.define(eq=False) class In(ExprNode): expr: Expr list: Sequence[Expr] - type = bool + + @property + def type(self) -> Optional[type]: + return bool -@dataclass +@attrs.define(eq=False) class Cast(ExprNode): expr: Expr target_type: Expr -@dataclass -class Random(ExprNode, LazyOps): - type = float +@attrs.define(eq=False) +class Random(LazyOps, ExprNode): + @property + def type(self) -> Optional[type]: + return float -@dataclass +@attrs.define(eq=False) class ConstantTable(ExprNode): rows: Sequence[Sequence] -@dataclass +@attrs.define(eq=False) class Explain(ExprNode, Root): select: Select - type = str + + @property + def type(self) -> Optional[type]: + return str +@attrs.define class CurrentTimestamp(ExprNode): - type = datetime + @property + def type(self) -> Optional[type]: + return datetime -@dataclass -class TimeTravel(ITable): +@attrs.define(eq=False) +class TimeTravel(ITable): # TODO: Unused? table: TablePath before: bool = False - timestamp: datetime = None - offset: int = None - statement: str = None + timestamp: Optional[datetime] = None + offset: Optional[int] = None + statement: Optional[str] = None # DDL +@attrs.define class Statement(Compilable, Root): - type = None + @property + def type(self) -> Optional[type]: + return None -@dataclass +@attrs.define(eq=False) class CreateTable(Statement): path: TablePath - source_table: Expr = None + source_table: Optional[Expr] = None if_not_exists: bool = False - primary_keys: List[str] = None + primary_keys: Optional[List[str]] = None -@dataclass +@attrs.define(eq=False) class DropTable(Statement): path: TablePath if_exists: bool = False -@dataclass +@attrs.define(eq=False) class TruncateTable(Statement): path: TablePath -@dataclass +@attrs.define(eq=False) class InsertToTable(Statement): path: TablePath expr: Expr - columns: List[str] = None - returning_exprs: List[str] = None + columns: Optional[List[str]] = None + returning_exprs: Optional[List[str]] = None def returning(self, *exprs) -> Self: """Add a 'RETURNING' clause to the insert expression. @@ -785,19 +812,15 @@ def returning(self, *exprs) -> Self: return self resolve_names(self.path, exprs) - return self.replace(returning_exprs=exprs) + return attrs.evolve(self, returning_exprs=exprs) -@dataclass +@attrs.define(eq=False) class Commit(Statement): """Generate a COMMIT statement, if we're in the middle of a transaction, or in auto-commit. Otherwise SKIP.""" -@dataclass -class Param(ExprNode, ITable): +@attrs.define(eq=False) +class Param(ExprNode, ITable): # TODO: Unused? """A value placeholder, to be specified at compilation time using the `cv_params` context variable.""" name: str - - @property - def source_table(self): - return self diff --git a/data_diff/queries/base.py b/data_diff/queries/base.py index 205c2211..e895cd8c 100644 --- a/data_diff/queries/base.py +++ b/data_diff/queries/base.py @@ -1,6 +1,9 @@ from typing import Generator +import attrs + +@attrs.define class _SKIP: def __repr__(self): return "SKIP" diff --git a/data_diff/queries/extras.py b/data_diff/queries/extras.py index bb0c8299..5441146f 100644 --- a/data_diff/queries/extras.py +++ b/data_diff/queries/extras.py @@ -1,25 +1,28 @@ "Useful AST classes that don't quite fall within the scope of regular SQL" -from typing import Callable, Sequence -from runtype import dataclass +from typing import Callable, Optional, Sequence -from data_diff.abcs.database_types import ColType +import attrs +from data_diff.abcs.database_types import ColType from data_diff.queries.ast_classes import Expr, ExprNode -@dataclass +@attrs.define class NormalizeAsString(ExprNode): expr: ExprNode - expr_type: ColType = None - type = str + expr_type: Optional[ColType] = None + + @property + def type(self) -> Optional[type]: + return str -@dataclass +@attrs.define class ApplyFuncAndNormalizeAsString(ExprNode): expr: ExprNode - apply_func: Callable = None + apply_func: Optional[Callable] = None -@dataclass +@attrs.define class Checksum(ExprNode): exprs: Sequence[Expr] diff --git a/data_diff/table_segment.py b/data_diff/table_segment.py index aaf747f6..405415a8 100644 --- a/data_diff/table_segment.py +++ b/data_diff/table_segment.py @@ -1,9 +1,9 @@ import time -from typing import List, Tuple +from typing import List, Optional, Tuple import logging from itertools import product -from runtype import dataclass +import attrs from typing_extensions import Self from data_diff.utils import safezip, Vector @@ -85,7 +85,7 @@ def create_mesh_from_points(*values_per_dim: list) -> List[Tuple[Vector, Vector] return res -@dataclass +@attrs.define class TableSegment: """Signifies a segment of rows (and selected columns) within a table @@ -112,20 +112,20 @@ class TableSegment: # Columns key_columns: Tuple[str, ...] - update_column: str = None + update_column: Optional[str] = None extra_columns: Tuple[str, ...] = () # Restrict the segment - min_key: Vector = None - max_key: Vector = None - min_update: DbTime = None - max_update: DbTime = None - where: str = None + min_key: Optional[Vector] = None + max_key: Optional[Vector] = None + min_update: Optional[DbTime] = None + max_update: Optional[DbTime] = None + where: Optional[str] = None - case_sensitive: bool = True - _schema: Schema = None + case_sensitive: Optional[bool] = True + _schema: Optional[Schema] = None - def __post_init__(self): + def __attrs_post_init__(self): if not self.update_column and (self.min_update or self.max_update): raise ValueError("Error: the min_update/max_update feature requires 'update_column' to be set.") @@ -142,7 +142,7 @@ def _where(self): def _with_raw_schema(self, raw_schema: dict) -> Self: schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where()) - return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) + return self.new(schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive)) def with_schema(self) -> Self: "Queries the table schema from the database, and returns a new instance of TableSegment, with a schema." @@ -199,7 +199,7 @@ def segment_by_checkpoints(self, checkpoints: List[List[DbKey]]) -> List["TableS def new(self, **kwargs) -> Self: """Creates a copy of the instance using 'replace()'""" - return self.replace(**kwargs) + return attrs.evolve(self, **kwargs) def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: if self.min_key is not None: @@ -210,7 +210,7 @@ def new_key_bounds(self, min_key: Vector, max_key: Vector) -> Self: assert min_key < self.max_key assert max_key <= self.max_key - return self.replace(min_key=min_key, max_key=max_key) + return attrs.evolve(self, min_key=min_key, max_key=max_key) @property def relevant_columns(self) -> List[str]: diff --git a/data_diff/thread_utils.py b/data_diff/thread_utils.py index 1be94ad4..4a7348f2 100644 --- a/data_diff/thread_utils.py +++ b/data_diff/thread_utils.py @@ -7,6 +7,8 @@ from time import sleep from typing import Callable, Iterator, Optional +import attrs + class AutoPriorityQueue(PriorityQueue): """Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs @@ -31,19 +33,22 @@ class PriorityThreadPoolExecutor(ThreadPoolExecutor): XXX WARNING: Might break in future versions of Python """ - def __init__(self, *args): super().__init__(*args) - self._work_queue = AutoPriorityQueue() +@attrs.define(init=False) class ThreadedYielder(Iterable): """Yields results from multiple threads into a single iterator, ordered by priority. To add a source iterator, call ``submit()`` with a function that returns an iterator. Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first) """ + _pool: ThreadPoolExecutor + _futures: deque + _yield: deque = attrs.field(alias='_yield') # Python keyword! + _exception: Optional[None] def __init__(self, max_workers: Optional[int] = None): self._pool = PriorityThreadPoolExecutor(max_workers) diff --git a/data_diff/utils.py b/data_diff/utils.py index b725285e..59ede253 100644 --- a/data_diff/utils.py +++ b/data_diff/utils.py @@ -3,13 +3,14 @@ import re import string from abc import abstractmethod -from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Sequence, TypeVar, Union +from typing import Any, Dict, Iterable, Iterator, List, MutableMapping, Optional, Sequence, TypeVar, Union from urllib.parse import urlparse import operator import threading from datetime import datetime from uuid import UUID +import attrs from packaging.version import parse as parse_version import requests from tabulate import tabulate @@ -61,6 +62,7 @@ def match_regexps(regexps: Dict[str, Any], s: str) -> Sequence[tuple]: V = TypeVar("V") +@attrs.define class CaseAwareMapping(MutableMapping[str, V]): @abstractmethod def get_key(self, key: str) -> str: @@ -70,7 +72,10 @@ def new(self, initial=()) -> Self: return type(self)(initial) +@attrs.define(repr=False, init=False) class CaseInsensitiveDict(CaseAwareMapping): + _dict: Dict[str, Any] + def __init__(self, initial): self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()} @@ -99,6 +104,7 @@ def __repr__(self) -> str: return repr(dict(self.items())) +@attrs.define class CaseSensitiveDict(dict, CaseAwareMapping): def get_key(self, key): self[key] # Throw KeyError if key doesn't exist @@ -114,6 +120,7 @@ def as_insensitive(self): alphanums = " -" + string.digits + string.ascii_uppercase + "_" + string.ascii_lowercase +@attrs.define class ArithString: @classmethod def new(cls, *args, **kw) -> Self: @@ -125,6 +132,7 @@ def range(self, other: "ArithString", count: int) -> List[Self]: return [self.new(int=i) for i in checkpoints] +@attrs.define class ArithUUID(UUID, ArithString): "A UUID that supports basic arithmetic (add, sub)" @@ -173,20 +181,21 @@ def alphanums_to_numbers(s1: str, s2: str): return n1, n2 +@attrs.define class ArithAlphanumeric(ArithString): - def __init__(self, s: str, max_len=None): - if s is None: + _str: str + _max_len: Optional[int] = None + + def __attrs_post_init__(self): + if self._str is None: raise ValueError("Alphanum string cannot be None") - if max_len and len(s) > max_len: - raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {max_len}") + if self._max_len and len(self._str) > self._max_len: + raise ValueError(f"Length of alphanum value '{str}' is longer than the expected {self._max_len}") - for ch in s: + for ch in self._str: if ch not in alphanums: raise ValueError(f"Unexpected character {ch} in alphanum string") - self._str = s - self._max_len = max_len - # @property # def int(self): # return alphanumToNumber(self._str, alphanums) diff --git a/poetry.lock b/poetry.lock index afd70d75..a95c7580 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2030,4 +2030,4 @@ vertica = ["vertica-python"] [metadata] lock-version = "2.0" python-versions = "^3.7.2" -content-hash = "55cde03a00788572dac6310e7bbf61bd2522d70217056a51608bcfc429440fbf" +content-hash = "c7da70c19432ca716980f3421182d54d7f5d2e0d8bbd7e20dbaf521c8ef7d0fb" diff --git a/pyproject.toml b/pyproject.toml index 0ae8f0a4..2ceb0c27 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ packages = [{ include = "data_diff" }] [tool.poetry.dependencies] pydantic = "1.10.12" python = "^3.7.2" -runtype = "^0.2.6" dsnparse = "<0.2.0" click = "^8.1" rich = "*" @@ -47,6 +46,7 @@ urllib3 = "<2" oracledb = {version = "*", optional=true} pyodbc = {version="^4.0.39", optional=true} typing-extensions = ">=4.0.1" +attrs = "^23.1.0" [tool.poetry.dev-dependencies] parameterized = "*" diff --git a/tests/test_database.py b/tests/test_database.py index b17cb7f0..d0c1d3a4 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -2,6 +2,7 @@ from datetime import datetime from typing import Callable, List, Tuple +import attrs import pytz from data_diff import connect @@ -128,8 +129,8 @@ def test_correct_timezone(self): raw_schema = db.query_table_schema(t.path) schema = db._process_table_schema(t.path, raw_schema) schema = create_schema(self.database.name, t, schema, case_sensitive=True) - t = t.replace(schema=schema) - t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision) + t = attrs.evolve(t, schema=schema) + t.schema["created_at"] = attrs.evolve(t.schema["created_at"], precision=t.schema["created_at"].precision) tbl = table(name, schema=t.schema) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index b5885a26..26d555dc 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -3,6 +3,8 @@ import uuid import unittest +import attrs + from data_diff.queries.api import table, this, commit, code from data_diff.utils import ArithAlphanumeric, numberToAlphanum @@ -382,13 +384,13 @@ def test_string_keys(self): self.assertRaises(ValueError, list, differ.diff_tables(self.a, self.b)) def test_where_sampling(self): - a = self.a.replace(where="1=1") + a = attrs.evolve(self.a, where="1=1") differ = HashDiffer(bisection_factor=2) diff = list(differ.diff_tables(a, self.b)) self.assertEqual(diff, [("-", (str(self.new_uuid), "This one is different"))]) - a_empty = self.a.replace(where="1=0") + a_empty = attrs.evolve(self.a, where="1=0") self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) @@ -519,11 +521,11 @@ def test_case_awareness(self): [src_table.create(), src_table.insert_rows([[1, 9, time_obj], [2, 2, time_obj]], columns=cols), commit] ) - res = tuple(self.table.replace(key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) + res = tuple(attrs.evolve(self.table, key_columns=("Id",), case_sensitive=False).with_schema().query_key_range()) self.assertEqual(res, (("1",), ("2",))) self.assertRaises( - KeyError, self.table.replace(key_columns=("Id",), case_sensitive=True).with_schema().query_key_range + KeyError, attrs.evolve(self.table, key_columns=("Id",), case_sensitive=True).with_schema().query_key_range ) diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index b2c5c419..dd424017 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -1,6 +1,8 @@ from typing import List from datetime import datetime +import attrs + from data_diff.queries.ast_classes import TablePath from data_diff.queries.api import table, commit from data_diff.table_segment import TableSegment @@ -114,7 +116,7 @@ def test_diff_small_tables(self): # Test materialize materialize_path = self.connection.parse_table_name(f"test_mat_{random_table_suffix()}") - mdiffer = self.differ.replace(materialize_to_table=materialize_path) + mdiffer = attrs.evolve(self.differ, materialize_to_table=materialize_path) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) @@ -126,7 +128,7 @@ def test_diff_small_tables(self): self.connection.query(t.drop()) # Test materialize all rows - mdiffer = mdiffer.replace(materialize_all_rows=True) + mdiffer = attrs.evolve(mdiffer, materialize_all_rows=True) diff = list(mdiffer.diff_tables(self.table, self.table2)) self.assertEqual(expected, diff) rows = self.connection.query(t.select(), List[tuple]) diff --git a/tests/test_sql.py b/tests/test_sql.py index 2dcab403..f45931ec 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -1,5 +1,7 @@ import unittest +import attrs + from tests.common import TEST_MYSQL_CONN_STRING from data_diff.databases import connect @@ -18,8 +20,9 @@ def test_compile_int(self): self.assertEqual("1", self.compiler.compile(1)) def test_compile_table_name(self): + compiler = attrs.evolve(self.compiler, root=False) self.assertEqual( - "`marine_mammals`.`walrus`", self.compiler.replace(root=False).compile(table("marine_mammals", "walrus")) + "`marine_mammals`.`walrus`", compiler.compile(table("marine_mammals", "walrus")) ) def test_compile_select(self):