Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Normalize schema info databricks redshift #781

Merged
merged 5 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions data_diff/databases/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class Dialect(BaseDialect):
"TIMESTAMP_NTZ": Timestamp,
# Text
"STRING": Text,
"VARCHAR": Text,
# Boolean
"BOOLEAN": Boolean,
}
Expand Down Expand Up @@ -138,25 +139,47 @@ def create_connection(self):
raise ConnectionError(*e.args) from e

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
# Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL.
# https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html
# So, to obtain information about schema, we should use another approach.

conn = self.create_connection()
table_schema = {}

catalog, schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(d) == len(rows)
return d
try:
table_schema = super().query_table_schema(path)
except:
logging.warning("Failed to get schema from information_schema, falling back to legacy approach.")

if not table_schema:
# This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255)
# and not the expected VARCHAR

# I don't think we'll fall back to this approach, but if so, see above
catalog, schema, table = self._normalize_table_path(path)
with conn.cursor() as cursor:
cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table)
try:
rows = cursor.fetchall()
finally:
conn.close()
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows}
assert len(table_schema) == len(rows)
return table_schema
else:
return table_schema

def select_table_schema(self, path: DbPath) -> str:
"""Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)"""
database, schema, name = self._normalize_table_path(path)
info_schema_path = ["information_schema", "columns"]
if database:
info_schema_path.insert(0, database)

return (
"SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale "
f"FROM {'.'.join(info_schema_path)} "
f"WHERE table_name = '{name}' AND table_schema = '{schema}'"
)

def _process_table_schema(
self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None
Expand Down
26 changes: 16 additions & 10 deletions data_diff/databases/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,15 +121,15 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]:
if not rows:
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")

d = {r[0]: r for r in rows}
assert len(d) == len(rows)
return d
schema_dict = self._normalize_schema_info(rows)

return schema_dict

def select_view_columns(self, path: DbPath) -> str:
_, schema, table = self._normalize_table_path(path)

return """select * from pg_get_cols('{}.{}')
cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int)
cols(col_name name, col_type varchar)
""".format(schema, table)

def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
Expand All @@ -138,10 +138,17 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
if not rows:
raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns")

output = {}
schema_dict = self._normalize_schema_info(rows)

return schema_dict

# when using a non-information_schema source, strip (N) from type(N) etc. to match
# typical information_schema output
def _normalize_schema_info(self, rows) -> Dict[str, tuple]:
schema_dict = {}
for r in rows:
col_name = r[2]
type_info = r[3].split("(")
col_name = r[0]
type_info = r[1].split("(")
base_type = type_info[0]
precision = None
scale = None
Expand All @@ -153,9 +160,8 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]:
scale = int(scale)

out = [col_name, base_type, None, precision, scale]
output[col_name] = tuple(out)

return output
schema_dict[col_name] = tuple(out)
return schema_dict

def query_table_schema(self, path: DbPath) -> Dict[str, tuple]:
try:
Expand Down
2 changes: 1 addition & 1 deletion data_diff/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping:
logger.debug(f"[{db_name}] Schema = {schema}")
logger.info(f"[{db_name}] Schema = {schema}")

if case_sensitive:
return CaseSensitiveDict(schema)
Expand Down