Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit cb9a8f9

Browse files
author
Sergey Vasilyev
committed
Convert the runtype classes to attrs
`attrs` is much more beneficial: * `attrs` is supported by type checkers, such as MyPy & IDEs * `attrs` is widely used and industry-proven * `attrs` is explicit in its declarations, there is no magic * `attrs` has slots But mainly for the first item — type checking by type checkers. Those that were runtype classes, are frozen. Those that were not, are unfrozen for now, but we can freeze them later if and where it works (the stricter, the better).
1 parent e7762f5 commit cb9a8f9

19 files changed

+191
-180
lines changed

data_diff/abcs/compiler.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from abc import ABC
22

3+
import attrs
34

5+
6+
@attrs.define(frozen=False)
47
class AbstractCompiler(ABC):
58
pass
69

710

11+
@attrs.define(frozen=False, eq=False)
812
class Compilable(ABC):
913
pass

data_diff/abcs/database_types.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Tuple, Union
44
from datetime import datetime
55

6-
from runtype import dataclass
6+
import attrs
77

88
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
99

@@ -13,14 +13,14 @@
1313
DbTime = datetime
1414

1515

16-
@dataclass
16+
@attrs.define(frozen=True)
1717
class ColType:
1818
@property
1919
def supported(self) -> bool:
2020
return True
2121

2222

23-
@dataclass
23+
@attrs.define(frozen=True)
2424
class PrecisionType(ColType):
2525
precision: int
2626
rounds: Union[bool, Unknown] = Unknown
@@ -50,7 +50,7 @@ class Date(TemporalType):
5050
pass
5151

5252

53-
@dataclass
53+
@attrs.define(frozen=True)
5454
class NumericType(ColType):
5555
# 'precision' signifies how many fractional digits (after the dot) we want to compare
5656
precision: int
@@ -84,7 +84,7 @@ def python_type(self) -> type:
8484
return decimal.Decimal
8585

8686

87-
@dataclass
87+
@attrs.define(frozen=True)
8888
class StringType(ColType):
8989
python_type = str
9090

@@ -122,7 +122,7 @@ class String_VaryingAlphanum(String_Alphanum):
122122
pass
123123

124124

125-
@dataclass
125+
@attrs.define(frozen=True)
126126
class String_FixedAlphanum(String_Alphanum):
127127
length: int
128128

@@ -132,20 +132,20 @@ def make_value(self, value):
132132
return self.python_type(value, max_len=self.length)
133133

134134

135-
@dataclass
135+
@attrs.define(frozen=True)
136136
class Text(StringType):
137137
@property
138138
def supported(self) -> bool:
139139
return False
140140

141141

142142
# In majority of DBMSes, it is called JSON/JSONB. Only in Snowflake, it is OBJECT.
143-
@dataclass
143+
@attrs.define(frozen=True)
144144
class JSON(ColType):
145145
pass
146146

147147

148-
@dataclass
148+
@attrs.define(frozen=True)
149149
class Array(ColType):
150150
item_type: ColType
151151

@@ -155,21 +155,21 @@ class Array(ColType):
155155
# For example, in BigQuery:
156156
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/data-types#struct_type
157157
# - https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#struct_literals
158-
@dataclass
158+
@attrs.define(frozen=True)
159159
class Struct(ColType):
160160
pass
161161

162162

163-
@dataclass
163+
@attrs.define(frozen=True)
164164
class Integer(NumericType, IKey):
165165
precision: int = 0
166166
python_type: type = int
167167

168-
def __post_init__(self):
168+
def __attrs_post_init__(self):
169169
assert self.precision == 0
170170

171171

172-
@dataclass
172+
@attrs.define(frozen=True)
173173
class UnknownColType(ColType):
174174
text: str
175175

data_diff/cloud/datafold_api.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import base64
2-
import dataclasses
32
import enum
43
import time
54
from typing import Any, Dict, List, Optional, Type, Tuple
65

6+
import attrs
77
import pydantic
88
import requests
99
from typing_extensions import Self
@@ -178,13 +178,13 @@ class TCloudApiDataSourceTestResult(pydantic.BaseModel):
178178
result: Optional[TCloudDataSourceTestResult]
179179

180180

181-
@dataclasses.dataclass
181+
@attrs.define(frozen=True)
182182
class DatafoldAPI:
183183
api_key: str
184184
host: str = "https://app.datafold.com"
185185
timeout: int = 30
186186

187-
def __post_init__(self):
187+
def __attrs_post_init__(self):
188188
self.host = self.host.rstrip("/")
189189
self.headers = {
190190
"Authorization": f"Key {self.api_key}",

data_diff/databases/_connect.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
from itertools import zip_longest
44
from contextlib import suppress
55
import weakref
6+
7+
import attrs
68
import dsnparse
79
import toml
810

9-
from runtype import dataclass
1011
from typing_extensions import Self
1112

1213
from data_diff.databases.base import Database, ThreadedDatabase
@@ -25,7 +26,7 @@
2526
from data_diff.databases.mssql import MsSQL
2627

2728

28-
@dataclass
29+
@attrs.define(frozen=True)
2930
class MatchUriPath:
3031
database_cls: Type[Database]
3132

data_diff/databases/base.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
import functools
3-
from dataclasses import field
3+
import random
44
from datetime import datetime
55
import math
66
import sys
@@ -14,7 +14,7 @@
1414
import decimal
1515
import contextvars
1616

17-
from runtype import dataclass
17+
import attrs
1818
from typing_extensions import Self
1919

2020
from data_diff.abcs.compiler import AbstractCompiler
@@ -90,12 +90,7 @@ class CompileError(Exception):
9090
pass
9191

9292

93-
# TODO: remove once switched to attrs, where ForwardRef[]/strings are resolved.
94-
class _RuntypeHackToFixCicularRefrencedDatabase:
95-
dialect: "BaseDialect"
96-
97-
98-
@dataclass
93+
@attrs.define(frozen=True)
9994
class Compiler(AbstractCompiler):
10095
"""
10196
Compiler bears the context for a single compilation.
@@ -107,16 +102,16 @@ class Compiler(AbstractCompiler):
107102
# Database is needed to normalize tables. Dialect is needed for recursive compilations.
108103
# In theory, it is many-to-many relations: e.g. a generic ODBC driver with multiple dialects.
109104
# In practice, we currently bind the dialects to the specific database classes.
110-
database: _RuntypeHackToFixCicularRefrencedDatabase
105+
database: "Database"
111106

112107
in_select: bool = False # Compilation runtime flag
113108
in_join: bool = False # Compilation runtime flag
114109

115-
_table_context: List = field(default_factory=list) # List[ITable]
116-
_subqueries: Dict[str, Any] = field(default_factory=dict) # XXX not thread-safe
110+
_table_context: List = attrs.field(factory=list) # List[ITable]
111+
_subqueries: Dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe
117112
root: bool = True
118113

119-
_counter: List = field(default_factory=lambda: [0])
114+
_counter: List = attrs.field(factory=lambda: [0])
120115

121116
@property
122117
def dialect(self) -> "BaseDialect":
@@ -136,7 +131,7 @@ def new_unique_table_name(self, prefix="tmp") -> DbPath:
136131
return self.database.dialect.parse_table_name(table_name)
137132

138133
def add_table_context(self, *tables: Sequence, **kw) -> Self:
139-
return self.replace(_table_context=self._table_context + list(tables), **kw)
134+
return attrs.evolve(self, table_context=self._table_context + list(tables), **kw)
140135

141136

142137
def parse_table_name(t):
@@ -272,7 +267,7 @@ def _compile(self, compiler: Compiler, elem) -> str:
272267
if elem is None:
273268
return "NULL"
274269
elif isinstance(elem, Compilable):
275-
return self.render_compilable(compiler.replace(root=False), elem)
270+
return self.render_compilable(attrs.evolve(compiler, root=False), elem)
276271
elif isinstance(elem, str):
277272
return f"'{elem}'"
278273
elif isinstance(elem, (int, float)):
@@ -382,7 +377,7 @@ def render_column(self, c: Compiler, elem: Column) -> str:
382377
return self.quote(elem.name)
383378

384379
def render_cte(self, parent_c: Compiler, elem: Cte) -> str:
385-
c: Compiler = parent_c.replace(_table_context=[], in_select=False)
380+
c: Compiler = attrs.evolve(parent_c, table_context=[], in_select=False)
386381
compiled = self.compile(c, elem.source_table)
387382

388383
name = elem.name or parent_c.new_unique_name()
@@ -495,7 +490,7 @@ def render_tablealias(self, c: Compiler, elem: TableAlias) -> str:
495490
return f"{self.compile(c, elem.source_table)} {self.quote(elem.name)}"
496491

497492
def render_tableop(self, parent_c: Compiler, elem: TableOp) -> str:
498-
c: Compiler = parent_c.replace(in_select=False)
493+
c: Compiler = attrs.evolve(parent_c, in_select=False)
499494
table_expr = f"{self.compile(c, elem.table1)} {elem.op} {self.compile(c, elem.table2)}"
500495
if parent_c.in_select:
501496
table_expr = f"({table_expr}) {c.new_unique_name()}"
@@ -507,7 +502,7 @@ def render__resolvecolumn(self, c: Compiler, elem: _ResolveColumn) -> str:
507502
return self.compile(c, elem._get_resolved())
508503

509504
def render_select(self, parent_c: Compiler, elem: Select) -> str:
510-
c: Compiler = parent_c.replace(in_select=True) # .add_table_context(self.table)
505+
c: Compiler = attrs.evolve(parent_c, in_select=True) # .add_table_context(self.table)
511506
compile_fn = functools.partial(self.compile, c)
512507

513508
columns = ", ".join(map(compile_fn, elem.columns)) if elem.columns else "*"
@@ -545,7 +540,8 @@ def render_select(self, parent_c: Compiler, elem: Select) -> str:
545540

546541
def render_join(self, parent_c: Compiler, elem: Join) -> str:
547542
tables = [
548-
t if isinstance(t, TableAlias) else TableAlias(t, parent_c.new_unique_name()) for t in elem.source_tables
543+
t if isinstance(t, TableAlias) else TableAlias(source_table=t, name=parent_c.new_unique_name())
544+
for t in elem.source_tables
549545
]
550546
c = parent_c.add_table_context(*tables, in_join=True, in_select=False)
551547
op = " JOIN " if elem.op is None else f" {elem.op} JOIN "
@@ -578,7 +574,8 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
578574
if isinstance(elem.table, Select) and elem.table.columns is None and elem.table.group_by_exprs is None:
579575
return self.compile(
580576
c,
581-
elem.table.replace(
577+
attrs.evolve(
578+
elem.table,
582579
columns=columns,
583580
group_by_exprs=[Code(k) for k in keys],
584581
having_exprs=elem.having_exprs,
@@ -590,7 +587,7 @@ def render_groupby(self, c: Compiler, elem: GroupBy) -> str:
590587
having_str = (
591588
" HAVING " + " AND ".join(map(compile_fn, elem.having_exprs)) if elem.having_exprs is not None else ""
592589
)
593-
select = f"SELECT {columns_str} FROM {self.compile(c.replace(in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
590+
select = f"SELECT {columns_str} FROM {self.compile(attrs.evolve(c, in_select=True), elem.table)} GROUP BY {keys_str}{having_str}"
594591

595592
if c.in_select:
596593
select = f"({select}) {c.new_unique_name()}"
@@ -827,7 +824,7 @@ def set_timezone_to_utc(self) -> str:
827824
T = TypeVar("T", bound=BaseDialect)
828825

829826

830-
@dataclass
827+
@attrs.define(frozen=True)
831828
class QueryResult:
832829
rows: list
833830
columns: Optional[list] = None
@@ -842,7 +839,7 @@ def __getitem__(self, i):
842839
return self.rows[i]
843840

844841

845-
class Database(abc.ABC, _RuntypeHackToFixCicularRefrencedDatabase):
842+
class Database(abc.ABC):
846843
"""Base abstract class for databases.
847844
848845
Used for providing connection code and implementation specific SQL utilities.

data_diff/diff_tables.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
import time
55
from abc import ABC, abstractmethod
6-
from dataclasses import field
76
from enum import Enum
87
from contextlib import contextmanager
98
from operator import methodcaller
109
from typing import Dict, Tuple, Iterator, Optional
1110
from concurrent.futures import ThreadPoolExecutor, as_completed
1211

13-
from runtype import dataclass
12+
import attrs
1413

1514
from data_diff.info_tree import InfoTree, SegmentInfo
1615
from data_diff.utils import dbt_diff_string_template, run_as_daemon, safezip, getLogger, truncate_error, Vector
@@ -31,7 +30,7 @@ class Algorithm(Enum):
3130
DiffResult = Iterator[Tuple[str, tuple]] # Iterator[Tuple[Literal["+", "-"], tuple]]
3231

3332

34-
@dataclass
33+
@attrs.define(frozen=True)
3534
class ThreadBase:
3635
"Provides utility methods for optional threading"
3736

@@ -72,7 +71,7 @@ def _run_in_background(self, *funcs):
7271
f.result()
7372

7473

75-
@dataclass
74+
@attrs.define(frozen=True)
7675
class DiffStats:
7776
diff_by_sign: Dict[str, int]
7877
table1_count: int
@@ -82,12 +81,12 @@ class DiffStats:
8281
extra_column_diffs: Optional[Dict[str, int]]
8382

8483

85-
@dataclass
84+
@attrs.define(frozen=True)
8685
class DiffResultWrapper:
8786
diff: iter # DiffResult
8887
info_tree: InfoTree
8988
stats: dict
90-
result_list: list = field(default_factory=list)
89+
result_list: list = attrs.field(factory=list)
9190

9291
def __iter__(self):
9392
yield from self.result_list
@@ -203,7 +202,7 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
203202

204203
def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
205204
if is_tracking_enabled():
206-
options = dict(self)
205+
options = attrs.asdict(self, recurse=False)
207206
options["differ_name"] = type(self).__name__
208207
event_json = create_start_event_json(options)
209208
run_as_daemon(send_event_json, event_json)

0 commit comments

Comments
 (0)