Skip to content

Commit 77e7024

Browse files
author
abdelabd
committed
added 'check_forward_dict'
1 parent eccb69d commit 77e7024

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

hls4ml/converters/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,14 @@ def convert_from_pytorch_model(model, input_shape, output_dir='my-hls-test', pro
270270

271271
return pytorch_to_hls(config)
272272

273+
def check_forward_dict(model, forward_dictionary):
274+
for key in forward_dictionary:
275+
assert(hasattr(model, key))
273276
def convert_from_pyg_model(model, n_node, node_dim, n_edge, edge_dim,
274277
forward_dictionary=None, activate_final=None,
275278
output_dir='my-hls-test', project_name='myproject',
276279
fpga_part='xcku115-flvb2104-2-i', clock_period=5, io_type='io_parallel', hls_config={}):
277-
280+
check_forward_dict(model, forward_dictionary)
278281
config = create_vivado_config(
279282
output_dir=output_dir,
280283
project_name=project_name,

hls4ml/converters/pyg/__init__.py

Whitespace-only changes.

hls4ml/converters/pyg_to_hls.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,12 @@ def pyg_to_hls(config):
116116
aggr_count = 0
117117
forward_dict_new = OrderedDict()
118118
for key, val in forward_dict.items():
119-
if val=="NodeBlock":
119+
if val == "NodeBlock":
120120
aggr_count += 1
121121
aggr_key = f"aggr{aggr_count}"
122122
aggr_val = "Aggregate"
123123
forward_dict_new[aggr_key] = aggr_val
124124
forward_dict_new[key] = val
125-
print(f"forward_dict: {forward_dict}")
126-
print(f"forward_dict_new: {forward_dict_new}")
127125

128126
# complete the layer list
129127
for i, (key, val) in enumerate(forward_dict_new.items()):

0 commit comments

Comments
 (0)