Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Resolve circular imports between schemas, datatabases, and compilers #719

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def _data_diff(

schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
schema1, schema2 = schemas = [
create_schema(db, table_path, schema, case_sensitive)
create_schema(db.name, table_path, schema, case_sensitive)
for db, table_path, schema in safezip(dbs, table_paths, schemas)
]

Expand Down
2 changes: 1 addition & 1 deletion data_diff/joindiff_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from data_diff.databases import Database, MsSQL, MySQL, BigQuery, Presto, Oracle, Snowflake
from data_diff.abcs.database_types import NumericType, DbPath
from data_diff.databases.base import Compiler
from data_diff.queries.api import (
table,
sum_,
Expand All @@ -23,7 +24,6 @@
rightjoin,
this,
when,
Compiler,
)
from data_diff.queries.ast_classes import Concat, Count, Expr, Random, TablePath, Code, ITable
from data_diff.queries.extras import NormalizeAsString
Expand Down
7 changes: 3 additions & 4 deletions data_diff/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging

from data_diff import Database
from data_diff.utils import CaseAwareMapping, CaseInsensitiveDict, CaseSensitiveDict
from data_diff.abcs.database_types import DbPath

Expand All @@ -9,13 +8,13 @@
Schema = CaseAwareMapping


def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
logger.debug(f"[{db.name}] Schema = {schema}")
def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
logger.debug(f"[{db_name}] Schema = {schema}")

if case_sensitive:
return CaseSensitiveDict(schema)

if len({k.lower() for k in schema}) < len(schema):
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
logger.warning(f'Ambiguous schema for {db_name}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
logger.warning("We recommend to disable case-insensitivity (set --case-sensitive).")
return CaseInsensitiveDict(schema)
2 changes: 1 addition & 1 deletion data_diff/table_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _where(self):

def _with_raw_schema(self, raw_schema: dict) -> Self:
schema = self.database._process_table_schema(self.table_path, raw_schema, self.relevant_columns, self._where())
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
return self.new(_schema=create_schema(self.database.name, self.table_path, schema, self.case_sensitive))

def with_schema(self) -> Self:
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
Expand Down
2 changes: 1 addition & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_correct_timezone(self):
t = table(name)
raw_schema = db.query_table_schema(t.path)
schema = db._process_table_schema(t.path, raw_schema)
schema = create_schema(self.database, t, schema, case_sensitive=True)
schema = create_schema(self.database.name, t, schema, case_sensitive=True)
t = t.replace(schema=schema)
t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision)

Expand Down