From 09683af4b61b0edcb98b0bebabe02ffbae18b476 Mon Sep 17 00:00:00 2001 From: David Raznick <kindly@gmail.com> Date: Fri, 4 Aug 2023 23:39:10 +0100 Subject: [PATCH] Fix for more than 50 fields in Postgres. Postgres does not allow functions that have more that have more than 100 arguments. When using the concat function, this limits comparisons to less than 50 fields. Using || for concat like the Oracle variant fixes this. --- data_diff/sqeleton/databases/postgresql.py | 5 +++ tests/test_postgresql.py | 44 ++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/data_diff/sqeleton/databases/postgresql.py b/data_diff/sqeleton/databases/postgresql.py index 47c372ee..69c6fb3b 100644 --- a/data_diff/sqeleton/databases/postgresql.py +++ b/data_diff/sqeleton/databases/postgresql.py @@ -1,3 +1,4 @@ +from typing import List from ..abcs.database_types import ( DbPath, JSON, @@ -92,6 +93,10 @@ def quote(self, s: str): def to_string(self, s: str): return f"{s}::varchar" + def concat(self, items: List[str]) -> str: + joined_exprs = " || ".join(items) + return f"({joined_exprs})" + def _convert_db_precision_to_digits(self, p: int) -> int: # Subtracting 2 due to wierd precision issues in PostgreSQL return super()._convert_db_precision_to_digits(p) - 2 diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0f24198e..2543006d 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -70,3 +70,47 @@ def test_uuid(self): self.connection.query(self.table_src.drop(True)) self.connection.query(self.table_dst.drop(True)) mysql_conn.query(self.table_dst.drop(True)) + + +class Test100Fields(unittest.TestCase): + def setUp(self) -> None: + self.connection = get_conn(db.PostgreSQL) + + table_suffix = random_table_suffix() + + self.table_src_name = f"src{table_suffix}" + self.table_dst_name = f"dst{table_suffix}" + + self.table_src = table(self.table_src_name) + self.table_dst = table(self.table_dst_name) + + def test_100_fields(self): + self.connection.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";', None) + + columns = [f"col{i}" for i in range(100)] + fields = " ,".join(f'"{field}" TEXT' for field in columns) + + queries = [ + self.table_src.drop(True), + self.table_dst.drop(True), + f"CREATE TABLE {self.table_src_name} (id uuid DEFAULT uuid_generate_v4 (), {fields})", + commit, + self.table_src.insert_rows([[f"{x * y}" for x in range(100)] for y in range(10)], columns=columns), + commit, + self.table_dst.create(self.table_src), + commit, + self.table_src.insert_rows([[1 for x in range(100)]], columns=columns), + commit, + ] + + for query in queries: + self.connection.query(query) + + a = TableSegment(self.connection, self.table_src.path, ("id",), extra_columns=tuple(columns)) + b = TableSegment(self.connection, self.table_dst.path, ("id",), extra_columns=tuple(columns)) + + differ = HashDiffer() + diff = list(differ.diff_tables(a, b)) + id_ = diff[0][1][0] + result = (id_,) + tuple("1" for x in range(100)) + self.assertEqual(diff, [("-", result)])