diff --git a/target_postgres/__init__.py b/target_postgres/__init__.py index 5283317..b643b9a 100644 --- a/target_postgres/__init__.py +++ b/target_postgres/__init__.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import argparse +import csv import io import json import os @@ -344,10 +345,14 @@ def flush_records(stream, records_to_load, row_count, db_sync, temp_dir=None): size_bytes = 0 csv_fd, csv_file = mkstemp(suffix='.csv', prefix=f'{stream}_', dir=temp_dir) - with open(csv_fd, 'w+b') as f: + with open(csv_fd, 'w') as csvfile: + writer = csv.DictWriter( + csvfile, + fieldnames=list(db_sync.flatten_schema.keys()), + extrasaction='ignore', + ) for record in records_to_load.values(): - csv_line = db_sync.record_to_csv_line(record) - f.write(bytes(csv_line + '\n', 'UTF-8')) + writer.writerow(db_sync.record_to_flattened_record(record)) size_bytes = os.path.getsize(csv_file) db_sync.load_csv(csv_file, row_count, size_bytes) diff --git a/target_postgres/db_sync.py b/target_postgres/db_sync.py index c04c1bf..08b8693 100644 --- a/target_postgres/db_sync.py +++ b/target_postgres/db_sync.py @@ -1,13 +1,15 @@ import json import sys -import psycopg2 -import psycopg2.extras -import inflection import re import uuid import itertools import time from collections.abc import MutableMapping +from typing import Dict + +import psycopg2 +import psycopg2.extras +import inflection from singer import get_logger @@ -344,15 +346,8 @@ def record_primary_key_string(self, record): raise exc return ','.join(key_props) - def record_to_csv_line(self, record): - flatten = flatten_record(record, self.flatten_schema, max_level=self.data_flattening_max_level) - return ','.join( - [ - json.dumps(flatten[name], ensure_ascii=False) - if name in flatten and (flatten[name] == 0 or flatten[name]) else '' - for name in self.flatten_schema - ] - ) + def record_to_flattened_record(self, record: Dict) -> Dict: + return flatten_record(record, self.flatten_schema, max_level=self.data_flattening_max_level) def load_csv(self, file, count, size_bytes): stream_schema_message = self.stream_schema_message @@ -367,7 +362,7 @@ def load_csv(self, file, count, size_bytes): temp_table = self.table_name(stream_schema_message['stream'], is_temporary=True) cur.execute(self.create_table_query(table_name=temp_table, is_temporary=True)) - copy_sql = "COPY {} ({}) FROM STDIN WITH (FORMAT CSV, ESCAPE '\\')".format( + copy_sql = "COPY {} ({}) FROM STDIN WITH (FORMAT CSV)".format( temp_table, ', '.join(self.column_names()) )