Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 33deadb

Browse files
author
Sergey Vasilyev
committed
Convert the remaining classes to attrs
Since we now use `attrs` for some classes, let's use `attrs` for them all — at least those belonging to the same hierarchies. This will ensure that all classes are slotted and will strictly check that we define attributes properly, especially in cases of multiple inheritance. Except for Pydantic models and Python exceptions. Despite the attrs classes are not frozen by default, we keep it explicitly stated, so that we see which classes were or were not frozen before the switch from runtype to attrs. We can later freeze more classes if/when it works (the stricter, the better). For this reason, we never unfreeze classes that were previously frozen.
1 parent cb9a8f9 commit 33deadb

24 files changed

+176
-36
lines changed

data_diff/abcs/database_types.py

+16
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,32 @@ class PrecisionType(ColType):
2626
rounds: Union[bool, Unknown] = Unknown
2727

2828

29+
@attrs.define(frozen=True)
2930
class Boolean(ColType):
3031
precision = 0
3132

3233

34+
@attrs.define(frozen=True)
3335
class TemporalType(PrecisionType):
3436
pass
3537

3638

39+
@attrs.define(frozen=True)
3740
class Timestamp(TemporalType):
3841
pass
3942

4043

44+
@attrs.define(frozen=True)
4145
class TimestampTZ(TemporalType):
4246
pass
4347

4448

49+
@attrs.define(frozen=True)
4550
class Datetime(TemporalType):
4651
pass
4752

4853

54+
@attrs.define(frozen=True)
4955
class Date(TemporalType):
5056
pass
5157

@@ -56,14 +62,17 @@ class NumericType(ColType):
5662
precision: int
5763

5864

65+
@attrs.define(frozen=True)
5966
class FractionalType(NumericType):
6067
pass
6168

6269

70+
@attrs.define(frozen=True)
6371
class Float(FractionalType):
6472
python_type = float
6573

6674

75+
@attrs.define(frozen=True)
6776
class IKey(ABC):
6877
"Interface for ColType, for using a column as a key in table."
6978

@@ -76,6 +85,7 @@ def make_value(self, value):
7685
return self.python_type(value)
7786

7887

88+
@attrs.define(frozen=True)
7989
class Decimal(FractionalType, IKey): # Snowflake may use Decimal as a key
8090
@property
8191
def python_type(self) -> type:
@@ -89,22 +99,27 @@ class StringType(ColType):
8999
python_type = str
90100

91101

102+
@attrs.define(frozen=True)
92103
class ColType_UUID(ColType, IKey):
93104
python_type = ArithUUID
94105

95106

107+
@attrs.define(frozen=True)
96108
class ColType_Alphanum(ColType, IKey):
97109
python_type = ArithAlphanumeric
98110

99111

112+
@attrs.define(frozen=True)
100113
class Native_UUID(ColType_UUID):
101114
pass
102115

103116

117+
@attrs.define(frozen=True)
104118
class String_UUID(ColType_UUID, StringType):
105119
pass
106120

107121

122+
@attrs.define(frozen=True)
108123
class String_Alphanum(ColType_Alphanum, StringType):
109124
@staticmethod
110125
def test_value(value: str) -> bool:
@@ -118,6 +133,7 @@ def make_value(self, value):
118133
return self.python_type(value)
119134

120135

136+
@attrs.define(frozen=True)
121137
class String_VaryingAlphanum(String_Alphanum):
122138
pass
123139

data_diff/abcs/mixins.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
from abc import ABC, abstractmethod
2+
3+
import attrs
4+
25
from data_diff.abcs.database_types import (
36
Array,
47
TemporalType,
@@ -13,10 +16,12 @@
1316
from data_diff.abcs.compiler import Compilable
1417

1518

19+
@attrs.define(frozen=False)
1620
class AbstractMixin(ABC):
1721
"A mixin for a database dialect"
1822

1923

24+
@attrs.define(frozen=False)
2025
class AbstractMixin_NormalizeValue(AbstractMixin):
2126
@abstractmethod
2227
def to_comparable(self, value: str, coltype: ColType) -> str:
@@ -108,6 +113,7 @@ def normalize_value_by_type(self, value: str, coltype: ColType) -> str:
108113
return self.to_string(value)
109114

110115

116+
@attrs.define(frozen=False)
111117
class AbstractMixin_MD5(AbstractMixin):
112118
"""Methods for calculating an MD6 hash as an integer."""
113119

@@ -116,6 +122,7 @@ def md5_as_int(self, s: str) -> str:
116122
"Provide SQL for computing md5 and returning an int"
117123

118124

125+
@attrs.define(frozen=False)
119126
class AbstractMixin_Schema(AbstractMixin):
120127
"""Methods for querying the database schema
121128
@@ -134,6 +141,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
134141
"""
135142

136143

144+
@attrs.define(frozen=False)
137145
class AbstractMixin_RandomSample(AbstractMixin):
138146
@abstractmethod
139147
def random_sample_n(self, tbl: str, size: int) -> str:
@@ -151,6 +159,7 @@ def random_sample_ratio_approx(self, tbl: str, ratio: float) -> str:
151159
# """
152160

153161

162+
@attrs.define(frozen=False)
154163
class AbstractMixin_TimeTravel(AbstractMixin):
155164
@abstractmethod
156165
def time_travel(
@@ -173,6 +182,7 @@ def time_travel(
173182
"""
174183

175184

185+
@attrs.define(frozen=False)
176186
class AbstractMixin_OptimizerHints(AbstractMixin):
177187
@abstractmethod
178188
def optimizer_hints(self, optimizer_hints: str) -> str:

data_diff/databases/_connect.py

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def match_path(self, dsn):
9393
}
9494

9595

96+
@attrs.define(frozen=False, init=False)
9697
class Connect:
9798
"""Provides methods for connecting to a supported database using a URL or connection dict."""
9899

@@ -288,6 +289,7 @@ def __make_cache_key(self, db_conf: Union[str, dict]) -> Hashable:
288289
return db_conf
289290

290291

292+
@attrs.define(frozen=False, init=False)
291293
class Connect_SetUTC(Connect):
292294
"""Provides methods for connecting to a supported database using a URL or connection dict.
293295

data_diff/databases/base.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def _one(seq):
168168
return x
169169

170170

171+
@attrs.define(frozen=False)
171172
class ThreadLocalInterpreter:
172173
"""An interpeter used to execute a sequence of queries within the same thread and cursor.
173174
@@ -177,11 +178,6 @@ class ThreadLocalInterpreter:
177178
compiler: Compiler
178179
gen: Generator
179180

180-
def __init__(self, compiler: Compiler, gen: Generator):
181-
super().__init__()
182-
self.gen = gen
183-
self.compiler = compiler
184-
185181
def apply_queries(self, callback: Callable[[str], Any]):
186182
q: Expr = next(self.gen)
187183
while True:
@@ -205,6 +201,7 @@ def apply_query(callback: Callable[[str], Any], sql_code: Union[str, ThreadLocal
205201
return callback(sql_code)
206202

207203

204+
@attrs.define(frozen=False)
208205
class Mixin_Schema(AbstractMixin_Schema):
209206
def table_information(self) -> Compilable:
210207
return table("information_schema", "tables")
@@ -221,6 +218,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
221218
)
222219

223220

221+
@attrs.define(frozen=False)
224222
class Mixin_RandomSample(AbstractMixin_RandomSample):
225223
def random_sample_n(self, tbl: ITable, size: int) -> ITable:
226224
# TODO use a more efficient algorithm, when the table count is known
@@ -230,15 +228,17 @@ def random_sample_ratio_approx(self, tbl: ITable, ratio: float) -> ITable:
230228
return tbl.where(Random() < ratio)
231229

232230

231+
@attrs.define(frozen=False)
233232
class Mixin_OptimizerHints(AbstractMixin_OptimizerHints):
234233
def optimizer_hints(self, hints: str) -> str:
235234
return f"/*+ {hints} */ "
236235

237236

237+
@attrs.define(frozen=False)
238238
class BaseDialect(abc.ABC):
239239
SUPPORTS_PRIMARY_KEY: ClassVar[bool] = False
240240
SUPPORTS_INDEXES: ClassVar[bool] = False
241-
TYPE_CLASSES: ClassVar[Dict[str, type]] = {}
241+
TYPE_CLASSES: ClassVar[Dict[str, Type[ColType]]] = {}
242242
MIXINS = frozenset()
243243

244244
PLACEHOLDER_TABLE = None # Used for Oracle
@@ -540,7 +540,7 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
540540

541541
def render_join(self, parent_c: Compiler, elem: Join) -> str:
542542
tables = [
543-
t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name())
543+
t if isinstance(t, TableAlias) else TableAlias(t, name=parent_c.new_unique_name())
544544
for t in elem.source_tables
545545
]
546546
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
@@ -839,6 +839,7 @@ def __getitem__(self, i):
839839
return self.rows[i]
840840

841841

842+
@attrs.define(frozen=False)
842843
class Database(abc.ABC):
843844
"""Base abstract class for databases.
844845
@@ -1114,22 +1115,22 @@ def is_autocommit(self) -> bool:
11141115
"Return whether the database autocommits changes. When false, COMMIT statements are skipped."
11151116

11161117

1118+
@attrs.define(frozen=False)
11171119
class ThreadedDatabase(Database):
11181120
"""Access the database through singleton threads.
11191121
11201122
Used for database connectors that do not support sharing their connection between different threads.
11211123
"""
11221124

1123-
_init_error: Optional[Exception]
1124-
_queue: ThreadPoolExecutor
1125-
thread_local: threading.local
1125+
thread_count: int = 1
1126+
1127+
_init_error: Optional[Exception] = None
1128+
_queue: Optional[ThreadPoolExecutor] = None
1129+
thread_local: threading.local = attrs.field(factory=threading.local)
11261130

1127-
def __init__(self, thread_count=1):
1128-
super().__init__()
1129-
self._init_error = None
1130-
self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn)
1131-
self.thread_local = threading.local()
1132-
logger.info(f"[{self.name}] Starting a threadpool, size={thread_count}.")
1131+
def __attrs_post_init__(self):
1132+
self._queue = ThreadPoolExecutor(self.thread_count, initializer=self.set_conn)
1133+
logger.info(f"[{self.name}] Starting a threadpool, size={self.thread_count}.")
11331134

11341135
def set_conn(self):
11351136
assert not hasattr(self.thread_local, "conn")

data_diff/databases/bigquery.py

+8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import re
22
from typing import Any, List, Union
3+
4+
import attrs
5+
36
from data_diff.abcs.database_types import (
47
ColType,
58
Array,
@@ -50,11 +53,13 @@ def import_bigquery_service_account():
5053
return service_account
5154

5255

56+
@attrs.define(frozen=False)
5357
class Mixin_MD5(AbstractMixin_MD5):
5458
def md5_as_int(self, s: str) -> str:
5559
return f"cast(cast( ('0x' || substr(TO_HEX(md5({s})), 18)) as int64) as numeric)"
5660

5761

62+
@attrs.define(frozen=False)
5863
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
5964
def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
6065
if coltype.rounds:
@@ -99,6 +104,7 @@ def normalize_struct(self, value: str, _coltype: Struct) -> str:
99104
return f"to_json_string({value})"
100105

101106

107+
@attrs.define(frozen=False)
102108
class Mixin_Schema(AbstractMixin_Schema):
103109
def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
104110
return (
@@ -112,6 +118,7 @@ def list_tables(self, table_schema: str, like: Compilable = None) -> Compilable:
112118
)
113119

114120

121+
@attrs.define(frozen=False)
115122
class Mixin_TimeTravel(AbstractMixin_TimeTravel):
116123
def time_travel(
117124
self,
@@ -139,6 +146,7 @@ def time_travel(
139146
)
140147

141148

149+
@attrs.define(frozen=False)
142150
class Dialect(
143151
BaseDialect, Mixin_Schema, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue
144152
):

data_diff/databases/clickhouse.py

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Optional, Type
22

3+
import attrs
4+
35
from data_diff.databases.base import (
46
MD5_HEXDIGITS,
57
CHECKSUM_HEXDIGITS,
@@ -35,12 +37,14 @@ def import_clickhouse():
3537
return clickhouse_driver
3638

3739

40+
@attrs.define(frozen=False)
3841
class Mixin_MD5(AbstractMixin_MD5):
3942
def md5_as_int(self, s: str) -> str:
4043
substr_idx = 1 + MD5_HEXDIGITS - CHECKSUM_HEXDIGITS
4144
return f"reinterpretAsUInt128(reverse(unhex(lowerUTF8(substr(hex(MD5({s})), {substr_idx})))))"
4245

4346

47+
@attrs.define(frozen=False)
4448
class Mixin_NormalizeValue(AbstractMixin_NormalizeValue):
4549
def normalize_number(self, value: str, coltype: FractionalType) -> str:
4650
# If a decimal value has trailing zeros in a fractional part, when casting to string they are dropped.
@@ -99,6 +103,7 @@ def normalize_timestamp(self, value: str, coltype: TemporalType) -> str:
99103
return f"rpad({value}, {TIMESTAMP_PRECISION_POS + 6}, '0')"
100104

101105

106+
@attrs.define(frozen=False)
102107
class Dialect(BaseDialect, Mixin_MD5, Mixin_NormalizeValue, AbstractMixin_MD5, AbstractMixin_NormalizeValue):
103108
name = "Clickhouse"
104109
ROUNDS_ON_PREC_LOSS = False
@@ -163,6 +168,7 @@ def current_timestamp(self) -> str:
163168
return "now()"
164169

165170

171+
@attrs.define(frozen=False)
166172
class Clickhouse(ThreadedDatabase):
167173
dialect = Dialect()
168174
CONNECT_URI_HELP = "clickhouse://<user>:<password>@<host>/<database>"

0 commit comments

Comments
 (0)