|
40 | 40 | from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, to_torch_device
|
41 | 41 | from torch_tensorrt.fx.observer import Observer
|
42 | 42 | from torch_tensorrt.logging import TRT_LOGGER
|
43 |
| -from tqdm import tqdm |
44 | 43 |
|
45 | 44 | from packaging import version
|
46 | 45 |
|
@@ -336,13 +335,21 @@ def _construct_trt_network_def(self) -> None:
|
336 | 335 |
|
337 | 336 | @staticmethod
|
338 | 337 | def find_weight(
|
339 |
| - weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] |
| 338 | + weight_name: str, np_map: dict[str, Any], state_dict: dict[str, Any] |
340 | 339 | ) -> str:
|
| 340 | + """ |
| 341 | + We need to build map from engine weight name to state_dict weight name. |
| 342 | + The purpose of this function is to find the corresponding weight name in module state_dict. |
| 343 | +
|
| 344 | + weight_name: the target weight name we want to search for |
| 345 | + np_map: the map from weight name to np values in INetworkDefinition |
| 346 | + state_dict: state of the graph module |
| 347 | + """ |
341 | 348 | network_weight = np_map[weight_name]
|
342 | 349 | network_weight = torch.from_numpy(np_map[weight_name]).cuda()
|
343 |
| - for sd_w_name, sd_weight in sd.items(): |
| 350 | + for sd_w_name, sd_weight in state_dict.items(): |
344 | 351 | if TRTInterpreter.check_weight_equal(sd_weight, network_weight):
|
345 |
| - del sd[sd_w_name] |
| 352 | + del state_dict[sd_w_name] |
346 | 353 | return sd_w_name
|
347 | 354 | return ""
|
348 | 355 |
|
@@ -470,7 +477,7 @@ def _save_weight_mapping(self) -> None:
|
470 | 477 | np_map[engine_weight_name] = weight
|
471 | 478 |
|
472 | 479 | # Stage 2: Value mapping
|
473 |
| - for engine_weight_name, sd_weight_name in tqdm(weight_name_map.items()): |
| 480 | + for engine_weight_name, sd_weight_name in weight_name_map.items(): |
474 | 481 | if "SCALE" in engine_weight_name:
|
475 | 482 | # There is no direct connection in batch_norm layer. So skip it
|
476 | 483 | pass
|
|
0 commit comments