3
3
import os
4
4
import warnings
5
5
from datetime import datetime
6
- from typing import Any , Callable , Dict , List , NamedTuple , Optional , Sequence , Set , Tuple
6
+ from typing import (
7
+ Any ,
8
+ Callable ,
9
+ Dict ,
10
+ List ,
11
+ NamedTuple ,
12
+ Optional ,
13
+ Sequence ,
14
+ Set ,
15
+ Tuple ,
16
+ Union ,
17
+ )
7
18
8
19
import numpy as np
9
20
import tensorrt as trt
26
37
get_node_name ,
27
38
get_trt_tensor ,
28
39
)
29
- from torch_tensorrt .dynamo .utils import DYNAMIC_DIM
40
+ from torch_tensorrt .dynamo .utils import DYNAMIC_DIM , to_torch_device
30
41
from torch_tensorrt .fx .observer import Observer
31
42
from torch_tensorrt .logging import TRT_LOGGER
43
+ from tqdm import tqdm
32
44
33
45
from packaging import version
34
46
@@ -334,18 +346,22 @@ def _save_weight_mapping(self) -> None:
334
346
def find_weight (
335
347
weight_name : str , np_map : dict [str , Any ], sd : dict [str , Any ]
336
348
) -> str :
337
- network_weight = np_map [weight_name ]
349
+ network_weight = torch . from_numpy ( np_map [weight_name ]). cuda ()
338
350
for sd_w_name , sd_weight in sd .items ():
339
351
if check_weight_equal (sd_weight , network_weight ):
352
+ del sd [sd_w_name ]
340
353
return sd_w_name
341
354
return ""
342
355
343
356
def check_weight_equal (
344
- sd_weight : torch .tensor , network_weight : np .ndarray
357
+ sd_weight : torch .tensor , network_weight : Union [ np .ndarray , torch . Tensor ]
345
358
) -> Any :
346
- sd_weight = sd_weight .reshape (- 1 ).cpu ().numpy ()
347
- return sd_weight .size == network_weight .size and np .allclose (
348
- sd_weight , network_weight , 1e-1 , 1e-1
359
+ sd_weight = sd_weight .reshape (- 1 )
360
+ if not isinstance (network_weight , torch .Tensor ):
361
+ network_weight = torch .from_numpy (network_weight ).cuda ()
362
+ return (
363
+ sd_weight .shape == network_weight .shape
364
+ and torch .all (torch .abs (sd_weight - network_weight ) < 0.1 ).cpu ()
349
365
)
350
366
351
367
MODULE_MAP = {
@@ -393,8 +409,14 @@ def check_weight_equal(
393
409
)
394
410
}
395
411
"""
412
+ _LOGGER .info ("building weight name mapping..." )
396
413
# Stage 1: Name mapping
397
414
sd = self .module .state_dict ()
415
+ gm_is_on_cuda = list (sd .values ())[0 ].device .type == "cuda"
416
+ # If the model original position is on CPU, move it GPU
417
+ if not gm_is_on_cuda :
418
+ self .module .to (to_torch_device (self .compilation_settings .device ))
419
+ sd = self .module .state_dict ()
398
420
weight_name_map : dict [str , Any ] = {}
399
421
np_map = {}
400
422
net = self .ctx .net
@@ -439,7 +461,7 @@ def check_weight_equal(
439
461
np_map [engine_weight_name ] = weight
440
462
441
463
# Stage 2: Value mapping
442
- for engine_weight_name , sd_weight_name in weight_name_map .items ():
464
+ for engine_weight_name , sd_weight_name in tqdm ( weight_name_map .items () ):
443
465
if "SCALE" in engine_weight_name :
444
466
# There is no direct connection in batch_norm layer. So skip it
445
467
pass
@@ -456,6 +478,10 @@ def check_weight_equal(
456
478
]
457
479
458
480
self .weight_name_map = weight_name_map
481
+ # If the model original position is on CPU, set it back to CPU and save GPU memory
482
+ if not gm_is_on_cuda :
483
+ self .module .to ("cpu" )
484
+ torch .cuda .empty_cache ()
459
485
460
486
def run (
461
487
self ,
0 commit comments