Closed as not planned
Description
Bug Description
Since #2881, if the inference is performed in its own stream, the output randomly becomes all zeros.
cc: @gs-olive
To Reproduce
import torch
import torch.nn as nn
import torch_tensorrt
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.relu = nn.ReLU()
def forward(self, x):
return self.relu(x)
with torch.inference_mode():
dtype = torch.half
device = torch.device("cuda", 0)
model = MyModule().eval().to(device)
inputs = [torch_tensorrt.Input(shape=(1, 3, 5), dtype=dtype)]
optimized_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=inputs,
enabled_precisions={dtype},
min_block_size=1,
device=device,
)
for _ in range(10):
new_input = torch.randn((1, 3, 5), dtype=dtype, device=device)
eager_output = model(new_input)
stream = torch.cuda.Stream(device=device)
stream.wait_stream(torch.cuda.current_stream(device=device))
with torch.cuda.stream(stream):
trt_output_with_stream = optimized_model(new_input)
torch.cuda.current_stream(device=device).wait_stream(stream)
trt_output_without_stream = optimized_model(new_input)
print("")
print(f"{torch.allclose(eager_output, trt_output_with_stream)=}")
print(f"{torch.allclose(eager_output, trt_output_without_stream)=}")
print(f"{trt_output_with_stream=}")
print(f"{trt_output_without_stream=}")
torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[1.4609, 1.7363, 2.2363, 0.5176, 1.1904],
[0.0000, 0.0000, 0.9243, 0.4497, 0.0000],
[0.3186, 0.0000, 0.0000, 0.2537, 0.0707]]], device='cuda:0',
dtype=torch.float16)
trt_output_without_stream=tensor([[[1.4609, 1.7363, 2.2363, 0.5176, 1.1904],
[0.0000, 0.0000, 0.9243, 0.4497, 0.0000],
[0.3186, 0.0000, 0.0000, 0.2537, 0.0707]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.6279, 0.6396, 0.9829, 0.2437, 0.0855],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.9268, 0.2942, 0.1624, 0.2323, 0.5444]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.4104, 0.9424, 0.9624, 0.0000, 0.6167],
[0.2534, 0.2571, 0.0000, 0.0000, 0.0000],
[0.4695, 0.0136, 0.9429, 0.2498, 0.6011]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.3899, 0.0000, 0.0000, 0.0000, 0.0000],
[1.1504, 1.1680, 1.5703, 0.7173, 0.0000],
[0.4333, 0.1777, 2.0723, 1.0098, 0.0000]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.1924],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.0889, 0.0027, 0.0000, 1.1680, 0.9487]]], device='cuda:0',
dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 1.1924],
[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[1.0889, 0.0027, 0.0000, 1.1680, 0.9487]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0765, 0.4006, 0.0000, 2.0410],
[2.2793, 0.0000, 0.0000, 0.0000, 0.0000],
[1.6064, 0.0000, 0.6040, 0.4290, 0.0000]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.8057, 0.9116, 0.3110, 0.0000, 0.0000],
[1.6016, 0.0000, 0.3757, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0048, 0.9331, 1.1475]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.7251, 0.0000, 0.0000, 0.9985],
[0.0000, 1.3789, 0.6831, 0.0000, 0.7051],
[0.0000, 0.0000, 0.0000, 0.6748, 0.0000]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=True
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0420],
[0.3943, 0.0000, 0.0000, 0.5200, 1.4277]]], device='cuda:0',
dtype=torch.float16)
trt_output_without_stream=tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.0000, 1.0420],
[0.3943, 0.0000, 0.0000, 0.5200, 1.4277]]], device='cuda:0',
dtype=torch.float16)
torch.allclose(eager_output, trt_output_with_stream)=False
torch.allclose(eager_output, trt_output_without_stream)=True
trt_output_with_stream=tensor([[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]]], device='cuda:0', dtype=torch.float16)
trt_output_without_stream=tensor([[[1.1484, 0.0000, 0.0000, 0.0000, 0.0000],
[0.0000, 0.5659, 0.0000, 1.3525, 0.5669],
[1.7070, 0.0000, 2.0957, 0.0000, 1.1602]]], device='cuda:0',
dtype=torch.float16)
Environment
- Torch-TensorRT Version (e.g. 1.0.0): 2.5.0.dev20240709+cu124
- PyTorch Version (e.g. 1.0): 2.5.0.dev20240709+cu124
- CPU Architecture: x86_64
- OS (e.g., Linux): Ubuntu 24.04 LTS
- How you installed PyTorch (
conda
,pip
,libtorch
, source): pip - Build command you used (if compiling from source):
- Are you using local sources or building from archives:
- Python version: 3.12
- CUDA version: 12.4
- GPU models and configuration: RTX 3050
- Any other relevant information:
Additional context
Interestingly, it's only reproducible when using dtype=torch.half
, but not for dtype=torch.float
.