Skip to content

Commit 5cac79c

Browse files
authored
Merge pull request #848 from JanFSchulte/batchNormFix
Bug fix for named nn.Sequential in pytorch parser
2 parents 3d227e5 + 2b0ae80 commit 5cac79c

File tree

7 files changed

+348
-45
lines changed

7 files changed

+348
-45
lines changed

hls4ml/converters/pytorch/convolution.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler
1+
from hls4ml.converters.pytorch_to_hls import pytorch_handler
22
from hls4ml.converters.utils import compute_padding_1d_pytorch, compute_padding_2d_pytorch, parse_data_format
33

44

@@ -9,11 +9,16 @@ def parse_conv1d_layer(operation, layer_name, input_names, input_shapes, node, c
99
layer = {}
1010

1111
layer['name'] = layer_name
12+
layer['inputs'] = input_names
1213
layer['class_name'] = 'Conv1D'
1314
layer['data_format'] = 'channels_first' # Pytorch default (can't change)
1415

15-
layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
16-
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
16+
layer['weight_data'] = class_object.weight.data.numpy()
17+
if class_object.bias is not None:
18+
layer['bias_data'] = class_object.bias.data.numpy()
19+
else:
20+
layer['bias_data'] = None
21+
1722
# Input info
1823
(layer['in_width'], layer['n_chan']) = parse_data_format(
1924
input_shapes[0], 'channels_first'
@@ -54,11 +59,16 @@ def parse_conv2d_layer(operation, layer_name, input_names, input_shapes, node, c
5459
layer = {}
5560

5661
layer['name'] = layer_name
62+
layer['inputs'] = input_names
5763
layer['class_name'] = 'Conv2D'
5864
layer['data_format'] = 'channels_first' # Pytorch default (can't change)
5965

60-
layer['weight_data'] = get_weights_data(data_reader, layer['name'], 'weight')
61-
layer['bias_data'] = get_weights_data(data_reader, layer['name'], 'bias')
66+
layer['weight_data'] = class_object.weight.data.numpy()
67+
if class_object.bias is not None:
68+
layer['bias_data'] = class_object.bias.data.numpy()
69+
else:
70+
layer['bias_data'] = None
71+
6272
# Input info
6373
(layer['in_height'], layer['in_width'], layer['n_chan']) = parse_data_format(
6474
input_shapes[0], 'channels_first'

hls4ml/converters/pytorch/core.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from hls4ml.converters.pytorch_to_hls import get_weights_data, pytorch_handler
1+
from hls4ml.converters.pytorch_to_hls import pytorch_handler
22

33

44
@pytorch_handler('Linear')
@@ -9,8 +9,14 @@ def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, c
99

1010
layer['class_name'] = 'Dense'
1111
layer['name'] = layer_name
12+
layer['inputs'] = input_names
13+
14+
layer['weight_data'] = class_object.weight.data.numpy()
15+
if class_object.bias is not None:
16+
layer['bias_data'] = class_object.bias.data.numpy()
17+
else:
18+
layer['bias_data'] = None
1219

13-
layer['weight_data'], layer['bias_data'] = get_weights_data(data_reader, layer['name'], ['weight', 'bias'])
1420
if class_object is not None:
1521
layer['n_in'] = class_object.in_features
1622
layer['n_out'] = class_object.out_features
@@ -39,6 +45,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
3945
layer['class_name'] = operation
4046
layer['activation'] = layer['class_name']
4147
layer['name'] = layer_name
48+
layer['inputs'] = input_names
4249

4350
# if layer['class_name'] != 'Activation':
4451
# layer['activation'] = layer['class_name']
@@ -50,7 +57,7 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
5057
if layer['class_name'] == 'ELU':
5158
layer['activ_param'] = class_object.alpha
5259
if layer['class_name'] == 'PReLU':
53-
layer['alpha_data'] = get_weights_data(data_reader, layer['name'], 'weight')
60+
layer['alpha_data'] = class_object.weight.data.numpy()
5461
if layer['class_name'] == 'Threshold':
5562
layer['activ_param'] = class_object.threshold
5663
layer['class_name'] = 'ThresholdedReLU'
@@ -92,25 +99,26 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
9299
layer['class_name'] = 'BatchNormalization'
93100
layer['data_format'] = 'channels_first'
94101
layer['name'] = layer_name
102+
layer['inputs'] = input_names
95103

96104
# batchnorm para
97105
if node.op == 'call_module':
98106
layer['epsilon'] = class_object.eps
99107
layer['use_gamma'] = layer['use_beta'] = class_object.affine
100108

101109
if layer['use_gamma']:
102-
layer['gamma_data'] = get_weights_data(data_reader, layer['name'], 'weight')
110+
layer['gamma_data'] = class_object.weight.data.numpy()
103111
else:
104112
layer['gamma_data'] = 1
105113

106114
if layer['use_beta']:
107-
layer['beta_data'] = get_weights_data(data_reader, layer['name'], 'bias')
115+
layer['beta_data'] = class_object.bias.data.numpy()
108116
else:
109117
layer['beta_data'] = 0
110118

111-
layer['mean_data'], layer['variance_data'] = get_weights_data(
112-
data_reader, layer['name'], ['running_mean', 'running_var']
113-
)
119+
layer['mean_data'] = class_object.running_mean.data.numpy()
120+
layer['variance_data'] = class_object.running_var.data.numpy()
121+
114122
in_size = 1
115123
for dim in input_shapes[0][1:]:
116124
in_size *= dim

hls4ml/converters/pytorch/pooling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node,
2020
layer['class_name'] = 'AveragePooling2D'
2121

2222
layer['name'] = layer_name
23+
layer['inputs'] = input_names
2324
layer['data_format'] = 'channels_first' # Pytorch default (can't change)
2425
if node.op == 'call_module' and 'Avg' in operation:
2526
if class_object.count_include_pad:

hls4ml/converters/pytorch/reshape.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node,
1212
layer = {}
1313
layer['class_name'] = 'Reshape'
1414
layer['name'] = layer_name
15+
layer['inputs'] = input_names
1516

1617
layer['target_shape'] = [int(i) for i in node.args[1:]]
1718
# View can have -1 as one as the dimensions,
@@ -29,6 +30,60 @@ def parse_reshape_layer(operation, layer_name, input_names, input_shapes, node,
2930
return layer, output_shape
3031

3132

33+
@pytorch_handler('squeeze')
34+
def parse_squeeze_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
35+
assert operation == 'squeeze'
36+
37+
layer = {}
38+
layer['class_name'] = 'Reshape'
39+
layer['name'] = layer_name
40+
41+
if len(node.args) > 1 or len(node.kwargs) > 0: # 'dim' argument is specified
42+
output_shape = [i for i in input_shapes[0]]
43+
squeeze_dim = node.kwargs.get('dim', None)
44+
if squeeze_dim is None:
45+
squeeze_dim = node.args[1]
46+
if isinstance(squeeze_dim, tuple):
47+
for dim in squeeze_dim:
48+
del output_shape[dim]
49+
else:
50+
del output_shape[squeeze_dim]
51+
else:
52+
output_shape = [i for i in input_shapes[0] if i != 1]
53+
54+
layer['target_shape'] = output_shape.copy()
55+
if layer['target_shape'][0] is None:
56+
del layer['target_shape'][0]
57+
58+
return layer, output_shape
59+
60+
61+
@pytorch_handler('unsqueeze')
62+
def parse_unsqueeze_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
63+
assert operation == 'unsqueeze'
64+
65+
layer = {}
66+
layer['class_name'] = 'Reshape'
67+
layer['name'] = layer_name
68+
layer['inputs'] = input_names
69+
70+
# Unlike in 'squeeze' in 'unsqueeze', dim argument must exist
71+
output_shape = [i for i in input_shapes[0]]
72+
if len(node.args) > 1: # Specified as unsqueeze(x, n)
73+
squeeze_dim = node.args[1]
74+
else: # Specified as unsqueeze(x, dim=n)
75+
squeeze_dim = node.kwargs['dim']
76+
# insert() will add an element before the index, unsqueeze expects the location
77+
index = output_shape.index(output_shape[squeeze_dim]) # + 1
78+
output_shape.insert(index, 1)
79+
80+
layer['target_shape'] = output_shape.copy()
81+
if layer['target_shape'][0] is None:
82+
del layer['target_shape'][0]
83+
84+
return layer, output_shape
85+
86+
3287
@pytorch_handler('Flatten')
3388
def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
3489
assert operation == 'Flatten'

hls4ml/converters/pytorch_to_hls.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,12 @@ def __init__(self, config):
1616
def get_weights_data(self, layer_name, var_name):
1717
data = None
1818

19-
# Workaround for naming schme in nn.Sequential,
20-
# have to remove the prefix we previously had to add to make sure the tensors are found
21-
if 'layer_' in layer_name:
22-
layer_name = layer_name.split('layer_')[-1]
19+
tensorName = layer_name + '.' + var_name
2320

24-
# if a layer is reused in the model, torch.FX will append a "_n" for the n-th use
25-
# have to snap that off to find the tensors
26-
if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1:
27-
layer_name = '_'.join(layer_name.split('_')[:-1])
21+
if tensorName in self.state_dict:
22+
data = self.state_dict[tensorName].numpy()
2823

29-
if layer_name + '.' + var_name in self.state_dict:
30-
data = self.state_dict[layer_name + '.' + var_name].numpy()
31-
return data
32-
33-
else:
34-
return None
24+
return data
3525

3626

3727
class PyTorchFileReader(PyTorchModelReader): # Inherit get_weights_data method
@@ -98,7 +88,7 @@ def decorator(function):
9888
'elu': 'ELU',
9989
'prelu': 'PReLU',
10090
'sigmoid': 'Sigmoid',
101-
'layer_threshold': 'Threshold',
91+
'_threshold': 'Threshold',
10292
'softmax': 'Softmax',
10393
'max_pool1d': 'MaxPool1d',
10494
'max_pool2d': 'MaxPool2d',
@@ -134,6 +124,7 @@ def pytorch_to_hls(config):
134124
input_shapes = [list(reader.input_shape)]
135125
else:
136126
input_shapes = list(reader.input_shape)
127+
input_shapes = [list(shape) for shape in input_shapes]
137128

138129
model = reader.torch_model
139130

@@ -149,6 +140,9 @@ def pytorch_to_hls(config):
149140
# All supported layers
150141
supported_layers = get_supported_pytorch_layers() + skip_layers
151142

143+
# Map inputs of skipped and split (activation) layers
144+
inputs_map = {}
145+
152146
input_layers = []
153147

154148
# Output shape tracking
@@ -162,15 +156,16 @@ def pytorch_to_hls(config):
162156
n_inputs = 0
163157

164158
for node in traced_model.graph.nodes:
165-
# If part of a nn.Sequntial, the node name will start with an "_" which messes up the parsing
166-
if node.name[0] == '_':
167-
node.name = 'layer' + node.name
168-
169159
if node.op == 'call_module':
170160
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
171161
# where x is an integer numbering the elements of the Sequential
172162
if '.' in node.target:
173-
class_object = children[node.target.split('.')[0]][int(node.target.split('.')[1])]
163+
fqn_path = node.target.split('.')
164+
sub_children = dict(children[fqn_path[0]].named_children())
165+
for name in fqn_path[1:-1]:
166+
sub_children = dict(sub_children[name].named_children())
167+
sub_children[fqn_path[-1]]
168+
class_object = sub_children[fqn_path[-1]]
174169
else:
175170
class_object = children[node.target]
176171

@@ -189,6 +184,10 @@ def pytorch_to_hls(config):
189184
if pytorch_class == 'Sequential': # Ignore the mother module's class name
190185
continue
191186

187+
# Assuming only one input
188+
parent_input = [str(i) for i in node.args][0]
189+
inputs_map[layer_name] = inputs_map.get(parent_input, parent_input)
190+
192191
output_shapes[layer_name] = input_shapes[0]
193192

194193
continue
@@ -198,7 +197,7 @@ def pytorch_to_hls(config):
198197
layer_counter += 1
199198

200199
# parse info from class object
201-
input_names = [str(i) for i in node.args]
200+
input_names = [inputs_map.get(str(i), str(i)) for i in node.args]
202201
input_shapes = [output_shapes[str(i)] for i in node.args]
203202

204203
# for Conv layers
@@ -236,7 +235,7 @@ def pytorch_to_hls(config):
236235
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
237236
layer_list.insert(n_inputs, input_layer)
238237

239-
output_shapes[input_layer['name']] = input_shapes[n_inputs]
238+
output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
240239
input_layers.append(input_layer['name'])
241240
n_inputs += 1
242241

@@ -265,7 +264,7 @@ def pytorch_to_hls(config):
265264

266265
layer_counter += 1
267266

268-
input_names = [str(i) for i in node.all_input_nodes]
267+
input_names = [inputs_map.get(str(i), str(i)) for i in node.all_input_nodes]
269268
input_shapes = [list(output_shapes[str(i)]) for i in input_names]
270269

271270
# Process the layer
@@ -318,10 +317,7 @@ def pytorch_to_hls(config):
318317

319318
layer_counter += 1
320319

321-
if 'View' in operation:
322-
input_names = [str(node.args[0])]
323-
else:
324-
input_names = [str(i) for i in node.args]
320+
input_names = [inputs_map.get(str(i), str(i)) for i in node.all_input_nodes]
325321

326322
# Process the layer
327323
input_shapes = [list(output_shapes[str(i)]) for i in input_names]

0 commit comments

Comments
 (0)