Skip to content

InstanceNorm always has training attribute set to True #1262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
yuanyao-nv opened this issue Feb 2, 2024 · 15 comments
Closed

InstanceNorm always has training attribute set to True #1262

yuanyao-nv opened this issue Feb 2, 2024 · 15 comments

Comments

@yuanyao-nv
Copy link
Contributor

I'm exporting a small ONNX model consisting of just a InstanceNorm2d op. The exported model contains a BatchNorm node with the training attribute set to True.

image

Here's my script:

class Model(nn.Module):
  def __init__(self):
      super().__init__()
      self.instancenorm = torch.nn.InstanceNorm2d(100)

  def forward(self, tensor_x: torch.Tensor):
      output = self.instancenorm(tensor_x)
      return output

def Dataloader():
    yield torch.randn(20, 100, 35, 45).cuda()

model = Model()
data = next(Dataloader())

export_output = torch.onnx.dynamo_export(
    model.eval().to('cuda'),
    data,
)
export_output.save('instancenorm_dynamo.onnx')

If I understand it correctly, the origin of this issue should be in the torch repo since onnxscript just takes whatever value is passed for the training input of aten_native_batch_norm()?

Thanks.

@yuanyao-nv yuanyao-nv changed the title InstanceNorm always have training attribute set to True InstanceNorm always has training attribute set to True Feb 2, 2024
@justinchuby
Copy link
Collaborator

I think you are right that it should be PyTorch's issue setting the parameter. Could you print the fx graph, and/or try exporting with torch.export.export()?

@yuanyao-nv
Copy link
Contributor Author

If I print with

exported_program = torch.export.export(model.eval().to('cuda'), (data,))
print(exported_program)

I get

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, l_tensor_x_: "f32[20, 100, 35, 45]"):
            # File: /ws/dynamo_exporter_workflow/export_instancenorm.py:24, code: output = self.instancenorm(tensor_x)
            view: "f32[1, 2000, 35, 45]" = torch.ops.aten.view.default(l_tensor_x_, [1, 2000, 35, 45]);  l_tensor_x_ = None
            _native_batch_norm_legit = torch.ops.aten._native_batch_norm_legit.no_stats(view, None, None, True, 0.1, 1e-05);  view = None
            getitem: "f32[1, 2000, 35, 45]" = _native_batch_norm_legit[0];  _native_batch_norm_legit = None
            view_1: "f32[20, 100, 35, 45]" = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]);  getitem = None
            return (view_1,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='l_tensor_x_'), target=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='view_1'), target=None)])
Range constraints: {}
Equality constraints: []

If I print with

gm, _ = torch._dynamo.export(model, data, aten_graph=True)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()

I get

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[20, 100, 35, 45]"):
        # No stacktrace found for following nodes
        view: "f32[1, 2000, 35, 45]" = torch.ops.aten.view.default(arg0_1, [1, 2000, 35, 45]);  arg0_1 = None
        empty: "u8[0]" = torch.ops.aten.empty.memory_format([0], dtype = torch.uint8, layout = torch.strided, device = device(type='cuda', index=0))
        _native_batch_norm_legit = torch.ops.aten._native_batch_norm_legit.no_stats(view, None, None, True, 0.1, 1e-05);  view = None
        getitem: "f32[1, 2000, 35, 45]" = _native_batch_norm_legit[0]
        getitem_1: "f32[2000]" = _native_batch_norm_legit[1]
        getitem_2: "f32[2000]" = _native_batch_norm_legit[2];  _native_batch_norm_legit = None
        view_1: "f32[20, 100, 35, 45]" = torch.ops.aten.view.default(getitem, [20, 100, 35, 45]);  getitem = None
        return view_1

So it does look like the training flag is set to True in the fx graph.

@justinchuby
Copy link
Collaborator

@BowenBao I remember you looked at this issue before. Do you have anything to add?

@BowenBao
Copy link
Contributor

BowenBao commented Feb 5, 2024

There are several problems here. The ultimate main goal is to match the behavior of the exported onnx with pytorch.

training in _native_batch_norm_legit is an overloaded term, it controls how the op handles the running mean and variance. The exported onnx tries to match the behavior.

@yuanyao-nv
Copy link
Contributor Author

yuanyao-nv commented Feb 5, 2024

@BowenBao Shouldn't calling model.eval() turn off training-related parameters?
If training is overloaded, what should be the correct thing to do in order to set the training attr to false in the exported models?

@BowenBao
Copy link
Contributor

BowenBao commented Feb 6, 2024

In this case not enough. Probably need changing track_running_stats which affects training on aten level.

https://github.com/pytorch/pytorch/blob/fd0bf96c2b9aea46f0597ba6fef9b896f5b874bb/torch/nn/modules/instancenorm.py#L38

However, it might have unwanted effects on the operator's behavior. I feel this is a legacy issue in pytorch.
This could potentially be an optimization/rewrite pass in onnx-rewriter.

@yuanyao-nv
Copy link
Contributor Author

Seeing this problem for BatchNorm as well. I guess it's affecting all normalization ops that are done this way in pytorch.

@BowenBao
Copy link
Contributor

@yuanyao-nv could you provide the sample repro for BatchNorm?

Back to InstanceNorm, I'm curious if there were unwanted decompositions happening during dynamo tracing. Normally we should be receiving aten::instance_norm instead of aten::batch_norm.

@BowenBao
Copy link
Contributor

Turns out aten::instance_norm is another one of those normally unskippable decomposition. It is slightly complicated since there are mutations involved when updating running mean & var.
We can skip the decomp in exporter for the non-mutation version and add proper onnxscript torchlib export for it.

@yuanyao-nv
Copy link
Contributor Author

yuanyao-nv commented Feb 15, 2024

@BowenBao A BatchNorm example can be found in this model:

import torch
from monai.networks.nets import VarAutoEncoder

def VarAutoEncoderDataloader():
    yield torch.randn(1,1,32,32).to("cuda")

model = lambda : VarAutoEncoder(dimensions=2, in_shape=(1, 32, 32), out_channels=1, latent_size=2, channels=(4, 8), strides=(2, 2))
data = next(VarAutoEncoderDataloader())

model = model().eval().to('cuda')

export_output = torch.onnx.dynamo_export(
    model,
    data,
)
export_output.save('Clara_VarAutoEncoder_dynamo.onnx')

Env:
NGC container: docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.09-py3
Update package versions to following:
torch 2.3.0.dev20240212+cu121
onnxscript 0.1.0.dev20240203
install MONAI: pip3 install git+https://github.com/Project-MONAI/[email protected]#egg=MONAI

I can also upload the model if you need.

@BowenBao
Copy link
Contributor

@yuanyao-nv thanks for repro, this is also originated from InstanceNorm. So hopefully the issue is limited only to instance norm.

VarAutoEncoder(
  (encode): Sequential(
    (encode_0): Convolution(
      (conv): Conv2d(1, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (adn): ADN(
        (N): InstanceNorm2d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (A): PReLU(num_parameters=1)
      )
    )
    (encode_1): Convolution(
      (conv): Conv2d(4, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (adn): ADN(
        (N): InstanceNorm2d(8, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (A): PReLU(num_parameters=1)
      )
    )
  )
  (intermediate): Identity()
  (decode): Sequential(
    (decode_0): Sequential(
      (conv): Convolution(
        (conv): ConvTranspose2d(8, 4, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (adn): ADN(
          (N): InstanceNorm2d(4, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
          (A): PReLU(num_parameters=1)
        )
      )
    )
    (decode_1): Sequential(
      (conv): Convolution(
        (conv): ConvTranspose2d(4, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      )
    )
  )
  (mu): Linear(in_features=512, out_features=2, bias=True)
  (logvar): Linear(in_features=512, out_features=2, bias=True)
  (decodeL): Linear(in_features=2, out_features=512, bias=True)
)

@yuanyao-nv
Copy link
Contributor Author

@BowenBao That's good to know. Can you please also share the command you used for this printout?

@BowenBao
Copy link
Contributor

...
model = model().eval().to('cuda')
print(model)

BowenBao added a commit that referenced this issue Feb 29, 2024
- Fixes #1280, #1262. Avoid exporting as onnx::BatchNormalization with training=True.
- Fixes mismatch in unittest.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Feb 29, 2024
- Fixes #1280, #1262. Avoid exporting as onnx::BatchNormalization with training=True.
- Fixes mismatch in unittest.

ghstack-source-id: ad17e9e
Pull Request resolved: #1284
BowenBao added a commit that referenced this issue Feb 29, 2024
- Fixes #1280, #1262. Avoid exporting as onnx::BatchNormalization with training=True.
- Fixes mismatch in unittest.

[ghstack-poisoned]
BowenBao added a commit that referenced this issue Feb 29, 2024
- Fixes #1280, #1262. Avoid exporting as onnx::BatchNormalization with training=True.
- Fixes mismatch in unittest.

ghstack-source-id: db69c95
Pull Request resolved: #1284
BowenBao added a commit that referenced this issue Feb 29, 2024
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at
bottom):
* __->__ #1284

- Fixes #1280, #1262. Avoid exporting as onnx::BatchNormalization with
training=True.
- Fixes mismatch in unittest.
@gramalingam
Copy link
Collaborator

How about these two lines ? Shouldn't they be fixed ? Can we just use the return value of BatchNorm in the previous line for the mean and var?

@BowenBao
Copy link
Contributor

BowenBao commented Mar 5, 2024

This issue is now fixed by #1284 and pytorch/pytorch#120866

How about these two lines ? Shouldn't they be fixed ? Can we just use the return value of BatchNorm in the previous line for the mean and var?

We can keep tracking the batchnorm outputs issue in #1256, since now they are irrelevant to instance norm after the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants