Skip to content

🐛 [Bug] #2881 regression #2991

Closed as not planned
Closed as not planned
@HolyWu

Description

@HolyWu

Bug Description

Since #2881, if the inference is performed in its own stream, the output randomly becomes all zeros.

cc: @gs-olive

To Reproduce

import torch
import torch.nn as nn
import torch_tensorrt


class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(x)


with torch.inference_mode():
    dtype = torch.half
    device = torch.device("cuda", 0)
    model = MyModule().eval().to(device)
    inputs = [torch_tensorrt.Input(shape=(1, 3, 5), dtype=dtype)]

    optimized_model = torch_tensorrt.compile(
        model,
        ir="dynamo",
        inputs=inputs,
        enabled_precisions={dtype},
        min_block_size=1,
        device=device,
    )

    for _ in range(10):
        new_input = torch.randn((1, 3, 5), dtype=dtype, device=device)

        eager_output = model(new_input)

        stream = torch.cuda.Stream(device=device)
        stream.wait_stream(torch.cuda.current_stream(device=device))
        with torch.cuda.stream(stream):
            trt_output_with_stream = optimized_model(new_input)
        torch.cuda.current_stream(device=device).wait_stream(stream)

        trt_output_without_stream = optimized_model(new_input)

        print("")
        print(f"{torch.allclose(eager_output, trt_output_with_stream)=}")
        print(f"{torch.allclose(eager_output, trt_output_without_stream)=}")
        print(f"{trt_output_with_stream=}")
        print(f"{trt_output_without_stream=}")
torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[1.4609, 1.7363, 2.2363, 0.5176, 1.1904],
         [0.0000, 0.0000, 0.9243, 0.4497, 0.0000],
         [0.3186, 0.0000, 0.0000, 0.2537, 0.0707]]], device='cuda:0',
       dtype=torch.float16)
trt_output_without_stream=tensor([[[1.4609, 1.7363, 2.2363, 0.5176, 1.1904],
         [0.0000, 0.0000, 0.9243, 0.4497, 0.0000],
         [0.3186, 0.0000, 0.0000, 0.2537, 0.0707]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.6279, 0.6396, 0.9829, 0.2437, 0.0855],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9268, 0.2942, 0.1624, 0.2323, 0.5444]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.4104, 0.9424, 0.9624, 0.0000, 0.6167],
         [0.2534, 0.2571, 0.0000, 0.0000, 0.0000],
         [0.4695, 0.0136, 0.9429, 0.2498, 0.6011]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.3899, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.1504, 1.1680, 1.5703, 0.7173, 0.0000],
         [0.4333, 0.1777, 2.0723, 1.0098, 0.0000]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.1924],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0889, 0.0027, 0.0000, 1.1680, 0.9487]]], device='cuda:0',
       dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.1924],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.0889, 0.0027, 0.0000, 1.1680, 0.9487]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0765, 0.4006, 0.0000, 2.0410],
         [2.2793, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6064, 0.0000, 0.6040, 0.4290, 0.0000]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.8057, 0.9116, 0.3110, 0.0000, 0.0000],
         [1.6016, 0.0000, 0.3757, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0048, 0.9331, 1.1475]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.7251, 0.0000, 0.0000, 0.9985],
         [0.0000, 1.3789, 0.6831, 0.0000, 0.7051],
         [0.0000, 0.0000, 0.0000, 0.6748, 0.0000]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0420],
         [0.3943, 0.0000, 0.0000, 0.5200, 1.4277]]], device='cuda:0',
       dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 1.0420],
         [0.3943, 0.0000, 0.0000, 0.5200, 1.4277]]], device='cuda:0',
       dtype=torch.float16)

torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[1.1484, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.5659, 0.0000, 1.3525, 0.5669],
         [1.7070, 0.0000, 2.0957, 0.0000, 1.1602]]], device='cuda:0',
       dtype=torch.float16)

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240709+cu124
  • PyTorch Version (e.g. 1.0): 2.5.0.dev20240709+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 24.04 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12
  • CUDA version: 12.4
  • GPU models and configuration: RTX 3050
  • Any other relevant information:

Additional context

Interestingly, it's only reproducible when using dtype=torch.half, but not for dtype=torch.float.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions