Skip to content

Commit 30f3b15

Browse files
committed
Fixed prolong time of weight name mapping construction
1 parent cf77e4c commit 30f3b15

File tree

1 file changed

+34
-8
lines changed

1 file changed

+34
-8
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,18 @@
33
import os
44
import warnings
55
from datetime import datetime
6-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple
6+
from typing import (
7+
Any,
8+
Callable,
9+
Dict,
10+
List,
11+
NamedTuple,
12+
Optional,
13+
Sequence,
14+
Set,
15+
Tuple,
16+
Union,
17+
)
718

819
import numpy as np
920
import tensorrt as trt
@@ -26,9 +37,10 @@
2637
get_node_name,
2738
get_trt_tensor,
2839
)
29-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
40+
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
3041
from torch_tensorrt.fx.observer import Observer
3142
from torch_tensorrt.logging import TRT_LOGGER
43+
from tqdm import tqdm
3244

3345
from packaging import version
3446

@@ -334,18 +346,22 @@ def _save_weight_mapping(self) -> None:
334346
def find_weight(
335347
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
336348
) -> str:
337-
network_weight = np_map[weight_name]
349+
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
338350
for sd_w_name, sd_weight in sd.items():
339351
if check_weight_equal(sd_weight, network_weight):
352+
del sd[sd_w_name]
340353
return sd_w_name
341354
return ""
342355

343356
def check_weight_equal(
344-
sd_weight: torch.tensor, network_weight: np.ndarray
357+
sd_weight: torch.tensor, network_weight: Union[np.ndarray, torch.Tensor]
345358
) -> Any:
346-
sd_weight = sd_weight.reshape(-1).cpu().numpy()
347-
return sd_weight.size == network_weight.size and np.allclose(
348-
sd_weight, network_weight, 1e-1, 1e-1
359+
sd_weight = sd_weight.reshape(-1)
360+
if not isinstance(network_weight, torch.Tensor):
361+
network_weight = torch.from_numpy(network_weight).cuda()
362+
return (
363+
sd_weight.shape == network_weight.shape
364+
and torch.all(torch.abs(sd_weight - network_weight) < 0.1).cpu()
349365
)
350366

351367
MODULE_MAP = {
@@ -393,8 +409,14 @@ def check_weight_equal(
393409
)
394410
}
395411
"""
412+
_LOGGER.info("building weight name mapping...")
396413
# Stage 1: Name mapping
397414
sd = self.module.state_dict()
415+
gm_is_on_cuda = list(sd.values())[0].device.type == "cuda"
416+
# If the model original position is on CPU, move it GPU
417+
if not gm_is_on_cuda:
418+
self.module.to(to_torch_device(self.compilation_settings.device))
419+
sd = self.module.state_dict()
398420
weight_name_map: dict[str, Any] = {}
399421
np_map = {}
400422
net = self.ctx.net
@@ -439,7 +461,7 @@ def check_weight_equal(
439461
np_map[engine_weight_name] = weight
440462

441463
# Stage 2: Value mapping
442-
for engine_weight_name, sd_weight_name in weight_name_map.items():
464+
for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()):
443465
if "SCALE" in engine_weight_name:
444466
# There is no direct connection in batch_norm layer. So skip it
445467
pass
@@ -456,6 +478,10 @@ def check_weight_equal(
456478
]
457479

458480
self.weight_name_map = weight_name_map
481+
# If the model original position is on CPU, set it back to CPU and save GPU memory
482+
if not gm_is_on_cuda:
483+
self.module.to("cpu")
484+
torch.cuda.empty_cache()
459485

460486
def run(
461487
self,

0 commit comments

Comments
 (0)