Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Compile all AST elements always via dialects, never directly #713

Merged
merged 4 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions data_diff/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Sequence, Tuple, Iterator, Optional, Union

from data_diff.abcs.database_types import DbTime, DbPath
from data_diff.databases import Database
from data_diff.tracking import disable_tracking
from data_diff.databases._connect import connect
from data_diff.diff_tables import Algorithm
Expand Down Expand Up @@ -31,10 +32,10 @@ def connect_to_table(
if isinstance(key_columns, str):
key_columns = (key_columns,)

db = connect(db_info, thread_count=thread_count)
db: Database = connect(db_info, thread_count=thread_count)

if isinstance(table_name, str):
table_name = db.parse_table_name(table_name)
table_name = db.dialect.parse_table_name(table_name)

return TableSegment(db, table_name, key_columns, **kwargs)

Expand Down Expand Up @@ -161,7 +162,8 @@ def diff_tables(
)
elif algorithm == Algorithm.JOINDIFF:
if isinstance(materialize_to_table, str):
materialize_to_table = table1.database.parse_table_name(eval_name_template(materialize_to_table))
table_name = eval_name_template(materialize_to_table)
materialize_to_table = table1.database.dialect.parse_table_name(table_name)
differ = JoinDiffer(
threaded=threaded,
max_threadpool_size=max_threadpool_size,
Expand Down
9 changes: 5 additions & 4 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import json
import logging
from itertools import islice
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

import rich
from rich.logging import RichHandler
import click

from data_diff import Database
from data_diff.schema import create_schema
from data_diff.queries.api import current_timestamp

Expand Down Expand Up @@ -425,7 +426,7 @@ def _data_diff(
logging.error(f"Error while parsing age expression: {e}")
return

dbs = db1, db2
dbs: Tuple[Database, Database] = db1, db2

if interactive:
for db in dbs:
Expand All @@ -444,7 +445,7 @@ def _data_diff(
materialize_all_rows=materialize_all_rows,
table_write_limit=table_write_limit,
materialize_to_table=materialize_to_table
and db1.parse_table_name(eval_name_template(materialize_to_table)),
and db1.dialect.parse_table_name(eval_name_template(materialize_to_table)),
)
else:
assert algorithm == Algorithm.HASHDIFF
Expand All @@ -456,7 +457,7 @@ def _data_diff(
)

table_names = table1, table2
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
table_paths = [db.dialect.parse_table_name(t) for db, t in safezip(dbs, table_names)]

schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
schema1, schema2 = schemas = [
Expand Down
12 changes: 3 additions & 9 deletions data_diff/abcs/compiler.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from typing import Any, Dict
from abc import ABC, abstractmethod
from abc import ABC


class AbstractCompiler(ABC):
@abstractmethod
def compile(self, elem: Any, params: Dict[str, Any] = None) -> str:
...
pass


class Compilable(ABC):
# TODO generic syntax, so we can write Compilable[T] for expressions returning a value of type T
@abstractmethod
def compile(self, c: AbstractCompiler) -> str:
...
pass
15 changes: 10 additions & 5 deletions data_diff/abcs/database_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import decimal
from abc import ABC, abstractmethod
from typing import Sequence, Optional, Tuple, Type, Union, Dict, List
from typing import Sequence, Optional, Tuple, Union, Dict, List
from datetime import datetime

from runtype import dataclass
from typing_extensions import Self

from data_diff.abcs.compiler import AbstractCompiler
from data_diff.utils import ArithAlphanumeric, ArithUUID, Unknown


Expand Down Expand Up @@ -176,6 +177,14 @@ class UnknownColType(ColType):
class AbstractDialect(ABC):
"""Dialect-dependent query expressions"""

@abstractmethod
def compile(self, compiler: AbstractCompiler, elem, params=None) -> str:
raise NotImplementedError

@abstractmethod
def parse_table_name(self, name: str) -> DbPath:
"Parse the given table name into a DbPath"

@property
@abstractmethod
def name(self) -> str:
Expand Down Expand Up @@ -319,10 +328,6 @@ def _process_table_schema(

"""

@abstractmethod
def parse_table_name(self, name: str) -> DbPath:
"Parse the given table name into a DbPath"

@abstractmethod
def close(self):
"Close connection(s) to the database instance. Querying will stop functioning."
Expand Down
5 changes: 0 additions & 5 deletions data_diff/bound_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing_extensions import Self

from data_diff.abcs.database_types import AbstractDatabase
from data_diff.abcs.compiler import AbstractCompiler
from data_diff.queries.ast_classes import ExprNode, TablePath, Compilable
from data_diff.queries.api import table
from data_diff.schema import create_schema
Expand Down Expand Up @@ -37,10 +36,6 @@ def query(self, res_type=list):
def type(self):
return self.node.type

def compile(self, c: AbstractCompiler) -> str:
assert c.database is self.database
return self.node.compile(c)


def bind_node(node, database):
return BoundNode(database, node)
Expand Down
Loading