diff --git a/hls4ml/converters/pytorch/core.py b/hls4ml/converters/pytorch/core.py index e45e3c9fae..5a2caba1ea 100644 --- a/hls4ml/converters/pytorch/core.py +++ b/hls4ml/converters/pytorch/core.py @@ -109,9 +109,8 @@ def parse_batchnorm_layer(operation, layer_name, input_names, input_shapes, node layer['beta_data'] = 0 layer['mean_data'], layer['variance_data'] = get_weights_data( - data_reader, layer['name'], ['running_mean', 'running_variance'] + data_reader, layer['name'], ['running_mean', 'running_var'] ) - in_size = 1 for dim in input_shapes[0][1:]: in_size *= dim diff --git a/hls4ml/converters/pytorch_to_hls.py b/hls4ml/converters/pytorch_to_hls.py index ddddbc04c7..961fb735af 100644 --- a/hls4ml/converters/pytorch_to_hls.py +++ b/hls4ml/converters/pytorch_to_hls.py @@ -17,30 +17,11 @@ def __init__(self, config): def get_weights_data(self, layer_name, var_name): data = None - # Parameter mapping from pytorch to keras - torch_paramap = { - # Conv - 'kernel': 'weight', - # Batchnorm - 'gamma': 'weight', - # Activiation - 'alpha': 'weight', - 'beta': 'bias', - 'moving_mean': 'running_mean', - 'moving_variance': 'running_var', - } - # Workaround for naming schme in nn.Sequential, # have to remove the prefix we previously had to add to make sure the tensors are found if 'layer_' in layer_name: layer_name = layer_name.split('layer_')[-1] - if var_name not in list(torch_paramap.keys()) + ['weight', 'bias']: - raise Exception('Pytorch parameter not yet supported!') - - elif var_name in list(torch_paramap.keys()): - var_name = torch_paramap[var_name] - # if a layer is reused in the model, torch.FX will append a "_n" for the n-th use # have to snap that off to find the tensors if layer_name.split('_')[-1].isdigit() and len(layer_name.split('_')) > 1: diff --git a/test/pytest/test_batchnorm_pytorch.py b/test/pytest/test_batchnorm_pytorch.py new file mode 100644 index 0000000000..a7a0c80247 --- /dev/null +++ b/test/pytest/test_batchnorm_pytorch.py @@ -0,0 +1,43 @@ +from pathlib import Path + +import numpy as np +import pytest +import torch +from torch import nn + +import hls4ml + +test_root_path = Path(__file__).parent + +in_shape = 16 +atol = 5e-3 + + +@pytest.fixture(scope='module') +def data(): + np.random.seed(0) + X = np.random.rand(100, in_shape) + return X + + +@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream']) +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus']) +def test_batchnorm(data, backend, io_type): + model = nn.Sequential( + nn.BatchNorm1d(in_shape), + ).to() + model.eval() + + default_precision = 'ac_fixed<32, 1, true>' if backend == 'Quartus' else 'ac_fixed<32, 1>' + + config = hls4ml.utils.config_from_pytorch_model(model, default_precision=default_precision, granularity='name') + output_dir = str(test_root_path / f'hls4mlprj_batchnorm_{backend}_{io_type}') + hls_model = hls4ml.converters.convert_from_pytorch_model( + model, (None, in_shape), backend=backend, hls_config=config, io_type=io_type, output_dir=output_dir + ) + hls_model.compile() + + # Predict + pytorch_prediction = model(torch.Tensor(data)).detach().numpy() + hls_prediction = hls_model.predict(data) + np.testing.assert_allclose(pytorch_prediction, hls_prediction, rtol=0, atol=atol, verbose=True)