Skip to content

Commit 2f1df87

Browse files
committed
Made revision according to comments
1 parent d9653cd commit 2f1df87

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
4141
from torch_tensorrt.fx.observer import Observer
4242
from torch_tensorrt.logging import TRT_LOGGER
43-
from tqdm import tqdm
4443

4544
from packaging import version
4645

@@ -336,13 +335,21 @@ def _construct_trt_network_def(self) -> None:
336335

337336
@staticmethod
338337
def find_weight(
339-
weight_name: str, np_map: dict[str, Any], sd: dict[str, Any]
338+
weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any]
340339
) -> str:
340+
"""
341+
We need to build map from engine weight name to state_dict weight name.
342+
The purpose of this function is to find the corresponding weight name in module state_dict.
343+
344+
weight_name: the target weight name we want to search for
345+
np_map: the map from weight name to np values in INetworkDefinition
346+
state_dict: state of the graph module
347+
"""
341348
network_weight = np_map[weight_name]
342349
network_weight = torch.from_numpy(np_map[weight_name]).cuda()
343-
for sd_w_name, sd_weight in sd.items():
350+
for sd_w_name, sd_weight in state_dict.items():
344351
if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
345-
del sd[sd_w_name]
352+
del state_dict[sd_w_name]
346353
return sd_w_name
347354
return ""
348355

@@ -470,7 +477,7 @@ def _save_weight_mapping(self) -> None:
470477
np_map[engine_weight_name] = weight
471478

472479
# Stage 2: Value mapping
473-
for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()):
480+
for engine_weight_name, sd_weight_name in weight_name_map.items():
474481
if "SCALE" in engine_weight_name:
475482
# There is no direct connection in batch_norm layer. So skip it
476483
pass

0 commit comments

Comments
 (0)