Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Fix for more than 50 fields in Postgres #662

Merged
merged 1 commit into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions data_diff/sqeleton/databases/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from ..abcs.database_types import (
DbPath,
JSON,
Expand Down Expand Up @@ -92,6 +93,10 @@ def quote(self, s: str):
def to_string(self, s: str):
return f"{s}::varchar"

def concat(self, items: List[str]) -> str:
joined_exprs = " || ".join(items)
return f"({joined_exprs})"

def _convert_db_precision_to_digits(self, p: int) -> int:
# Subtracting 2 due to wierd precision issues in PostgreSQL
return super()._convert_db_precision_to_digits(p) - 2
Expand Down
44 changes: 44 additions & 0 deletions tests/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,47 @@ def test_uuid(self):
self.connection.query(self.table_src.drop(True))
self.connection.query(self.table_dst.drop(True))
mysql_conn.query(self.table_dst.drop(True))


class Test100Fields(unittest.TestCase):
def setUp(self) -> None:
self.connection = get_conn(db.PostgreSQL)

table_suffix = random_table_suffix()

self.table_src_name = f"src{table_suffix}"
self.table_dst_name = f"dst{table_suffix}"

self.table_src = table(self.table_src_name)
self.table_dst = table(self.table_dst_name)

def test_100_fields(self):
self.connection.query('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";', None)

columns = [f"col{i}" for i in range(100)]
fields = " ,".join(f'"{field}" TEXT' for field in columns)

queries = [
self.table_src.drop(True),
self.table_dst.drop(True),
f"CREATE TABLE {self.table_src_name} (id uuid DEFAULT uuid_generate_v4 (), {fields})",
commit,
self.table_src.insert_rows([[f"{x * y}" for x in range(100)] for y in range(10)], columns=columns),
commit,
self.table_dst.create(self.table_src),
commit,
self.table_src.insert_rows([[1 for x in range(100)]], columns=columns),
commit,
]

for query in queries:
self.connection.query(query)

a = TableSegment(self.connection, self.table_src.path, ("id",), extra_columns=tuple(columns))
b = TableSegment(self.connection, self.table_dst.path, ("id",), extra_columns=tuple(columns))

differ = HashDiffer()
diff = list(differ.diff_tables(a, b))
id_ = diff[0][1][0]
result = (id_,) + tuple("1" for x in range(100))
self.assertEqual(diff, [("-", result)])