diff --git a/data_diff/databases/databricks.py b/data_diff/databases/databricks.py index bed17dea..de5ea8b7 100644 --- a/data_diff/databases/databricks.py +++ b/data_diff/databases/databricks.py @@ -53,6 +53,7 @@ class Dialect(BaseDialect): "TIMESTAMP_NTZ": Timestamp, # Text "STRING": Text, + "VARCHAR": Text, # Boolean "BOOLEAN": Boolean, } @@ -138,25 +139,47 @@ def create_connection(self): raise ConnectionError(*e.args) from e def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: - # Databricks has INFORMATION_SCHEMA only for Databricks Runtime, not for Databricks SQL. - # https://docs.databricks.com/spark/latest/spark-sql/language-manual/information-schema/columns.html - # So, to obtain information about schema, we should use another approach. - conn = self.create_connection() + table_schema = {} - catalog, schema, table = self._normalize_table_path(path) - with conn.cursor() as cursor: - cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) - try: - rows = cursor.fetchall() - finally: - conn.close() - if not rows: - raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - - d = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} - assert len(d) == len(rows) - return d + try: + table_schema = super().query_table_schema(path) + except: + logging.warning("Failed to get schema from information_schema, falling back to legacy approach.") + + if not table_schema: + # This legacy approach can cause bugs. e.g. VARCHAR(255) -> VARCHAR(255) + # and not the expected VARCHAR + + # I don't think we'll fall back to this approach, but if so, see above + catalog, schema, table = self._normalize_table_path(path) + with conn.cursor() as cursor: + cursor.columns(catalog_name=catalog, schema_name=schema, table_name=table) + try: + rows = cursor.fetchall() + finally: + conn.close() + if not rows: + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") + + table_schema = {r.COLUMN_NAME: (r.COLUMN_NAME, r.TYPE_NAME, r.DECIMAL_DIGITS, None, None) for r in rows} + assert len(table_schema) == len(rows) + return table_schema + else: + return table_schema + + def select_table_schema(self, path: DbPath) -> str: + """Provide SQL for selecting the table schema as (name, type, date_prec, num_prec)""" + database, schema, name = self._normalize_table_path(path) + info_schema_path = ["information_schema", "columns"] + if database: + info_schema_path.insert(0, database) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale " + f"FROM {'.'.join(info_schema_path)} " + f"WHERE table_name = '{name}' AND table_schema = '{schema}'" + ) def _process_table_schema( self, path: DbPath, raw_schema: Dict[str, tuple], filter_columns: Sequence[str], where: str = None diff --git a/data_diff/databases/redshift.py b/data_diff/databases/redshift.py index 7a621f57..dcf061c4 100644 --- a/data_diff/databases/redshift.py +++ b/data_diff/databases/redshift.py @@ -121,15 +121,15 @@ def query_external_table_schema(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - d = {r[0]: r for r in rows} - assert len(d) == len(rows) - return d + schema_dict = self._normalize_schema_info(rows) + + return schema_dict def select_view_columns(self, path: DbPath) -> str: _, schema, table = self._normalize_table_path(path) return """select * from pg_get_cols('{}.{}') - cols(view_schema name, view_name name, col_name name, col_type varchar, col_num int) + cols(col_name name, col_type varchar) """.format(schema, table) def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: @@ -138,10 +138,17 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: if not rows: raise RuntimeError(f"{self.name}: View '{'.'.join(path)}' does not exist, or has no columns") - output = {} + schema_dict = self._normalize_schema_info(rows) + + return schema_dict + + # when using a non-information_schema source, strip (N) from type(N) etc. to match + # typical information_schema output + def _normalize_schema_info(self, rows) -> Dict[str, tuple]: + schema_dict = {} for r in rows: - col_name = r[2] - type_info = r[3].split("(") + col_name = r[0] + type_info = r[1].split("(") base_type = type_info[0] precision = None scale = None @@ -153,9 +160,8 @@ def query_pg_get_cols(self, path: DbPath) -> Dict[str, tuple]: scale = int(scale) out = [col_name, base_type, None, precision, scale] - output[col_name] = tuple(out) - - return output + schema_dict[col_name] = tuple(out) + return schema_dict def query_table_schema(self, path: DbPath) -> Dict[str, tuple]: try: diff --git a/data_diff/schema.py b/data_diff/schema.py index db17d287..67b4261f 100644 --- a/data_diff/schema.py +++ b/data_diff/schema.py @@ -9,7 +9,7 @@ def create_schema(db_name: str, table_path: DbPath, schema: dict, case_sensitive: bool) -> CaseAwareMapping: - logger.debug(f"[{db_name}] Schema = {schema}") + logger.info(f"[{db_name}] Schema = {schema}") if case_sensitive: return CaseSensitiveDict(schema)