Skip to content

🐛 [Bug] bug with torchvision.transforms.GaussianBlur #1526

Closed
@mjack3

Description

@mjack3

Bug Description

Your test does not work correctly. I have two models using the torchvision.transforms.GaussianBlur and torch_executed_modules is not able to skip the operation in the second model.

To Reproduce

I have prepared a toy sample

import torch
import torchvision
import tensorrt
import torch_tensorrt

class ToyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.gaus = torchvision.transforms.GaussianBlur([33, 33], [4., 4.])
        self.conv = torch.nn.Conv2d(3, 64, (3,3))
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.gaus(x)
        return x

model1 = ToyModel().eval().to('cuda')
model2 = ToyModel().eval().to('cuda')

traced1 = torch.jit.trace(model1, torch.randn(1,3,224,224).to('cuda'))
traced2 = torch.jit.trace(model2, torch.randn(1,3,224,224).to('cuda'))

print(traced1.graph) ## Here we can see torchvision.transforms.transforms.GaussianBlur
print(traced2.graph) ## But here we can see torchvision.transforms.transforms.___torch_mangle_4.GaussianBlur

trt_model1 = torch_tensorrt.compile(
    traced1,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)

print("**** Done first! *****")

trt_model2 = torch_tensorrt.compile(
    traced2,
    "default",
    [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)],
    torch.float32,
    truncate_long_and_double = True,
    torch_executed_modules = ['torchvision.transforms.transforms.GaussianBlur']
)
print("***** Done second! *****")

Expected behavior

According to your api test, the second model should be correctly converted into tensorRT, skipping the no supported operation GaussianBlur.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): v1.1.0
  • PyTorch Version (e.g. 1.0): 1.11.0+cu113
  • OS (e.g., Linux): Ubuntu 22
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Python version: 10.0

Additional context

This example can also be done with the resnet18, similar to your api test

model1 = torchvision.models.resnet.resnet18(pretrained=True).eval().to('cuda')
model2 = torchvision.models.resnet.resnet18(pretrained=True).eval().to('cuda')

scripted_model1 = torch.jit.trace(model1, torch.randn(1, 3, 224, 224).cuda())
scripted_model2 = torch.jit.trace(model2, torch.randn(1, 3, 224, 224).cuda())

print(scripted_model1.graph) # Here we can see  torchvision.models.resnet.ResNet
print(scripted_model2.graph) # And here torchvision.models.resnet.___torch_mangle_194.ResNet

The graphs have different names and then, torch_executed_modules = torchvision.models.resnet.BasicBlock may not be properly skipped.

The question is

How could I manage to skip the GaussianBlur in the second model? It seems that they can not exist at the same time

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions