Skip to content

Commit e690afe

Browse files
authored
add option to not write tar.gz for oneAPI and Quartus (#1189)
* add option to not write tar.gz for oneAPI and Quartus * fix the docsting for initial config * ignore extra parameters in oneAPI initial config
1 parent 4d23e9f commit e690afe

File tree

4 files changed

+34
-6
lines changed

4 files changed

+34
-6
lines changed

hls4ml/backends/oneapi/oneapi_backend.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,30 @@ def get_default_flow(self):
129129
def get_writer_flow(self):
130130
return self._writer_flow
131131

132-
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel'):
132+
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', write_tar=False, **_):
133+
"""Create initial configuration of the oneAPI backend.
134+
135+
Args:
136+
part (str, optional): The FPGA part to be used. Defaults to 'Arria10'.
137+
clock_period (int, optional): The clock period. Defaults to 5.
138+
io_type (str, optional): Type of implementation used. One of
139+
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
140+
write_tar (bool, optional): If True, compresses the output directory into a .tar.gz file. Defaults to False.
141+
142+
Returns:
143+
dict: initial configuration.
144+
"""
145+
133146
config = {}
134147

135148
config['Part'] = part if part is not None else 'Arria10'
136149
config['ClockPeriod'] = clock_period
137150
config['IOType'] = io_type
138151
config['HLSConfig'] = {}
152+
config['WriterConfig'] = {
153+
# TODO: add namespace
154+
'WriteTar': write_tar,
155+
}
139156

140157
return config
141158

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,16 @@ def get_default_flow(self):
131131
def get_writer_flow(self):
132132
return self._writer_flow
133133

134-
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', **_):
134+
def create_initial_config(self, part='Arria10', clock_period=5, io_type='io_parallel', write_tar=False, **_):
135135
config = {}
136136

137137
config['Part'] = part if part is not None else 'Arria10'
138138
config['ClockPeriod'] = clock_period if clock_period is not None else 5
139139
config['IOType'] = io_type if io_type is not None else 'io_parallel'
140140
config['HLSConfig'] = {}
141+
config['WriterConfig'] = {
142+
'WriteTar': write_tar,
143+
}
141144

142145
return config
143146

hls4ml/writer/oneapi_writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -955,8 +955,12 @@ def write_tar(self, model):
955955
model (ModelGraph): the hls4ml model.
956956
"""
957957

958-
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
959-
archive.add(model.config.get_output_dir(), recursive=True)
958+
if model.config.get_writer_config().get('WriteTar', False):
959+
tar_path = model.config.get_output_dir() + '.tar.gz'
960+
if os.path.exists(tar_path):
961+
os.remove(tar_path)
962+
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
963+
archive.add(model.config.get_output_dir(), recursive=True)
960964

961965
def write_hls(self, model):
962966
print('Writing HLS project')

hls4ml/writer/quartus_writer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,8 +1345,12 @@ def write_tar(self, model):
13451345
model (ModelGraph): the hls4ml model.
13461346
"""
13471347

1348-
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
1349-
archive.add(model.config.get_output_dir(), recursive=True)
1348+
if model.config.get_writer_config().get('WriteTar', False):
1349+
tar_path = model.config.get_output_dir() + '.tar.gz'
1350+
if os.path.exists(tar_path):
1351+
os.remove(tar_path)
1352+
with tarfile.open(model.config.get_output_dir() + '.tar.gz', mode='w:gz') as archive:
1353+
archive.add(model.config.get_output_dir(), recursive=True)
13501354

13511355
def write_hls(self, model):
13521356
print('Writing HLS project')

0 commit comments

Comments
 (0)