Skip to content

Commit 490ac46

Browse files
authored
Merge pull request #847 from JanFSchulte/batchNormFix
Remove obsolete parameter mapping between pytorch and keras
2 parents c6174c3 + 39798af commit 490ac46

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,8 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node
109109
layer['beta_data'] = 0
110110

111111
layer['mean_data'], layer['variance_data'] = get_weights_data(
112-
data_reader, layer['name'], ['running_mean', 'running_variance']
112+
data_reader, layer['name'], ['running_mean', 'running_var']
113113
)
114-
115114
in_size = 1
116115
for dim in input_shapes[0][1:]:
117116
in_size *= dim

hls4ml/converters/pytorch_to_hls.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,11 @@ def __init__(self, config):
1717
def get_weights_data(self, layer_name, var_name):
1818
data = None
1919

20-
# Parameter mapping from pytorch to keras
21-
torch_paramap = {
22-
# Conv
23-
'kernel': 'weight',
24-
# Batchnorm
25-
'gamma': 'weight',
26-
# Activiation
27-
'alpha': 'weight',
28-
'beta': 'bias',
29-
'moving_mean': 'running_mean',
30-
'moving_variance': 'running_var',
31-
}
32-
3320
# Workaround for naming schme in nn.Sequential,
3421
# have to remove the prefix we previously had to add to make sure the tensors are found
3522
if 'layer_' in layer_name:
3623
layer_name = layer_name.split('layer_')[-1]
3724

38-
if var_name not in list(torch_paramap.keys()) + ['weight', 'bias']:
39-
raise Exception('Pytorch parameter not yet supported!')
40-
41-
elif var_name in list(torch_paramap.keys()):
42-
var_name = torch_paramap[var_name]
43-
4425
# if a layer is reused in the model, torch.FX will append a "_n" for the n-th use
4526
# have to snap that off to find the tensors
4627
if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1:

test/pytest/test_batchnorm_pytorch.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from pathlib import Path
2+
3+
import numpy as np
4+
import pytest
5+
import torch
6+
from torch import nn
7+
8+
import hls4ml
9+
10+
test_root_path = Path(__file__).parent
11+
12+
in_shape = 16
13+
atol = 5e-3
14+
15+
16+
@pytest.fixture(scope='module')
17+
def data():
18+
np.random.seed(0)
19+
X = np.random.rand(100, in_shape)
20+
return X
21+
22+
23+
@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
24+
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus'])
25+
def test_batchnorm(data, backend, io_type):
26+
model = nn.Sequential(
27+
nn.BatchNorm1d(in_shape),
28+
).to()
29+
model.eval()
30+
31+
default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>'
32+
33+
config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name')
34+
output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}')
35+
hls_model = hls4ml.converters.convert_from_pytorch_model(
36+
model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir
37+
)
38+
hls_model.compile()
39+
40+
# Predict
41+
pytorch_prediction = model(torch.Tensor(data)).detach().numpy()
42+
hls_prediction = hls_model.predict(data)
43+
np.testing.assert_allclose(pytorch_prediction, hls_prediction, rtol=0, atol=atol, verbose=True)

0 commit comments

Comments
 (0)