diff --git a/target_postgres/db_sync.py b/target_postgres/db_sync.py index c04c1bf..0d563d0 100644 --- a/target_postgres/db_sync.py +++ b/target_postgres/db_sync.py @@ -181,6 +181,14 @@ def stream_name_to_dict(stream_name, separator='-'): } +def csv_quote(s): + if s is None: + return "" + if isinstance(s, int): + return str(s) + return '"' + str(s).replace("\\", "\\\\").replace('"', '\\"') + '"' + + # pylint: disable=too-many-public-methods,too-many-instance-attributes class DbSync: def __init__(self, connection_config, stream_schema_message=None): @@ -347,11 +355,7 @@ def record_primary_key_string(self, record): def record_to_csv_line(self, record): flatten = flatten_record(record, self.flatten_schema, max_level=self.data_flattening_max_level) return ','.join( - [ - json.dumps(flatten[name], ensure_ascii=False) - if name in flatten and (flatten[name] == 0 or flatten[name]) else '' - for name in self.flatten_schema - ] + csv_quote(flatten.get(name)) for name in self.flatten_schema ) def load_csv(self, file, count, size_bytes): diff --git a/tests/unit/test_db_sync.py b/tests/unit/test_db_sync.py index 85bb503..b78d1a6 100644 --- a/tests/unit/test_db_sync.py +++ b/tests/unit/test_db_sync.py @@ -1,4 +1,5 @@ import unittest +import unittest.mock import target_postgres @@ -320,3 +321,61 @@ def test_flatten_record_with_flatten_schema(self): for idx, (should_use_flatten_schema, record, expected_output) in enumerate(test_cases): output = flatten_record(record, flatten_schema if should_use_flatten_schema else None) assert output == expected_output + + def test_record_to_csv_line(self): + dbsync = unittest.mock.MagicMock() + dbsync.flatten_schema = { + "c_pk": {"type": ["null", "integer"]}, + "c_varchar": {"type": ["null", "string"]}, + "c_int": {"type": ["null", "integer"]}} + dbsync.data_flattening_max_level = 0 + + test_cases = [ + ( + { + "c_pk": 123, + "c_varchar": "hello", + "c_int": 456, + }, + '123,"hello",456', + ), + ( + { + "c_pk": 999, + "c_varchar": "hello\nworld", + "c_int": None, + }, + '999,"hello\nworld",', + ), + ( + { + "c_pk": 1, + "c_varchar": 'some "quotes" and \\backslashes\\', + "c_int": 555, + }, + '1,"some \\"quotes\\" and \\\\backslashes\\\\",555', + ), + ( + { + "c_pk": 1, + "c_varchar": "", + "c_int": 2, + }, + '1,"",2', + ), + ( + { + "c_pk": 1, + "c_varchar": None, + "c_int": 2, + }, + '1,,2', + ), + ( + {}, + ',,', + ), + ] + + for idx, (record, expected_output) in enumerate(test_cases): + assert target_postgres.db_sync.DbSync.record_to_csv_line(dbsync, record) == expected_output