Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit f3dc52b

Browse files
authored
Merge pull request #713 from datafold/simplify-ast-compilation
Compile all AST elements always via dialects, never directly
2 parents 0a7d2cb + 795bb0e commit f3dc52b

File tree

12 files changed

+503
-404
lines changed

12 files changed

+503
-404
lines changed

data_diff/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Sequence, Tuple, Iterator, Optional, Union
22

33
from data_diff.abcs.database_types import DbTime, DbPath
4+
from data_diff.databases import Database
45
from data_diff.tracking import disable_tracking
56
from data_diff.databases._connect import connect
67
from data_diff.diff_tables import Algorithm
@@ -31,10 +32,10 @@ def connect_to_table(
3132
if isinstance(key_columns, str):
3233
key_columns = (key_columns,)
3334

34-
db = connect(db_info, thread_count=thread_count)
35+
db: Database = connect(db_info, thread_count=thread_count)
3536

3637
if isinstance(table_name, str):
37-
table_name = db.parse_table_name(table_name)
38+
table_name = db.dialect.parse_table_name(table_name)
3839

3940
return TableSegment(db, table_name, key_columns, **kwargs)
4041

@@ -161,7 +162,8 @@ def diff_tables(
161162
)
162163
elif algorithm == Algorithm.JOINDIFF:
163164
if isinstance(materialize_to_table, str):
164-
materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table))
165+
table_name = eval_name_template(materialize_to_table)
166+
materialize_to_table = table1.database.dialect.parse_table_name(table_name)
165167
differ = JoinDiffer(
166168
threaded=threaded,
167169
max_threadpool_size=max_threadpool_size,

data_diff/__main__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
import json
77
import logging
88
from itertools import islice
9-
from typing import Dict, Optional
9+
from typing import Dict, Optional, Tuple
1010

1111
import rich
1212
from rich.logging import RichHandler
1313
import click
1414

15+
from data_diff import Database
1516
from data_diff.schema import create_schema
1617
from data_diff.queries.api import current_timestamp
1718

@@ -425,7 +426,7 @@ def _data_diff(
425426
logging.error(f"Error while parsing age expression: {e}")
426427
return
427428

428-
dbs = db1, db2
429+
dbs: Tuple[Database, Database] = db1, db2
429430

430431
if interactive:
431432
for db in dbs:
@@ -444,7 +445,7 @@ def _data_diff(
444445
materialize_all_rows=materialize_all_rows,
445446
table_write_limit=table_write_limit,
446447
materialize_to_table=materialize_to_table
447-
and db1.parse_table_name(eval_name_template(materialize_to_table)),
448+
and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)),
448449
)
449450
else:
450451
assert algorithm == Algorithm.HASHDIFF
@@ -456,7 +457,7 @@ def _data_diff(
456457
)
457458

458459
table_names = table1, table2
459-
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
460+
table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)]
460461

461462
schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
462463
schema1, schema2 = schemas = [

data_diff/abcs/compiler.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
1-
from typing import Any, Dict
2-
from abc import ABC, abstractmethod
1+
from abc import ABC
32

43

54
class AbstractCompiler(ABC):
6-
@abstractmethod
7-
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
8-
...
5+
pass
96

107

118
class Compilable(ABC):
12-
# TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T
13-
@abstractmethod
14-
def compile(self, c: AbstractCompiler) -> str:
15-
...
9+
pass

data_diff/abcs/database_types.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
3+
from typing import Sequence, Optional, Tuple, Union, Dict, List
44
from datetime import datetime
55

66
from runtype import dataclass
77
from typing_extensions import Self
88

9+
from data_diff.abcs.compiler import AbstractCompiler
910
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown
1011

1112

@@ -176,6 +177,14 @@ class UnknownColType(ColType):
176177
class AbstractDialect(ABC):
177178
"""Dialect-dependent query expressions"""
178179

180+
@abstractmethod
181+
def compile(self, compiler: AbstractCompiler, elem, params=None) -> str:
182+
raise NotImplementedError
183+
184+
@abstractmethod
185+
def parse_table_name(self, name: str) -> DbPath:
186+
"Parse the given table name into a DbPath"
187+
179188
@property
180189
@abstractmethod
181190
def name(self) -> str:
@@ -319,10 +328,6 @@ def _process_table_schema(
319328
320329
"""
321330

322-
@abstractmethod
323-
def parse_table_name(self, name: str) -> DbPath:
324-
"Parse the given table name into a DbPath"
325-
326331
@abstractmethod
327332
def close(self):
328333
"Close connection(s) to the database instance. Querying will stop functioning."

data_diff/bound_exprs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from typing_extensions import Self
99

1010
from data_diff.abcs.database_types import AbstractDatabase
11-
from data_diff.abcs.compiler import AbstractCompiler
1211
from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable
1312
from data_diff.queries.api import table
1413
from data_diff.schema import create_schema
@@ -37,10 +36,6 @@ def query(self, res_type=list):
3736
def type(self):
3837
return self.node.type
3938

40-
def compile(self, c: AbstractCompiler) -> str:
41-
assert c.database is self.database
42-
return self.node.compile(c)
43-
4439

4540
def bind_node(node, database):
4641
return BoundNode(database, node)

0 commit comments

Comments
 (0)