Skip to content

Commit 62e7b82

Browse files
committed
feat: Utilize non-default stream for runtimes
- Add support for non-default streams
1 parent 0155e2f commit 62e7b82

File tree

6 files changed

+55
-26
lines changed

6 files changed

+55
-26
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cuda_runtime.h>
44
#include "NvInfer.h"
5+
#include "c10/cuda/CUDAStream.h"
56
#include "torch/csrc/jit/frontend/function_schema_parser.h"
67
#include "torch/cuda.h"
78

@@ -66,6 +67,10 @@ TRTEngine::TRTEngine(
6667
multi_gpu_device_check();
6768
set_rt_device(device_info);
6869

70+
// Set active stream to high-priority, non-default stream
71+
active_stream = c10::cuda::getStreamFromPool(true, device_info.id);
72+
c10::cuda::setCurrentCUDAStream(active_stream);
73+
6974
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
7075

7176
name = slugify(mod_name);

core/runtime/TRTEngine.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "ATen/core/function_schema.h"
1010
#include "ATen/cuda/CUDAGraph.h"
1111
#include "NvInfer.h"
12+
#include "c10/cuda/CUDAStream.h"
1213
#include "torch/custom_class.h"
1314

1415
#include "core/runtime/TRTEngineProfiler.h"
@@ -65,6 +66,7 @@ struct TRTEngine : torch::CustomClassHolder {
6566

6667
// CUDAGraph-Related Functionality
6768
at::cuda::CUDAGraph cudagraph = {};
69+
at::cuda::CUDAStream active_stream = c10::cuda::getDefaultCUDAStream();
6870
std::vector<at::Tensor> input_buffers = {};
6971
std::vector<at::Tensor> output_buffers = {};
7072
std::string shape_key;

core/runtime/execute_engine.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#include "c10/cuda/CUDAGuard.h"
21
#include "c10/cuda/CUDAStream.h"
32

43
#include "torch/csrc/jit/runtime/custom_operator.h"
@@ -64,9 +63,20 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
6463
// invalidate the existing cudagraphs object
6564

6665
// Populate the shape key for the inputs
66+
// x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
6767
std::stringstream new_shape_key_ss;
6868
for (auto input : inputs) {
69-
new_shape_key_ss << input.sizes();
69+
new_shape_key_ss << "(";
70+
auto sizes = input.sizes();
71+
auto rank = input.sizes().size();
72+
for (auto i = 0; i < rank; i++) {
73+
new_shape_key_ss << sizes[i];
74+
// For all but the final dimension in the shape key, add comma separator
75+
if (i < rank - 1) {
76+
new_shape_key_ss << ",";
77+
}
78+
}
79+
new_shape_key_ss << ")";
7080
}
7181

7282
auto new_shape_key = new_shape_key_ss.str();
@@ -128,6 +138,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
128138
select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible);
129139
set_rt_device(device);
130140

141+
// Update active stream based on new device
142+
compiled_engine->active_stream = c10::cuda::getStreamFromPool(true, device.id);
143+
c10::cuda::setCurrentCUDAStream(compiled_engine->active_stream);
144+
131145
// Target device is new device
132146
target_device += std::to_string(device.id);
133147

@@ -157,6 +171,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
157171
}
158172
}
159173

174+
// this is a buffer to store shape tensor input addresses throughout the runtime scope
175+
std::list<std::vector<int32_t>> inputShapeTensorValues;
160176
{
161177
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
162178
if (compiled_engine->profile_execution) {
@@ -252,23 +268,18 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
252268

253269
if (!CUDAGRAPHS_MODE) {
254270
// If not in cudagraphs mode, proceed with enqueueV3 as normal
255-
c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(inputs[0].device().index());
256-
compiled_engine->exec_ctx->enqueueV3(stream);
271+
compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream);
257272
} else if (need_cudagraphs_record) {
258273
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
259274

260275
// Cudagraphs cannot record on the default stream, so use an alternate
261276
c10::cuda::CUDAStream stream = c10::cuda::getStreamFromPool(true, inputs[0].device().index());
262-
c10::cuda::CUDAStreamGuard guard(stream);
263-
compiled_engine->exec_ctx->enqueueV3(stream);
277+
compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream);
264278

265279
compiled_engine->cudagraph.capture_begin();
266-
compiled_engine->exec_ctx->enqueueV3(stream);
280+
compiled_engine->exec_ctx->enqueueV3(compiled_engine->active_stream);
267281
compiled_engine->cudagraph.capture_end();
268282

269-
// Reset the stream to its original setting
270-
guard.reset_stream(guard.original_stream());
271-
272283
} else {
273284
// If the cudagraph has already been recorded, copy the input buffers and replay it
274285
for (auto i = 0; i < inputs.size(); i++) {

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ def __init__(
4646
self.input_buffers: List[torch.Tensor] = []
4747
self.output_buffers: List[torch.Tensor] = []
4848
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
49-
# {shape: cudagraph}
50-
# limitation on CG
49+
self.active_stream: Optional[torch.cuda.Stream] = None
50+
51+
# TODO: Make the below a Dictionary {shape: cudagraph}
5152
self.shape_key: Optional[str] = None
5253

5354
# See https://github.com/pytorch/pytorch/blob/acfe237a71af609e837a34bb38048aa8acb8eb4d/torch/cuda/graphs.py#L92-L98
@@ -97,6 +98,10 @@ def _initialize(self) -> None:
9798
self.cudagraph = torch.cuda.CUDAGraph()
9899
self.graph_capturer = torch.cuda.graphs.graph(self.cudagraph)
99100

101+
# Set the active stream using the current device, with a high priority flag
102+
self.active_stream = torch.cuda.Stream(torch.cuda.current_device(), priority=-1)
103+
torch.cuda.set_stream(self.active_stream)
104+
100105
def _check_initialized(self) -> None:
101106
if not self.initialized:
102107
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
@@ -185,9 +190,15 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
185190
self.target_device_id,
186191
self.target_device_properties,
187192
)
193+
194+
# Update current device
188195
device = torch.device(device_id)
189196
torch.cuda.set_device(device_id)
190197

198+
# Update current stream
199+
self.active_stream = torch.cuda.Stream(device, priority=-1)
200+
torch.cuda.set_stream(self.active_stream)
201+
191202
contiguous_inputs = [
192203
tensor.to(device) for tensor in contiguous_inputs
193204
]
@@ -306,21 +317,19 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
306317
):
307318

308319
if not cudagraphs_enabled:
309-
self.context.execute_async_v3(
310-
torch.cuda.current_stream().cuda_stream
311-
)
320+
self.context.execute_async_v3(self.active_stream)
312321

313322
elif need_cudagraphs_record:
314323
self.input_buffers = list(contiguous_inputs)
315324
self.output_buffers = list(outputs)
316325

317-
current_stream = self.graph_capturer.capture_stream
326+
graph_capturer_stream = self.graph_capturer.capture_stream
318327

319-
self.context.execute_async_v3(current_stream.cuda_stream)
320-
current_stream.synchronize()
328+
self.context.execute_async_v3(graph_capturer_stream.cuda_stream)
329+
graph_capturer_stream.synchronize()
321330

322331
with self.graph_capturer:
323-
self.context.execute_async_v3(current_stream.cuda_stream)
332+
self.context.execute_async_v3(graph_capturer_stream.cuda_stream)
324333

325334
else:
326335
for idx, input_tensor in enumerate(inputs):
@@ -377,8 +386,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
377386
"""
378387
# Representation of input shapes to a given model
379388
# Shapes are concatenated as so:
380-
# x: (3, 4), y: (4, 5) --> Key: (3, 4)(4, 5)
381-
new_shape_key = "".join(str(tuple(t.shape)) for t in inputs)
389+
# x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
390+
new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs)
382391

383392
# If the new shape key differs from the existing one,
384393
# invalidate the old shape key and remove the CUDAGraph

py/torch_tensorrt/runtime/cudagraphs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
import torch
66

7-
if find_spec("torch_tensorrt._C") is not None:
7+
import torch_tensorrt
8+
9+
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
810
_PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode()
911
else:
1012
_PY_RT_CUDAGRAPHS = False

tests/py/dynamo/runtime/test_cudagraphs.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88
from ..testing_utilities import DECIMALS_OF_AGREEMENT
99

1010

11-
@unittest.skipIf(
12-
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
13-
"Torch-TensorRT runtime is not available",
14-
)
1511
class TestCudagraphs(TestCase):
1612
def test_cudagraphs_on(self):
1713
torch_tensorrt.runtime.set_cudagraphs_mode(True)
@@ -66,6 +62,10 @@ def forward(self, x):
6662
)
6763
torch._dynamo.reset()
6864

65+
@unittest.skipIf(
66+
not torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime,
67+
"Torch-TensorRT runtime is not available",
68+
)
6969
def test_cudagraphs_enabled_inference_cpp(self):
7070
class SampleModel(torch.nn.Module):
7171
def forward(self, x):

0 commit comments

Comments
 (0)