diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index 47c372ee..69c6fb3b 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -1,3 +1,4 @@ +from typing import List from ..abcs.database_types import ( DbPath, JSON, @@ -92,6 +93,10 @@ def quote(self, s: str): def to_string(self, s: str): return f"{s}::varchar" + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0f24198e..2543006d 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -70,3 +70,47 @@ def test_uuid(self): self.connection.query(self.table_src.drop(True)) self.connection.query(self.table_dst.drop(True)) mysql_conn.query(self.table_dst.drop(True)) + + +class Test100Fields(unittest.TestCase): + def setUp(self) -> None: + self.connection = get_conn(db.PostgreSQL) + + table_suffix = random_table_suffix() + + self.table_src_name = f"src{table_suffix}" + self.table_dst_name = f"dst{table_suffix}" + + self.table_src = table(self.table_src_name) + self.table_dst = table(self.table_dst_name) + + def test_100_fields(self): + self.connection.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";', None) + + columns = [f"col{i}" for i in range(100)] + fields = " ,".join(f'"{field}" TEXT' for field in columns) + + queries = [ + self.table_src.drop(True), + self.table_dst.drop(True), + f"CREATE TABLE {self.table_src_name} (id uuid DEFAULT uuid_generate_v4 (), {fields})", + commit, + self.table_src.insert_rows([[f"{x * y}" for x in range(100)] for y in range(10)], columns=columns), + commit, + self.table_dst.create(self.table_src), + commit, + self.table_src.insert_rows([[1 for x in range(100)]], columns=columns), + commit, + ] + + for query in queries: + self.connection.query(query) + + a = TableSegment(self.connection, self.table_src.path, ("id",), extra_columns=tuple(columns)) + b = TableSegment(self.connection, self.table_dst.path, ("id",), extra_columns=tuple(columns)) + + differ = HashDiffer() + diff = list(differ.diff_tables(a, b)) + id_ = diff[0][1][0] + result = (id_,) + tuple("1" for x in range(100)) + self.assertEqual(diff, [("-", result)])