Skip to content

Commit 4e69f5d

Browse files
committed
feat: engine caching
1 parent feb4d84 commit 4e69f5d

File tree

7 files changed

+365
-12
lines changed

7 files changed

+365
-12
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import time
2+
3+
import numpy as np
4+
import torch
5+
import torch_tensorrt as torch_trt
6+
import torchvision.models as models
7+
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
8+
9+
np.random.seed(0)
10+
torch.manual_seed(0)
11+
size = (100, 3, 224, 224)
12+
inputs = [torch.rand(size).to("cuda")]
13+
14+
model = models.resnet18(pretrained=True).eval().to("cuda")
15+
exp_program = torch.export.export(model, tuple(inputs))
16+
enabled_precisions = {torch.float}
17+
debug = False
18+
workspace_size = 20 << 30
19+
min_block_size = 0
20+
use_python_runtime = False
21+
torch_executed_ops = {}
22+
23+
24+
def dynamo_path():
25+
############### warmup ###############
26+
inputs = [torch.rand(size).to("cuda")]
27+
t1 = time.time()
28+
trt_gm = torch_trt.dynamo.compile(
29+
exp_program,
30+
tuple(inputs),
31+
use_python_runtime=use_python_runtime,
32+
enabled_precisions=enabled_precisions,
33+
debug=debug,
34+
min_block_size=min_block_size,
35+
torch_executed_ops=torch_executed_ops,
36+
make_refitable=True,
37+
ignore_engine_cache=True,
38+
) # Output is a torch.fx.GraphModule
39+
t2 = time.time()
40+
41+
############### compile for the first time ###############
42+
inputs = [torch.rand(size).to("cuda")]
43+
t3 = time.time()
44+
trt_gm1 = torch_trt.dynamo.compile(
45+
exp_program,
46+
tuple(inputs),
47+
use_python_runtime=use_python_runtime,
48+
enabled_precisions=enabled_precisions,
49+
debug=debug,
50+
min_block_size=min_block_size,
51+
torch_executed_ops=torch_executed_ops,
52+
make_refitable=True,
53+
ignore_engine_cache=False,
54+
) # Output is a torch.fx.GraphModule
55+
t4 = time.time()
56+
# Check the output
57+
outputs = trt_gm1(*inputs)
58+
print("----------> 1st output:", outputs)
59+
60+
############### compile for the second time ###############
61+
inputs = [torch.rand(size).to("cuda")]
62+
t5 = time.time()
63+
trt_gm2 = torch_trt.dynamo.compile(
64+
exp_program,
65+
tuple(inputs),
66+
use_python_runtime=use_python_runtime,
67+
enabled_precisions=enabled_precisions,
68+
debug=debug,
69+
min_block_size=min_block_size,
70+
torch_executed_ops=torch_executed_ops,
71+
make_refitable=True,
72+
ignore_engine_cache=False,
73+
) # Output is a torch.fx.GraphModule
74+
t6 = time.time()
75+
# Check the output
76+
outputs = trt_gm2(*inputs)
77+
print("----------> 2nd output:", outputs)
78+
79+
print("----------> warmup compilation time:", t2 - t1, "seconds")
80+
print("----------> 1st compilation time:", t4 - t3, "seconds")
81+
print("----------> 2nd compilation time:", t6 - t5, "seconds")
82+
83+
84+
def compile_path():
85+
inputs = [torch.rand(size).to("cuda")]
86+
model = models.resnet18(pretrained=True).eval().to("cuda")
87+
t1 = time.time()
88+
model = torch.compile(
89+
model,
90+
backend="tensorrt",
91+
options={
92+
"use_python_runtime": use_python_runtime,
93+
"enabled_precisions": enabled_precisions,
94+
"debug": debug,
95+
"min_block_size": min_block_size,
96+
"torch_executed_ops": torch_executed_ops,
97+
"make_refitable": True,
98+
"ignore_engine_cache": True,
99+
},
100+
)
101+
t2 = time.time()
102+
print("---------->", model(*inputs))
103+
104+
t3 = time.time()
105+
model1 = torch.compile(
106+
model,
107+
backend="tensorrt",
108+
options={
109+
"use_python_runtime": use_python_runtime,
110+
"enabled_precisions": enabled_precisions,
111+
"debug": debug,
112+
"min_block_size": min_block_size,
113+
"torch_executed_ops": torch_executed_ops,
114+
"make_refitable": True,
115+
"ignore_engine_cache": False,
116+
},
117+
)
118+
t4 = time.time()
119+
print("----------> 1st output:", model1(*inputs))
120+
121+
t5 = time.time()
122+
model2 = torch.compile(
123+
model,
124+
backend="tensorrt",
125+
options={
126+
"use_python_runtime": use_python_runtime,
127+
"enabled_precisions": enabled_precisions,
128+
"debug": debug,
129+
"min_block_size": min_block_size,
130+
"torch_executed_ops": torch_executed_ops,
131+
"make_refitable": True,
132+
"ignore_engine_cache": False,
133+
},
134+
)
135+
t6 = time.time()
136+
print("----------> 2nd output:", model2(*inputs))
137+
138+
print("----------> warmup compilation time:", t2 - t1, "seconds")
139+
print("----------> 1st compilation time:", t4 - t3, "seconds")
140+
print("----------> 2nd compilation time:", t6 - t5, "seconds")
141+
142+
143+
if __name__ == "__main__":
144+
dynamo_path()
145+
compile_path()

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def compile(
7979
dryrun: bool = _defaults.DRYRUN,
8080
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
8181
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
82+
ignore_engine_cache: bool = _defaults.IGNORE_ENGINE_CACHE,
83+
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
84+
engine_cache_size: int = _defaults.ENGINE_CACHE_SIZE,
8285
**kwargs: Any,
8386
) -> torch.fx.GraphModule:
8487
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -139,6 +142,9 @@ def compile(
139142
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
140143
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
141144
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
145+
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
146+
engine_cache_dir (str): Directory to store the cached TRT engines
147+
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
142148
**kwargs: Any,
143149
Returns:
144150
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -234,6 +240,9 @@ def compile(
234240
"dryrun": dryrun,
235241
"hardware_compatible": hardware_compatible,
236242
"timing_cache_path": timing_cache_path,
243+
"ignore_engine_cache": ignore_engine_cache,
244+
"engine_cache_dir": engine_cache_dir,
245+
"engine_cache_size": engine_cache_size,
237246
}
238247

239248
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
HARDWARE_COMPATIBLE = False
3333
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8}
3434
TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin")
35+
IGNORE_ENGINE_CACHE = False
36+
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
37+
ENGINE_CACHE_SIZE = 1 << 30
3538

3639

3740
def default_device() -> Device:
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import ast
2+
import copy
3+
import logging
4+
import os
5+
from abc import ABC, abstractmethod
6+
from typing import List, Optional, Tuple, cast
7+
8+
import torch
9+
from torch._inductor.codecache import FxGraphCachePickler
10+
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
11+
from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR, ENGINE_CACHE_SIZE
12+
13+
_LOGGER: logging.Logger = logging.getLogger(__name__)
14+
15+
16+
class BaseEngineCache(ABC):
17+
18+
@staticmethod
19+
def get_hash(gm: torch.fx.GraphModule) -> str:
20+
"""Get the hash value of the GraphModule
21+
22+
Args:
23+
gm (torch.fx.GraphModule): GraphModule to hash
24+
25+
Returns:
26+
str: hash value of the GraphModule
27+
"""
28+
# parameters are set to 0
29+
with maybe_disable_fake_tensor_mode():
30+
new_gm = copy.deepcopy(gm)
31+
for name, param in new_gm.named_parameters():
32+
param.data.zero_()
33+
34+
hash_val = cast(str, FxGraphCachePickler.get_hash(gm))
35+
36+
return hash_val
37+
38+
@abstractmethod
39+
def save(
40+
self,
41+
hash: str,
42+
serialized_engine: bytes,
43+
input_names: List[str],
44+
output_names: List[str],
45+
) -> None:
46+
"""Save the serialized engine to hard disk
47+
48+
Args:
49+
hash (str): hash value of the GraphModule
50+
serialized_engine (bytes): serialized TRT engine
51+
input_names (List[str]): input names of TRT engine
52+
output_names (List[str]): output names of TRT engine
53+
54+
Returns:
55+
None
56+
"""
57+
pass
58+
59+
@abstractmethod
60+
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
61+
"""Load the serialized engine from hard disk
62+
63+
Args:
64+
hash (str): hash value of the GraphModule
65+
66+
Returns:
67+
Sequence[Optional[bytes], List[str], List[str]]: serialized TRT engine, input names of TRT Engine, output names of TRT Engine
68+
"""
69+
pass
70+
71+
@abstractmethod
72+
def clear_cache(self, size: int) -> None:
73+
"""Clear the cache to make sure at least `size` bytes are available
74+
75+
Args:
76+
size (int): the needed size
77+
"""
78+
pass
79+
80+
81+
class EngineCache(BaseEngineCache):
82+
83+
def __init__(
84+
self,
85+
engine_cache_size: int = ENGINE_CACHE_SIZE,
86+
engine_cache_dir: str = ENGINE_CACHE_DIR,
87+
) -> None:
88+
self.total_engine_cache_size = engine_cache_size
89+
self.available_engine_cache_size = engine_cache_size
90+
self.engine_cache_dir = engine_cache_dir
91+
92+
def has_available_cache_size(self, serialized_engine: bytes) -> bool:
93+
"""Check if the cache has available space for saving the serialized engine
94+
95+
Args:
96+
serialized_engine (bytes): serialized TRT engine
97+
98+
Returns:
99+
bool: whether the cache has available size for the serialized engine
100+
"""
101+
return len(serialized_engine) <= self.available_engine_cache_size
102+
103+
def clear_cache(self, size: int) -> None:
104+
105+
def LRU() -> None:
106+
pass
107+
108+
pass
109+
110+
def save(
111+
self,
112+
hash: str,
113+
serialized_engine: bytes,
114+
input_names: List[str],
115+
output_names: List[str],
116+
) -> None:
117+
serialized_engine_size = len(serialized_engine)
118+
if serialized_engine_size <= self.total_engine_cache_size:
119+
_LOGGER.warning(
120+
f"The serialized engine cannot be saved because the size of the engine {serialized_engine_size} is larger than the total cache size {self.total_engine_cache_size}."
121+
)
122+
return
123+
124+
if not self.has_available_cache_size(serialized_engine):
125+
self.clear_cache(serialized_engine_size)
126+
127+
path = os.path.join(
128+
self.engine_cache_dir, f"{hash}/engine_{input_names}_{output_names}.trt"
129+
)
130+
os.makedirs(os.path.dirname(path), exist_ok=True)
131+
with open(path, "wb") as f:
132+
f.write(serialized_engine)
133+
_LOGGER.info(f"A TRT engine was cached to {path}")
134+
135+
def load(self, hash: str) -> Tuple[Optional[bytes], List[str], List[str]]:
136+
directory = os.path.join(self.engine_cache_dir, hash)
137+
if os.path.exists(directory):
138+
engine_list = os.listdir(directory)
139+
assert (
140+
len(engine_list) == 1
141+
), f"There are more than one engine {engine_list} under {directory}."
142+
path = os.path.join(directory, engine_list[0])
143+
input_names_str, output_names_str = (
144+
engine_list[0].split(".")[0].split("_")[1:]
145+
)
146+
input_names = ast.literal_eval(input_names_str)
147+
output_names = ast.literal_eval(output_names_str)
148+
with open(path, "rb") as f:
149+
serialized_engine = f.read()
150+
return serialized_engine, input_names, output_names
151+
else:
152+
return None, [], []

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
DRYRUN,
1515
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1616
ENABLED_PRECISIONS,
17+
ENGINE_CACHE_DIR,
18+
ENGINE_CACHE_SIZE,
1719
ENGINE_CAPABILITY,
1820
HARDWARE_COMPATIBLE,
21+
IGNORE_ENGINE_CACHE,
1922
MAKE_REFITABLE,
2023
MAX_AUX_STREAMS,
2124
MIN_BLOCK_SIZE,
@@ -73,6 +76,9 @@ class CompilationSettings:
7376
ouptut to a file if a string path is specified
7477
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
7578
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
79+
ignore_engine_cache (bool): Whether to ignore the cached TRT engines and recompile the module
80+
engine_cache_dir (str): Directory to store the cached TRT engines
81+
engine_cache_size (int): Maximum hard-disk space to use for the engine cache
7682
"""
7783

7884
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -104,3 +110,6 @@ class CompilationSettings:
104110
dryrun: Union[bool, str] = DRYRUN
105111
hardware_compatible: bool = HARDWARE_COMPATIBLE
106112
timing_cache_path: str = TIMING_CACHE_PATH
113+
ignore_engine_cache: bool = IGNORE_ENGINE_CACHE
114+
engine_cache_dir: str = ENGINE_CACHE_DIR
115+
engine_cache_size: int = ENGINE_CACHE_SIZE

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,21 +96,19 @@ def _pretraced_backend(
9696
),
9797
)
9898

99-
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
99+
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
100100

101-
gm = post_lowering(gm, sample_inputs)
101+
gm = post_lowering(gm, sample_inputs)
102102

103-
logger.debug("Lowered Input graph:\n " + str(gm.graph))
103+
logger.debug("Lowered Input graph:\n " + str(gm.graph))
104104

105-
torchtrt_inputs = prepare_inputs(
106-
torch_inputs, disable_memory_format_check=True
107-
)
108-
trt_compiled = compile_module(
109-
gm,
110-
torchtrt_inputs,
111-
settings=settings,
112-
)
113-
return trt_compiled
105+
torchtrt_inputs = prepare_inputs(torch_inputs, disable_memory_format_check=True)
106+
trt_compiled = compile_module(
107+
gm,
108+
torchtrt_inputs,
109+
settings=settings,
110+
)
111+
return trt_compiled
114112
except (AssertionError, RuntimeError):
115113
if not settings.pass_through_build_failures:
116114
logger.warning(

0 commit comments

Comments
 (0)