diff --git a/target_postgres/__init__.py b/target_postgres/__init__.py index 15a12b8..5dc7036 100644 --- a/target_postgres/__init__.py +++ b/target_postgres/__init__.py @@ -32,16 +32,13 @@ def persist_lines(config, lines): state = None schemas = {} key_properties = {} - headers = {} validators = {} + records_to_load = {} csv_files_to_load = {} row_count = {} stream_to_sync = {} - primary_key_exists = {} batch_size = config['batch_size'] if 'batch_size' in config else 100000 - now = datetime.now().strftime('%Y%m%dT%H%M%S') - # Loop over lines from stdin for line in lines: try: @@ -67,22 +64,18 @@ def persist_lines(config, lines): # Validate record validators[stream].validate(o['record']) - sync = stream_to_sync[stream] + primary_key_string = stream_to_sync[stream].record_primary_key_string(o['record']) + if not primary_key_string: + primary_key_string = 'RID-{}'.format(row_count[stream]) - primary_key_string = sync.record_primary_key_string(o['record']) - if stream not in primary_key_exists: - primary_key_exists[stream] = {} - if primary_key_string and primary_key_string in primary_key_exists[stream]: - flush_records(o, csv_files_to_load, row_count, primary_key_exists, sync) + if stream not in records_to_load: + records_to_load[stream] = {} - csv_line = sync.record_to_csv_line(o['record']) - csv_files_to_load[o['stream']].write(bytes(csv_line + '\n', 'UTF-8')) - row_count[o['stream']] += 1 - if primary_key_string: - primary_key_exists[stream][primary_key_string] = True + records_to_load[stream][primary_key_string] = o['record'] + row_count[stream] = len(records_to_load[stream]) - if row_count[o['stream']] >= batch_size: - flush_records(o, csv_files_to_load, row_count, primary_key_exists, sync) + if row_count[stream] >= batch_size: + flush_records(stream, records_to_load, row_count, stream_to_sync) state = None elif t == 'STATE': @@ -108,20 +101,24 @@ def persist_lines(config, lines): raise Exception("Unknown message type {} in message {}" .format(o['type'], o)) - for (stream_name, count) in row_count.items(): + for (stream, count) in row_count.items(): if count > 0: - stream_to_sync[stream_name].load_csv(csv_files_to_load[stream_name], count) + flush_records(stream, records_to_load, row_count, stream_to_sync) return state -def flush_records(o, csv_files_to_load, row_count, primary_key_exists, sync): - stream = o['stream'] - sync.load_csv(csv_files_to_load[stream], row_count[stream]) - row_count[stream] = 0 - primary_key_exists[stream] = {} - csv_files_to_load[stream] = TemporaryFile(mode='w+b') +def flush_records(stream, records_to_load, row_count, stream_to_sync): + sync = stream_to_sync[stream] + csv_file = TemporaryFile(mode='w+b') + for record in records_to_load[stream].values(): + csv_line = sync.record_to_csv_line(record) + csv_file.write(bytes(csv_line + '\n', 'UTF-8')) + + sync.load_csv(csv_file, row_count[stream]) + row_count[stream] = 0 + records_to_load[stream] = {} def main(): parser = argparse.ArgumentParser()