Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit d02fad8

Browse files
author
Sergey Vasilyev
committed
Convert the remaining classes to attrs
Since we now use `attrs` for some classes, let's use them for them all — at least those belonging to the same hierarchies. This will ensure that all classes are slotted and will strictly check that we define attributes properly, especially in cases of multiple inheritance. Except for Pydantic models and Python exceptions.
1 parent 10614f2 commit d02fad8

25 files changed

+197
-68
lines changed

data_diff/abcs/database_types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,33 @@ class PrecisionType(ColType):
2626
rounds: Union[bool, Unknown] = Unknown
2727

2828

29+
30+
@attrs.define
2931
class Boolean(ColType):
3032
precision = 0
3133

3234

35+
@attrs.define
3336
class TemporalType(PrecisionType):
3437
pass
3538

3639

40+
@attrs.define
3741
class Timestamp(TemporalType):
3842
pass
3943

4044

45+
@attrs.define
4146
class TimestampTZ(TemporalType):
4247
pass
4348

4449

50+
@attrs.define
4551
class Datetime(TemporalType):
4652
pass
4753

4854

55+
@attrs.define
4956
class Date(TemporalType):
5057
pass
5158

@@ -56,14 +63,17 @@ class NumericType(ColType):
5663
precision: int
5764

5865

66+
@attrs.define
5967
class FractionalType(NumericType):
6068
pass
6169

6270

71+
@attrs.define
6372
class Float(FractionalType):
6473
python_type = float
6574

6675

76+
@attrs.define
6777
class IKey(ABC):
6878
"Interface for ColType, for using a column as a key in table."
6979

@@ -76,6 +86,7 @@ def make_value(self, value):
7686
return self.python_type(value)
7787

7888

89+
@attrs.define
7990
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
8091
@property
8192
def python_type(self) -> type:
@@ -89,22 +100,27 @@ class StringType(ColType):
89100
python_type = str
90101

91102

103+
@attrs.define
92104
class ColType_UUID(ColType, IKey):
93105
python_type = ArithUUID
94106

95107

108+
@attrs.define
96109
class ColType_Alphanum(ColType, IKey):
97110
python_type = ArithAlphanumeric
98111

99112

113+
@attrs.define
100114
class Native_UUID(ColType_UUID):
101115
pass
102116

103117

118+
@attrs.define
104119
class String_UUID(ColType_UUID, StringType):
105120
pass
106121

107122

123+
@attrs.define
108124
class String_Alphanum(ColType_Alphanum, StringType):
109125
@staticmethod
110126
def test_value(value: str) -> bool:
@@ -118,6 +134,7 @@ def make_value(self, value):
118134
return self.python_type(value)
119135

120136

137+
@attrs.define
121138
class String_VaryingAlphanum(String_Alphanum):
122139
pass
123140

data_diff/abcs/mixins.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from abc import ABC, abstractmethod
2+
3+
import attrs
4+
25
from data_diff.abcs.database_types import (
36
Array,
47
TemporalType,
@@ -13,10 +16,12 @@
1316
from data_diff.abcs.compiler import Compilable
1417

1518

19+
@attrs.define
1620
class AbstractMixin(ABC):
1721
"A mixin for a database dialect"
1822

1923

24+
@attrs.define
2025
class AbstractMixin_NormalizeValue(AbstractMixin):
2126
@abstractmethod
2227
def to_comparable(self, value: str, coltype: ColType) -> str:
@@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
108113
return self.to_string(value)
109114

110115

116+
@attrs.define
111117
class AbstractMixin_MD5(AbstractMixin):
112118
"""Methods for calculating an MD6 hash as an integer."""
113119

@@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str:
116122
"Provide SQL for computing md5 and returning an int"
117123

118124

125+
@attrs.define
119126
class AbstractMixin_Schema(AbstractMixin):
120127
"""Methods for querying the database schema
121128
@@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
134141
"""
135142

136143

144+
@attrs.define
137145
class AbstractMixin_RandomSample(AbstractMixin):
138146
@abstractmethod
139147
def random_sample_n(self, tbl: str, size: int) -> str:
@@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str:
151159
# """
152160

153161

162+
@attrs.define
154163
class AbstractMixin_TimeTravel(AbstractMixin):
155164
@abstractmethod
156165
def time_travel(
@@ -173,6 +182,7 @@ def time_travel(
173182
"""
174183

175184

185+
@attrs.define
176186
class AbstractMixin_OptimizerHints(AbstractMixin):
177187
@abstractmethod
178188
def optimizer_hints(self, optimizer_hints: str) -> str:

data_diff/databases/_connect.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,11 @@ def match_path(self, dsn):
9393
}
9494

9595

96+
@attrs.define(init=False)
9697
class Connect:
9798
"""Provides methods for connecting to a supported database using a URL or connection dict."""
99+
database_by_scheme: Dict[str, Database]
100+
match_uri_path: Dict[str, MatchUriPath]
98101
conn_cache: MutableMapping[Hashable, Database]
99102

100103
def __init__(self, database_by_scheme: Dict[str, Database] = DATABASE_BY_SCHEME):
@@ -284,6 +287,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
284287
return db_conf
285288

286289

290+
@attrs.define(init=False)
287291
class Connect_SetUTC(Connect):
288292
"""Provides methods for connecting to a supported database using a URL or connection dict.
289293

data_diff/databases/base.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import sys
77
import logging
8-
from typing import Any, Callable, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
8+
from typing import Any, Callable, ClassVar, Dict, Generator, Tuple, Optional, Sequence, Type, List, Union, TypeVar
99
from functools import partial, wraps
1010
from concurrent.futures import ThreadPoolExecutor
1111
import threading
@@ -156,15 +156,14 @@ def _one(seq):
156156
return x
157157

158158

159+
@attrs.define
159160
class ThreadLocalInterpreter:
160161
"""An interpeter used to execute a sequence of queries within the same thread and cursor.
161162
162163
Useful for cursor-sensitive operations, such as creating a temporary table.
163164
"""
164-
165-
def __init__(self, compiler: Compiler, gen: Generator):
166-
self.gen = gen
167-
self.compiler = compiler
165+
compiler: Compiler
166+
gen: Generator
168167

169168
def apply_queries(self, callback: Callable[[str], Any]):
170169
q: Expr = next(self.gen)
@@ -189,6 +188,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
189188
return callback(sql_code)
190189

191190

191+
@attrs.define
192192
class Mixin_Schema(AbstractMixin_Schema):
193193
def table_information(self) -> Compilable:
194194
return table("information_schema", "tables")
@@ -205,6 +205,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
205205
)
206206

207207

208+
@attrs.define
208209
class Mixin_RandomSample(AbstractMixin_RandomSample):
209210
def random_sample_n(self, tbl: ITable, size: int) -> ITable:
210211
# TODO use a more efficient algorithm, when the table count is known
@@ -214,15 +215,17 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable:
214215
return tbl.where(Random() < ratio)
215216

216217

218+
@attrs.define
217219
class Mixin_OptimizerHints(AbstractMixin_OptimizerHints):
218220
def optimizer_hints(self, hints: str) -> str:
219221
return f"/*+ {hints} */ "
220222

221223

224+
@attrs.define
222225
class BaseDialect(abc.ABC):
223226
SUPPORTS_PRIMARY_KEY = False
224227
SUPPORTS_INDEXES = False
225-
TYPE_CLASSES: Dict[str, type] = {}
228+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
226229
MIXINS = frozenset()
227230

228231
PLACEHOLDER_TABLE = None # Used for Oracle
@@ -522,7 +525,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
522525

523526
def render_join(self, parent_c: Compiler, elem: Join) -> str:
524527
tables = [
525-
t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name()) for t in elem.source_tables
528+
t if isinstance(t, TableAlias) else TableAlias(source_table_=t, name=parent_c.new_unique_name()) for t in elem.source_tables
526529
]
527530
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
528531
op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
@@ -823,6 +826,7 @@ def __getitem__(self, i):
823826
return self.rows[i]
824827

825828

829+
@attrs.define
826830
class Database(abc.ABC):
827831
"""Base abstract class for databases.
828832
@@ -1099,6 +1103,7 @@ def is_autocommit(self) -> bool:
10991103
"Return whether the database autocommits changes. When false, COMMIT statements are skipped."
11001104

11011105

1106+
@attrs.define(init=False, slots=False)
11021107
class ThreadedDatabase(Database):
11031108
"""Access the database through singleton threads.
11041109

data_diff/databases/bigquery.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import re
22
from typing import Any, List, Union
3+
4+
import attrs
5+
36
from data_diff.abcs.database_types import (
47
ColType,
58
Array,
@@ -50,11 +53,13 @@ def import_bigquery_service_account():
5053
return service_account
5154

5255

56+
@attrs.define
5357
class Mixin_MD5(AbstractMixin_MD5):
5458
def md5_as_int(self, s: str) -> str:
5559
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"
5660

5761

62+
@attrs.define
5863
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
5964
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6065
if coltype.rounds:
@@ -99,6 +104,7 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str:
99104
return f"to_json_string({value})"
100105

101106

107+
@attrs.define
102108
class Mixin_Schema(AbstractMixin_Schema):
103109
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
104110
return (
@@ -112,6 +118,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
112118
)
113119

114120

121+
@attrs.define
115122
class Mixin_TimeTravel(AbstractMixin_TimeTravel):
116123
def time_travel(
117124
self,
@@ -139,6 +146,7 @@ def time_travel(
139146
)
140147

141148

149+
@attrs.define
142150
class Dialect(BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
143151
name = "BigQuery"
144152
ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation

data_diff/databases/clickhouse.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional, Type
22

3+
import attrs
4+
35
from data_diff.databases.base import (
46
MD5_HEXDIGITS,
57
CHECKSUM_HEXDIGITS,
@@ -35,12 +37,14 @@ def import_clickhouse():
3537
return clickhouse_driver
3638

3739

40+
@attrs.define
3841
class Mixin_MD5(AbstractMixin_MD5):
3942
def md5_as_int(self, s: str) -> str:
4043
substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS
4144
return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))"
4245

4346

47+
@attrs.define
4448
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
4549
def normalize_number(self, value: str, coltype: FractionalType) -> str:
4650
# If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
@@ -99,6 +103,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
99103
return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')"
100104

101105

106+
@attrs.define
102107
class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
103108
name = "Clickhouse"
104109
ROUNDS_ON_PREC_LOSS = False
@@ -163,6 +168,7 @@ def current_timestamp(self) -> str:
163168
return "now()"
164169

165170

171+
@attrs.define
166172
class Clickhouse(ThreadedDatabase):
167173
dialect = Dialect()
168174
CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"

data_diff/databases/databricks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from typing import Dict, Sequence
33
import logging
44

5+
import attrs
6+
57
from data_diff.abcs.database_types import (
68
Integer,
79
Float,
@@ -34,11 +36,13 @@ def import_databricks():
3436
return databricks
3537

3638

39+
@attrs.define
3740
class Mixin_MD5(AbstractMixin_MD5):
3841
def md5_as_int(self, s: str) -> str:
3942
return f"cast(conv(substr(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16, 10) as decimal(38, 0))"
4043

4144

45+
@attrs.define
4246
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
4347
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
4448
"""Databricks timestamp contains no more than 6 digits in precision"""
@@ -60,6 +64,7 @@ def normalize_boolean(self, value: str, _coltype: Boolean) -> str:
6064
return self.to_string(f"cast ({value} as int)")
6165

6266

67+
@attrs.define
6368
class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
6469
name = "Databricks"
6570
ROUNDS_ON_PREC_LOSS = True
@@ -99,6 +104,7 @@ def parse_table_name(self, name: str) -> DbPath:
99104
return tuple(i for i in path if i is not None)
100105

101106

107+
@attrs.define(init=False)
102108
class Databricks(ThreadedDatabase):
103109
dialect = Dialect()
104110
CONNECT_URI_HELP = "databricks://:<access_token>@<server_hostname>/<http_path>"

0 commit comments

Comments
 (0)