Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit be24cb2

Browse files
authored
Merge pull request #824 from datafold/overriddeable-infotree-classes
Make InfoTree classes overrideable
2 parents d4589c7 + bf5f60d commit be24cb2

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

data_diff/diff_tables.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from enum import Enum
77
from contextlib import contextmanager
88
from operator import methodcaller
9-
from typing import Dict, Set, List, Tuple, Iterator, Optional
9+
from typing import Dict, Set, List, Tuple, Iterator, Optional, Union
1010
from concurrent.futures import ThreadPoolExecutor, as_completed
1111

1212
import attrs
@@ -182,6 +182,8 @@ def get_stats_dict(self, is_dbt: bool = False):
182182

183183
@attrs.define(frozen=False)
184184
class TableDiffer(ThreadBase, ABC):
185+
INFO_TREE_CLASS = InfoTree
186+
185187
bisection_factor = 32
186188
stats: dict = {}
187189

@@ -204,7 +206,8 @@ def diff_tables(self, table1: TableSegment, table2: TableSegment, info_tree: Inf
204206
Where `row` is a tuple of values, corresponding to the diffed columns.
205207
"""
206208
if info_tree is None:
207-
info_tree = InfoTree(SegmentInfo([table1, table2]))
209+
segment_info = self.INFO_TREE_CLASS.SEGMENT_INFO_CLASS([table1, table2])
210+
info_tree = self.INFO_TREE_CLASS(segment_info)
208211
return DiffResultWrapper(self._diff_tables_wrapper(table1, table2, info_tree), info_tree, self.stats)
209212

210213
def _diff_tables_wrapper(self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree) -> DiffResult:
@@ -259,7 +262,7 @@ def _validate_and_adjust_columns(self, table1: TableSegment, table2: TableSegmen
259262

260263
def _diff_tables_root(
261264
self, table1: TableSegment, table2: TableSegment, info_tree: InfoTree
262-
) -> DiffResult | DiffResultList:
265+
) -> Union[DiffResult, DiffResultList]:
263266
return self._bisect_and_diff_tables(table1, table2, info_tree)
264267

265268
@abstractmethod

data_diff/info_tree.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import List, Dict, Optional, Any, Tuple, Union
22

33
import attrs
4+
from typing_extensions import Self
45

56
from data_diff.table_segment import TableSegment
67

@@ -41,11 +42,14 @@ def update_from_children(self, child_infos):
4142

4243
@attrs.define(frozen=True)
4344
class InfoTree:
45+
SEGMENT_INFO_CLASS = SegmentInfo
46+
4447
info: SegmentInfo
4548
children: List["InfoTree"] = attrs.field(factory=list)
4649

47-
def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: int = None):
48-
node = InfoTree(SegmentInfo([table1, table2], max_rows=max_rows))
50+
def add_node(self, table1: TableSegment, table2: TableSegment, max_rows: Optional[int] = None) -> Self:
51+
cls = self.__class__
52+
node = cls(cls.SEGMENT_INFO_CLASS([table1, table2], max_rows=max_rows))
4953
self.children.append(node)
5054
return node
5155

0 commit comments

Comments
 (0)