Skip to content

Commit 5c36aad

Browse files
author
abdelabd
committed
special handling for Pyg dim_names wasn't necessary
1 parent ca49221 commit 5c36aad

File tree

2 files changed

+2
-8
lines changed

2 files changed

+2
-8
lines changed

hls4ml/converters/pyg_to_hls.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def pyg_to_hls(config):
5555
'class_name': 'InputLayer',
5656
'input_shape': input_shapes['NodeAttr'],
5757
'inputs': 'input',
58-
'dim_names': ['N_NODE', 'NODE_DIM'],
5958
'precision': fp_type
6059
}
6160
layer_list.append(NodeAttr_layer)
@@ -64,7 +63,6 @@ def pyg_to_hls(config):
6463
'class_name': 'InputLayer',
6564
'input_shape': input_shapes['EdgeAttr'],
6665
'inputs': 'input',
67-
'dim_names': ['N_EDGE', 'EDGE_DIM'],
6866
'precision': fp_type
6967
}
7068
layer_list.append(EdgeAttr_layer)
@@ -73,7 +71,6 @@ def pyg_to_hls(config):
7371
'class_name': 'InputLayer',
7472
'input_shape': input_shapes['EdgeIndex'],
7573
'inputs': 'input',
76-
'dim_names': ['N_EDGE', 'TWO'],
7774
'precision': int_type
7875
}
7976
layer_list.append(EdgeIndex_layer)

hls4ml/model/hls_layers.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -564,10 +564,7 @@ def initialize(self):
564564
shape = self.attributes['input_shape']
565565
if shape[0] is None:
566566
shape = shape[1:]
567-
try:
568-
dims = self.attributes['dim_names']
569-
except KeyError:
570-
dims = ['N_INPUT_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)]
567+
dims = ['N_INPUT_{}_{}'.format(i, self.index) for i in range(1, len(shape) + 1)]
571568
if self.index == 1:
572569
default_type_name = 'input_t'
573570
else:
@@ -2281,7 +2278,7 @@ def initialize(self):
22812278

22822279
aggr_name = f"layer{self.index}_out"
22832280
aggr_shape = [self.n_node, self.out_dim]
2284-
aggr_dims = ['N_NODE', f'LAYER{self.index}_OUT_DIM']
2281+
aggr_dims = ['N_NODE', f'LAYER{self.index}_OUT_DIM'] #todo: see if we can do without 'N_NODE'
22852282
self.add_output_variable(shape=aggr_shape, dim_names=aggr_dims, out_name=aggr_name, var_name=aggr_name,
22862283
precision=self.attributes.get('precision', None))
22872284

0 commit comments

Comments
 (0)