Skip to content

Commit 4615705

Browse files
committed
Changed the stream handling
1 parent b72a538 commit 4615705

File tree

4 files changed

+58
-15
lines changed

4 files changed

+58
-15
lines changed

core/runtime/execute_engine.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -283,9 +283,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
283283
auto current_device_id = -1;
284284
if (inputs.size() > 0) {
285285
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
286-
if (current_device_id != compiled_engine->current_device_id) {
287-
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
288-
}
286+
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
289287
}
290288

291289
{ // Engine Execution (execute on engine stream)
@@ -370,9 +368,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
370368
auto current_device_id = -1;
371369
if (inputs.size() > 0) {
372370
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
373-
if (current_device_id != compiled_engine->current_device_id) {
374-
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
375-
}
371+
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
376372
}
377373

378374
{ // Engine Execution (execute on engine stream)

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from torch_tensorrt.dynamo.partitioning._resource_partitioner import (
4444
resource_partition,
4545
)
46+
from torch_tensorrt.dynamo.runtime._stream_handler import handle_cuda_stream
4647
from torch_tensorrt.dynamo.utils import (
4748
deallocate_module,
4849
get_cpu_memory_usage,
@@ -950,6 +951,7 @@ def preserve_module_specs(
950951
if attr.startswith("_frozen_param"):
951952
delattr(gm, attr)
952953
trt_module = None
954+
953955
for name, _ in partitioned_module.named_children():
954956
submodule = getattr(partitioned_module, name)
955957
# filter on the GraphModule
@@ -1090,6 +1092,8 @@ def preserve_module_specs(
10901092
settings.use_fast_partitioner = True
10911093

10921094
dryrun_stats_display(dryrun_tracker, settings.dryrun)
1095+
if not settings.dryrun:
1096+
handle_cuda_stream(partitioned_module)
10931097

10941098
return partitioned_module
10951099

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,6 @@ def __init__(
172172
self._input_buffers: List[torch.Tensor] = []
173173
self._output_buffers: List[torch.Tensor] = []
174174
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
175-
self._engine_stream: torch.cuda.Stream = torch.cuda.current_stream()
176175
self.output_tensors: Optional[List[torch.Tensor]] = None
177176
self.sync_stream = True
178177

@@ -283,7 +282,6 @@ def setup_engine(self) -> None:
283282
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
284283
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
285284
# otherwise, use the caller stream and disable stream synchronization
286-
self._engine_stream = torch.cuda.current_stream()
287285

288286
self.initialized = True
289287
runtime = trt.Runtime(TRT_LOGGER)
@@ -564,10 +562,10 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
564562
self.cudagraph.enable_debug_mode()
565563

566564
with torch.cuda.graph(
567-
self.cudagraph, stream=self._engine_stream
565+
self.cudagraph, stream=torch.cuda.current_stream()
568566
):
569567
self.context.execute_async_v3(
570-
self._engine_stream.cuda_stream
568+
torch.cuda.current_stream().cuda_stream
571569
)
572570

573571
if self.profiling_enabled:
@@ -590,7 +588,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
590588
with warnings.catch_warnings():
591589
try:
592590
self.context.execute_async_v3(
593-
self._engine_stream.cuda_stream
591+
torch.cuda.current_stream().cuda_stream
594592
)
595593
except Warning as e:
596594
breakpoint()
@@ -650,10 +648,9 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
650648
else nullcontext()
651649
):
652650

653-
with torch.cuda.stream(self._engine_stream):
654-
self.context.execute_async_v3(
655-
self._engine_stream.cuda_stream
656-
) # The OutputAllocator is called by execute_async_v3()
651+
self.context.execute_async_v3(
652+
torch.cuda.current_stream().cuda_stream
653+
) # The OutputAllocator is called by execute_async_v3()
657654

658655
with (
659656
torch.autograd.profiler.record_function(
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import torch.fx
3+
4+
5+
def handle_cuda_stream(
6+
partitioned_module: torch.fx.GraphModule,
7+
) -> torch.fx.GraphModule:
8+
for node in partitioned_module.graph.nodes:
9+
if node.op == "placeholder":
10+
with partitioned_module.graph.inserting_before(node):
11+
partitioned_module.graph.call_function(
12+
torch.ops.tensorrt.enter_compute_stream
13+
)
14+
elif node.op == "output":
15+
with partitioned_module.graph.inserting_before(node):
16+
partitioned_module.graph.call_function(
17+
torch.ops.tensorrt.exit_compute_stream
18+
)
19+
20+
partitioned_module.graph.lint()
21+
partitioned_module.recompile()
22+
return partitioned_module
23+
24+
25+
@torch.library.custom_op("tensorrt::enter_compute_stream", mutates_args=())
26+
def enter_compute_stream() -> None:
27+
stream = torch.cuda.Stream()
28+
stream.wait_stream(torch.cuda.default_stream())
29+
torch.cuda.set_stream(stream)
30+
31+
32+
@torch.library.custom_op("tensorrt::exit_compute_stream", mutates_args=())
33+
def exit_compute_stream() -> None:
34+
stream = torch.cuda.current_stream()
35+
torch.cuda.default_stream().wait_stream(stream)
36+
torch.cuda.set_stream(torch.cuda.default_stream())
37+
38+
39+
@torch.library.register_fake("tensorrt::enter_compute_stream")
40+
def fake_enter_compute_stream() -> None:
41+
pass
42+
43+
44+
@torch.library.register_fake("tensorrt::exit_compute_stream")
45+
def fake_exit_compute_stream() -> None:
46+
pass

0 commit comments

Comments
 (0)