Skip to content

Bug Fix for Operand Shape Mismatch in BatchNorm Fusion (PyTorch) #1045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 18, 2024

Conversation

sei-rquartiano
Copy link
Contributor

@sei-rquartiano sei-rquartiano commented Jul 31, 2024

Main PR Information:

Description of Change

This is a bug fix to prevent erroring during batchnorm fusion of two layers in a PyTorch model (in the test case provided, a Conv1d followed by a BatchNorm1d. The fix involves adding a condition to check the 'data_format' attribute of the parent node in order to properly index its self.get_output_variable().shape array. There is supposed to be one bias term per output channel of the previous layer, so if the data format is 'channels_first' its shape tuple is accessed at index 0, and if its 'channels_last' its accessed at -1.

Type of change

  • Bug fix (non-breaking change that fixes an issue)

Tests

Test Configuration: test_batchnorm_pytorch.py

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have installed and run pre-commit on the files I edited or added.
  • I have added tests that prove my fix is effective or that my feature works.

Additional PR Information:

Bug:

When synthesizing a simple CNN model with one convolution->batchnorm->relu block written in PyTorch (see test file above), I get the following error:

File "/home/hls4ml-user/work/sei-rquartiano-hls4ml/test/bn_fusion_test.py", line 69, in <module>
    hls_model = convert_from_pytorch_model(
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/__init__.py", line 309, in convert_from_pytorch_model
    return pytorch_to_hls(config)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/converters/pytorch_to_hls.py", line 355, in pytorch_to_hls
    hls_model = ModelGraph(config, layer_list, inputs=input_layers)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 380, in __init__
    self.apply_flow(flow)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 442, in apply_flow
    self._apply_sub_flow(flow, applied_flows)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 451, in _apply_sub_flow
    self._apply_sub_flow(sub_flow, applied_flows)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/graph.py", line 454, in _apply_sub_flow
    applied_passes = optimize_model(self, flow.optimizers)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/optimizer/optimizer.py", line 318, in optimize_model
    res = opt.transform(model, node)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/hls4ml/model/optimizer/passes/bn_fuse.py", line 30, in transform
    fused_bias = bn_scale.data * parent_bias.data + bn_bias.data
ValueError: operands could not be broadcast together with shapes (2,) (1024,)

Cause:

There should be one bias term per output channel of the convolution layer, which is why bn_scale.data and bn_bias.data are vectors of length 2. However, parent_bias.data is of length 1024, which causes an error when multiplied by bn_scale.data. This is because the dimensions of the parent_node are assumed to be in channels_last format when indexing self.get_output_variable().shape[-1] (-1 referring to the last dimension of the shape output) in model/layers.py

While I do set channels_last_conversion="full" during config creation, it doesn't seem to actually do anything. This is confirmed by the hardcoding of the layer['data_format'] = 'channels_first' attribute in converters/convolution.py:14. I was actually able to find a channels last converter in model/optimizer/passes/convert_to_channels_last.py but as far as I can tell it isn't called anywhere in the PyTorch conversion process, even if I do set channels_last_conversion="full" when making the config. Also according to the comment on converters/convolution.py:14 this isn't changeable for PyTorch anyway. Maybe this is an in-progress hls4ml feature?

Fix:

I added a condition checking the 'data_format' attribute of the parent node. If it is 'channels_last,' the shape of the parent node is indexed at -1 just like before. But if it's 'channels_first' (which according to the converter hardcoding it will always be for PyTorch) it is indexed at 0. Now the add_bias() method of class Layers() looks like this:

def add_bias(self, quantizer=None):
    data = self.get_attr('bias_data', None)
    precision = None
    type_name = None
    if data is None:
        if self.attributes['data_format'] == "channels_first":
            data = np.zeros(self.get_output_variable().shape[0])
        elif self.attributes['data_format'] == "channels_last":
            data = np.zeros(self.get_output_variable().shape[-1])

        precision = IntegerPrecisionType(width=1, signed=False)
        type_name = 'bias{index}_t'
        quantizer = None  # Don't quantize non-existant bias

    self.add_weights_variable(
        name='bias', var_name='b{index}', type_name=type_name, precision=precision, data=data, quantizer=quantizer
    )

Whenever the channels_last converter is finished/called in PyTorch conversion, I would assume some conditional logic about whether it is actually called called during config will be required, hence me adding it now. After this fix, test file runs without error (see below)

$ python test/bn_fusion_test.py
2024-07-31 19:11:07.762539: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-07-31 19:11:08.368743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
class BatchNormModel(torch.nn.Module):
    def forward(self, x):
        # No stacktrace found for following nodes
        conv1 = self.conv1(x);  x = None
        bn1 = self.bn1(conv1);  conv1 = None
        relu1 = self.relu1(bn1);  bn1 = None
        return relu1
        
main dbg: placeholder
main dbg: call_module
main dbg: call_module
main dbg: call_module
main dbg: output
X Shape:  (2, 2, 1024)
y Shape:  (2, 2, 1024)
{'Model': {'Precision': 'ap_fixed<64,24>', 'ReuseFactor': 12, 'ChannelsLastConversion': 'full', 'TransposeOutputs': False, 'Strategy': 'Resource'}}
Interpreting Model ...
Topology:
Layer name: conv1, layer type: Conv1D, input shape: [[None, 2, 1024]]
Layer name: bn1, layer type: BatchNormalization, input shape: [[None, 2, 1024]]
Layer name: relu1, layer type: Activation, input shape: [[None, 2, 1024]]
Creating HLS model
WARNING: Changing pipeline style to "dataflow".

@JanFSchulte
Copy link
Contributor

Hi! Thanks for bringing this to our attention. The channels_last conversion is actually run for pytorch (and only for pytorch, in fact). It's also active in the case of your model. However, you have stumbled upon a case we overlookied, and that can't be addressed by this conversion since it is applied after the model is created and transposes the tensors in the model. This isn't possible in this case, since we just have a 1D tensor with the wrong size which can't be fixed in that way. I do think your solution is good.

However, I would ask for 3 things:

Thanks!

… and test case has been moved from standalone file to existing pytests
@sei-rquartiano
Copy link
Contributor Author

sei-rquartiano commented Aug 1, 2024

Hi JanFSchlute, thanks for getting back to me so quickly and clarifying the source of my issue. If you could point me towards where specifically the channels_last conversion takes place I'd greatly appreciate it (something tells me I'll need to know that in the future haha). That being said I'm glad the fix is still suitable.

As for the PR itself, I made the changes you requested. The original PR has been updated to match the desired format (with my original submission tacked at the bottom), the test case has been moved from a standalone script to the file you linked, and I ran the pre-commit. I also added a RaiseException for unsupported data_formats. Please let me know if there's any other changes you'd like me to make!

A note about the pytests-- I noticed in your original pytest for batchnorm you take the model through compilation and compare inference results (which is further than I had pushed the model in the original test file) so I tried to replicate that in my pytest. The model successfully compiles, however aborts on hls_prediction = hls_model.predict(fusion_data).reshape(pytorch_prediction.shape). When I try to compare the results of pytorch and hls inference outside pytest, I get the following AssertionError indicating that tolerances are exceeded.

Traceback (most recent call last):
  File "/home/hls4ml-user/work/sei-rquartiano-hls4ml/test/bn_fusion_test.py", line 84, in <module>
    np.testing.assert_allclose(pytorch_prediction, hls_prediction, rtol=0, atol=atol, verbose=True)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 1592, in assert_allclose
    assert_array_compare(compare, actual, desired, err_msg=str(err_msg),
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/hls4ml-user/miniconda3/envs/hls4ml/lib/python3.10/site-packages/numpy/testing/_private/utils.py", line 862, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=0, atol=0.005

Mismatched elements: 2033 / 4096 (49.6%)
Max absolute difference: 1.01677549
Max relative difference: 1055.02454578
 x: array([[[0.744443, 0.094531, 0.      , ..., 0.140032, 0.      ,
         0.514814],
        [0.532051, 0.444561, 0.066727, ..., 0.370584, 0.120806,...
 y: array([[[0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ],
        [0.     , 0.     , 0.     , ..., 0.     , 0.     , 0.     ]],

To reiterate-- the model gets through compilation without error and as far as I can tell this inaccuracy seems unrelated to the proposed bug fix. Just wanted to include the erroring steps in the pytest and bring it to your attention for transparency's sake. Happy to make a separate issue for this inaccuracy once the fix is merged, or keep working on it beforehand (whichever you prefer).

@JanFSchulte
Copy link
Contributor

Thanks for addressing my comments!

For the channels_last conversion, this is implemented as one of the optimizers that hls4ml runs after the initial parsing of every model, you can see it here in the list of optimizers that are applied https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/model/optimizer/__init__.py#L37. So you will not see an explicit call to it anywhere, but the code in https://github.com/fastmachinelearning/hls4ml/blob/main/hls4ml/model/optimizer/passes/convert_to_channels_last.py is applied to every node in the model graph.

As for the other issues, it looks to me like there is something going wrong that I would like to understand before merging this, as we can't merge in code with a failing pytest. I will have a look.

@JanFSchulte
Copy link
Contributor

Three things:

  • The crash of the pytests is due to large reuse factor in your configuration, it fails this assert in the code Assertion (CONFIG_T::reuse_factor <= CONFIG_T::filt_width * CONFIG_T::n_chan) && "This function is correct only for RF <= FILT_WIDTH * N_CHAN"'`
  • We don't have a transpose layer in hls4ml that works with streamed inputs. In that case, the inputs and outputs have to transformed to channels_last outside of hls4ml, and the channels_last_conversion="internal" option needs to be used
  • After fixing these two things, the pytests still fail because we seem to be having issues reproducing the exact behavior of batchnorm layers in pytorch with the HLS code we currently have. That will require some deeper investigation that will take some time.

@sei-rquartiano
Copy link
Contributor Author

sei-rquartiano commented Aug 5, 2024

Hi Jan, thanks again for the additional information about channels-last conversion and reuse factor. I have reduced the reuse factor like you recommended and pytests are no longer crashing. As for the deeper investigation into batchnorm layer behavior, please let me know if there's anything I should look for or try in order to help. Thanks again for looking into this!

@sei-rquartiano
Copy link
Contributor Author

Hi Jan, I've continued working on my model since we last spoke. I'm encountering similar channels-last-related issues in when trying to synthesize other PyTorch layers (AvgPool2d and shortcut connections). In all of these cases it seems like the source code is assuming the model to be in channels_last format, but for PyTorch channels_first is the default and goes unchanged regardless of whether I set channels_last_conversion="full" in config_from_pytorch_model()

Is there a larger issue with channels last conversion for pytorch models? I remember you mentioned that ChannelsLastConverter() is called during the synthesis process, but maybe its invocation isn't being recorded in the 'data_format' attribute properly? Or perhaps these layer-specific conversion functions are assuming it's been called already when in actuality it's called afterwards? Overall, is there a larger issue that may be causing these errors?

If you'd like me to make a separate git issue for these layer-specific errors please let me know, and if there is a larger issue you think may be worth looking into please let me know where. Thanks!

@JanFSchulte
Copy link
Contributor

Hi! Thanks for reporting further problems. I was away last week and haven't looked at this in more detail yet. I will do this week. In general hls4ml does indeed assume that all models are in channels_last format, and we do in principle convert all layers before the hls code is created and compiled. At first glance you are finding the edge cases that we didn't test for so far, so I need to do some bug fixing, but the expectation is that we can do this for all types of supported layers.

@JanFSchulte
Copy link
Contributor

I have opened another PR to fix this issue, seemed easier to just merge my changes: #1050

The BatchNorm works fine, the issue was that the pytorch model has to be switch to eval() mode when evaluating it. BatchNorm is one of the layers that changes behavior between training and evaluation mode, which was causing the differences in the pytests. Now everything checks out fine.

I will close here, for your other issues, can you open an issue and provide an example to reproduce? Thanks!

@JanFSchulte JanFSchulte reopened this Aug 13, 2024
@JanFSchulte
Copy link
Contributor

On second thought, this way we are loosing track of your contribution to the fix. So I am reopening this so you can incorporate the further changes I made in https://github.com/fastmachinelearning/hls4ml/pull/1050/files so we can merge your branch instead :)

@sei-rquartiano
Copy link
Contributor Author

Hi Jan, thanks so much for working with me on this! I just merged your pr into mine. Please reach out if there's anything else I need to do before this can be merged.

@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Aug 13, 2024
@@ -13,15 +13,24 @@
atol = 5e-3


@pytest.fixture(scope='module')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I accidentally committed this change. Can you undo it @sei-rquartiano

@pytest.mark.parametrize('io_type', ['io_parallel', 'io_stream'])
@pytest.mark.parametrize('backend', ['Vivado', 'Vitis', 'Quartus', 'Catapult'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I accidentally committed this change. Can you undo it @sei-rquartiano ?

@sei-rquartiano
Copy link
Contributor Author

Sure, no worries! To clarify, I'm adding scope='module' back to the pytest fixture and 'Catapult' back to the list of backend parameters now. Please let me know if this is incorrect or any other changes need to be made. Thanks!

@JanFSchulte
Copy link
Contributor

Those are the two changes I made accidentally, yes. Thanks a lot! Once that is done, I think this is ready to merge. And I'm looking forward to your issue on the AvgPool2D so we can make sure the framework works for you.

@JanFSchulte JanFSchulte removed the please test Trigger testing by creating local PR branch label Aug 14, 2024
@JanFSchulte JanFSchulte added the please test Trigger testing by creating local PR branch label Aug 14, 2024
@sei-rquartiano
Copy link
Contributor Author

sei-rquartiano commented Aug 14, 2024

Done! Thanks again for all your help. I'll look more into whats going on with AvgPool2D and shortcut connections and will make an issue when I have more information (or hopefully a solution haha).

Copy link
Contributor

@JanFSchulte JanFSchulte left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug fix verified to work, additional tests effective to ensure this type of issue will be caught next time.

@JanFSchulte
Copy link
Contributor

The pytests for this seem to have failed randomly in two cases, which looks to me like bad luck. I am considering just going ahead with merging this. @vloncar @jmitrevs Thoughts?

@JanFSchulte JanFSchulte added please test Trigger testing by creating local PR branch and removed please test Trigger testing by creating local PR branch labels Aug 15, 2024
Copy link
Contributor

@vloncar vloncar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! thanks @sei-rquartiano and @JanFSchulte

@vloncar vloncar merged commit 63acf34 into fastmachinelearning:main Aug 18, 2024
8 of 10 checks passed
@sei-rquartiano
Copy link
Contributor Author

Thank you @JanFSchulte and @vloncar for your help! I'm happy the bug fix has been useful and was able to be merged! @JanFSchulte, thanks for finding the cause of those pytest failures; I'm also glad a solution to that was found as well.

Im excited to keep working on this and make an issue/PR for AvgPool2d soon. I will be out of town next week but will continue working on this right when I return, so you can expect a submission early-mid September. Cheers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants