Skip to content

Commit 9d3e40c

Browse files
narendasanNaren Dasan
authored andcommitted
feat: Adding live progress monitoring to the engine building phase
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 2589fdb commit 9d3e40c

File tree

5 files changed

+194
-21
lines changed

5 files changed

+194
-21
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ jobs:
7777
pre-script: ${{ matrix.pre-script }}
7878
script: |
7979
export USE_HOST_DEPS=1
80+
export CI_BUILD=1
8081
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH
8182
pushd .
8283
cd tests/modules
@@ -112,6 +113,7 @@ jobs:
112113
pre-script: ${{ matrix.pre-script }}
113114
script: |
114115
export USE_HOST_DEPS=1
116+
export CI_BUILD=1
115117
pushd .
116118
cd tests/py/dynamo
117119
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 8 conversion/
@@ -140,6 +142,7 @@ jobs:
140142
pre-script: ${{ matrix.pre-script }}
141143
script: |
142144
export USE_HOST_DEPS=1
145+
export CI_BUILD=1
143146
pushd .
144147
cd tests/py/dynamo
145148
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
@@ -168,6 +171,7 @@ jobs:
168171
pre-script: ${{ matrix.pre-script }}
169172
script: |
170173
export USE_HOST_DEPS=1
174+
export CI_BUILD=1
171175
pushd .
172176
cd tests/py/dynamo
173177
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
@@ -196,6 +200,7 @@ jobs:
196200
pre-script: ${{ matrix.pre-script }}
197201
script: |
198202
export USE_HOST_DEPS=1
203+
export CI_BUILD=1
199204
pushd .
200205
cd tests/py/dynamo
201206
python -m pytest -ra -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
@@ -226,6 +231,7 @@ jobs:
226231
pre-script: ${{ matrix.pre-script }}
227232
script: |
228233
export USE_HOST_DEPS=1
234+
export CI_BUILD=1
229235
pushd .
230236
cd tests/py/dynamo
231237
python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml --ignore runtime/test_002_cudagraphs_py.py --ignore runtime/test_002_cudagraphs_cpp.py runtime/
@@ -256,6 +262,7 @@ jobs:
256262
pre-script: ${{ matrix.pre-script }}
257263
script: |
258264
export USE_HOST_DEPS=1
265+
export CI_BUILD=1
259266
pushd .
260267
cd tests/py/dynamo
261268
nvidia-smi
@@ -286,6 +293,7 @@ jobs:
286293
pre-script: ${{ matrix.pre-script }}
287294
script: |
288295
export USE_HOST_DEPS=1
296+
export CI_BUILD=1
289297
pushd .
290298
cd tests/py/core
291299
python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .

.github/workflows/build-test-windows.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ jobs:
8383
pre-script: packaging/driver_upgrade.bat
8484
script: |
8585
export USE_HOST_DEPS=1
86+
export CI_BUILD=1
8687
pushd .
8788
cd tests/modules
8889
python hub.py
@@ -114,6 +115,7 @@ jobs:
114115
pre-script: packaging/driver_upgrade.bat
115116
script: |
116117
export USE_HOST_DEPS=1
118+
export CI_BUILD=1
117119
pushd .
118120
cd tests/py/dynamo
119121
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 10 conversion/
@@ -139,6 +141,7 @@ jobs:
139141
pre-script: packaging/driver_upgrade.bat
140142
script: |
141143
export USE_HOST_DEPS=1
144+
export CI_BUILD=1
142145
pushd .
143146
cd tests/py/dynamo
144147
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/dyn_models_export.xml --ir dynamo models/
@@ -164,6 +167,7 @@ jobs:
164167
pre-script: packaging/driver_upgrade.bat
165168
script: |
166169
export USE_HOST_DEPS=1
170+
export CI_BUILD=1
167171
pushd .
168172
cd tests/py/dynamo
169173
python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/export_serde_test_results.xml --ir dynamo models/test_export_serde.py
@@ -189,6 +193,7 @@ jobs:
189193
pre-script: packaging/driver_upgrade.bat
190194
script: |
191195
export USE_HOST_DEPS=1
196+
export CI_BUILD=1
192197
pushd .
193198
cd tests/py/dynamo
194199
python -m pytest -ra -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
@@ -216,6 +221,7 @@ jobs:
216221
pre-script: packaging/driver_upgrade.bat
217222
script: |
218223
export USE_HOST_DEPS=1
224+
export CI_BUILD=1
219225
pushd .
220226
cd tests/py/dynamo
221227
python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_test_results.xml --ignore runtime/test_002_cudagraphs_py.py --ignore runtime/test_002_cudagraphs_cpp.py runtime/
@@ -246,6 +252,7 @@ jobs:
246252
pre-script: ${{ matrix.pre-script }}
247253
script: |
248254
export USE_HOST_DEPS=1
255+
export CI_BUILD=1
249256
pushd .
250257
cd tests/py/dynamo
251258
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_dynamo_core_runtime_cudagraphs_cpp_test_results.xml runtime/test_002_cudagraphs_cpp.py
@@ -272,6 +279,7 @@ jobs:
272279
pre-script: packaging/driver_upgrade.bat
273280
script: |
274281
export USE_HOST_DEPS=1
282+
export CI_BUILD=1
275283
pushd .
276284
cd tests/py/core
277285
python -m pytest -ra -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/tests_py_core_test_results.xml .
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
import os
2+
import sys
3+
from typing import Any, Dict, Optional
4+
5+
import tensorrt as trt
6+
7+
8+
class _ASCIIMonitor(trt.IProgressMonitor): # type: ignore
9+
def __init__(self, engine_name: str = "") -> None:
10+
trt.IProgressMonitor.__init__(self)
11+
self._active_phases: Dict[str, Dict[str, Any]] = {}
12+
self._step_result = True
13+
14+
self._render = True
15+
if (ci_env_var := os.environ.get("CI_BUILD")) is not None:
16+
if ci_env_var == "1":
17+
self._render = False
18+
19+
def phase_start(
20+
self, phase_name: str, parent_phase: Optional[str], num_steps: int
21+
) -> None:
22+
try:
23+
if parent_phase is not None:
24+
nbIndents = 1 + self._active_phases[parent_phase]["nbIndents"]
25+
else:
26+
nbIndents = 0
27+
self._active_phases[phase_name] = {
28+
"title": phase_name,
29+
"steps": 0,
30+
"num_steps": num_steps,
31+
"nbIndents": nbIndents,
32+
}
33+
self._redraw()
34+
except KeyboardInterrupt:
35+
_step_result = False
36+
37+
def phase_finish(self, phase_name: str) -> None:
38+
try:
39+
del self._active_phases[phase_name]
40+
self._redraw(blank_lines=1) # Clear the removed phase.
41+
except KeyboardInterrupt:
42+
_step_result = False
43+
44+
def step_complete(self, phase_name: str, step: int) -> bool:
45+
try:
46+
self._active_phases[phase_name]["steps"] = step
47+
self._redraw()
48+
return self._step_result
49+
except KeyboardInterrupt:
50+
return False
51+
52+
def _redraw(self, *, blank_lines: int = 0) -> None:
53+
if self._render:
54+
55+
def clear_line() -> None:
56+
print("\x1B[2K", end="")
57+
58+
def move_to_start_of_line() -> None:
59+
print("\x1B[0G", end="")
60+
61+
def move_cursor_up(lines: int) -> None:
62+
print("\x1B[{}A".format(lines), end="")
63+
64+
def progress_bar(steps: int, num_steps: int) -> str:
65+
INNER_WIDTH = 10
66+
completed_bar_chars = int(INNER_WIDTH * steps / float(num_steps))
67+
return "[{}{}]".format(
68+
"=" * completed_bar_chars, "-" * (INNER_WIDTH - completed_bar_chars)
69+
)
70+
71+
# Set max_cols to a default of 200 if not run in interactive mode.
72+
max_cols = os.get_terminal_size().columns if sys.stdout.isatty() else 200
73+
74+
move_to_start_of_line()
75+
for phase in self._active_phases.values():
76+
phase_prefix = "{indent}{bar} {title}".format(
77+
indent=" " * phase["nbIndents"],
78+
bar=progress_bar(phase["steps"], phase["num_steps"]),
79+
title=phase["title"],
80+
)
81+
phase_suffix = "{steps}/{num_steps}".format(**phase)
82+
allowable_prefix_chars = max_cols - len(phase_suffix) - 2
83+
if allowable_prefix_chars < len(phase_prefix):
84+
phase_prefix = phase_prefix[0 : allowable_prefix_chars - 3] + "..."
85+
clear_line()
86+
print(phase_prefix, phase_suffix)
87+
for line in range(blank_lines):
88+
clear_line()
89+
print()
90+
move_cursor_up(len(self._active_phases) + blank_lines)
91+
sys.stdout.flush()
92+
93+
94+
try:
95+
from rich.progress import BarColumn, Progress, TaskID, TextColumn, TimeElapsedColumn
96+
97+
class _RichMonitor(trt.IProgressMonitor): # type: ignore
98+
def __init__(self, engine_name: str = "") -> None:
99+
trt.IProgressMonitor.__init__(self)
100+
self._active_phases: Dict[str, TaskID] = {}
101+
self._step_result = True
102+
103+
self._progress_monitors = Progress(
104+
TextColumn(" "),
105+
TimeElapsedColumn(),
106+
TextColumn("{task.description}: "),
107+
BarColumn(),
108+
TextColumn(" {task.percentage:.0f}% ({task.completed}/{task.total})"),
109+
)
110+
111+
self._render = True
112+
if (ci_env_var := os.environ.get("CI_BUILD")) is not None:
113+
if ci_env_var == "1":
114+
self._render = False
115+
116+
if self._render:
117+
self._progress_monitors.start()
118+
119+
def phase_start(
120+
self, phase_name: str, parent_phase: Optional[str], num_steps: int
121+
) -> None:
122+
try:
123+
self._active_phases[phase_name] = self._progress_monitors.add_task(
124+
phase_name, total=num_steps
125+
)
126+
self._progress_monitors.refresh()
127+
except KeyboardInterrupt:
128+
# The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
129+
_step_result = False
130+
131+
def phase_finish(self, phase_name: str) -> None:
132+
try:
133+
self._progress_monitors.update(
134+
self._active_phases[phase_name], visible=False
135+
)
136+
self._progress_monitors.stop_task(self._active_phases[phase_name])
137+
self._progress_monitors.remove_task(self._active_phases[phase_name])
138+
self._progress_monitors.refresh()
139+
except KeyboardInterrupt:
140+
_step_result = False
141+
142+
def step_complete(self, phase_name: str, step: int) -> bool:
143+
try:
144+
self._progress_monitors.update(
145+
self._active_phases[phase_name], completed=step
146+
)
147+
self._progress_monitors.refresh()
148+
return self._step_result
149+
except KeyboardInterrupt:
150+
# There is no need to propagate this exception to TensorRT. We can simply cancel the build.
151+
return False
152+
153+
def __del__(self) -> None:
154+
if self._progress_monitors:
155+
self._progress_monitors.stop()
156+
157+
TRTBulderMonitor: trt.IProgressMonitor = _RichMonitor
158+
except ImportError:
159+
TRTBulderMonitor: trt.IProgressMonitor = _ASCIIMonitor # type: ignore[no-redef]

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
77

88
import numpy as np
9-
import tensorrt as trt
109
import torch
1110
import torch.fx
1211
from torch.fx.node import _get_qualified_name
@@ -21,6 +20,7 @@
2120
DYNAMO_CONVERTERS as CONVERTERS,
2221
)
2322
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
23+
from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor
2424
from torch_tensorrt.dynamo.conversion.converter_utils import (
2525
get_node_io,
2626
get_node_name,
@@ -30,6 +30,7 @@
3030
from torch_tensorrt.fx.observer import Observer
3131
from torch_tensorrt.logging import TRT_LOGGER
3232

33+
import tensorrt as trt
3334
from packaging import version
3435

3536
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -146,7 +147,7 @@ def clean_repr(x: Any, depth: int = 0) -> Any:
146147
else:
147148
return "(...)"
148149
else:
149-
return x
150+
return f"{x} <{type(x).__name__}>"
150151

151152
str_args = [clean_repr(a) for a in args]
152153
return repr(tuple(str_args))
@@ -176,6 +177,10 @@ def _populate_trt_builder_config(
176177
) -> trt.IBuilderConfig:
177178

178179
builder_config = self.builder.create_builder_config()
180+
181+
if self.compilation_settings.debug:
182+
builder_config.progress_monitor = TRTBulderMonitor()
183+
179184
if self.compilation_settings.workspace_size != 0:
180185
builder_config.set_memory_pool_limit(
181186
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
@@ -516,18 +521,18 @@ def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
516521
kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
517522
n.kwargs = kwargs
518523

519-
# run the node
520-
_LOGGER.debug(
521-
f"Running node {self._cur_node_name}, a {self._cur_node.op} node "
522-
f"with target {self._cur_node.target} in the TensorRT Interpreter"
523-
)
524+
if _LOGGER.isEnabledFor(logging.DEBUG):
525+
_LOGGER.debug(
526+
f"Converting node {self._cur_node_name} (kind: {n.target}, args: {TRTInterpreter._args_str(n.args)})"
527+
)
528+
524529
trt_node: torch.fx.Node = super().run_node(n)
525530

526531
if n.op == "get_attr":
527532
self.const_mapping[str(n)] = (tuple(trt_node.shape), str(trt_node.dtype))
528533

529-
_LOGGER.debug(
530-
f"Ran node {self._cur_node_name} with properties: {get_node_io(n, self.const_mapping)}"
534+
_LOGGER.info(
535+
f"Converted node {self._cur_node_name} [{n.target}] ({get_node_io(n, self.const_mapping)})"
531536
)
532537

533538
# remove "_itensor_to_tensor_meta"
@@ -611,9 +616,7 @@ def call_module(
611616
converter, calling_convention = converter_packet
612617

613618
assert self._cur_node_name is not None
614-
_LOGGER.debug(
615-
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
616-
)
619+
617620
if calling_convention is CallingConvention.LEGACY:
618621
return converter(self.ctx.net, submod, args, kwargs, self._cur_node_name)
619622
else:
@@ -629,10 +632,6 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
629632

630633
converter, calling_convention = converter_packet
631634

632-
assert self._cur_node_name is not None
633-
_LOGGER.debug(
634-
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
635-
)
636635
if calling_convention is CallingConvention.LEGACY:
637636
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
638637
else:
@@ -663,10 +662,6 @@ def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
663662
)
664663
converter, calling_convention = converter_packet
665664

666-
assert self._cur_node_name is not None
667-
_LOGGER.debug(
668-
f"Converting node {self._cur_node_name} (kind: {target}, args: {TRTInterpreter._args_str(args)})"
669-
)
670665
if calling_convention is CallingConvention.LEGACY:
671666
return converter(self.ctx.net, target, args, kwargs, self._cur_node_name)
672667
else:

0 commit comments

Comments
 (0)