Skip to content

Commit df121b9

Browse files
committed
Fixed batchnorm converter change
1 parent df95f00 commit df121b9

File tree

5 files changed

+25
-42
lines changed

5 files changed

+25
-42
lines changed

examples/dynamo/mutable_torchtrt_module_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
"make_refitable": True,
3535
}
3636

37-
model = models.resnet18(pretrained=False).eval().to("cuda")
37+
model = models.resnet18(pretrained=True).eval().to("cuda")
3838
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
3939
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
4040
mutable_module(*inputs)
@@ -45,7 +45,7 @@
4545

4646
# %%
4747
# Making changes to mutable module can trigger refit or re-compilation. For example, loading a different state_dict and setting new weight values will trigger refit, and adding a module to the model will trigger re-compilation.
48-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
48+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
4949
mutable_module.load_state_dict(model2.state_dict())
5050

5151

examples/dynamo/refit_engine_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# Compile the module for the first time and save it.
4040
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4141

42-
model = models.resnet18(pretrained=False).eval().to("cuda")
42+
model = models.resnet18(pretrained=True).eval().to("cuda")
4343
exp_program = torch.export.export(model, tuple(inputs))
4444
enabled_precisions = {torch.float}
4545
debug = False
@@ -68,7 +68,7 @@
6868
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
6969

7070
# Create and compile the updated model
71-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
71+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
7272
exp_program2 = torch.export.export(model2, tuple(inputs))
7373

7474

py/torch_tensorrt/dynamo/_refit.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -115,25 +115,8 @@ def construct_refit_mapping_from_weight_name_map(
115115
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
116116
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
117117
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
118-
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
119-
# Batch Norm Layer
120-
params = {
121-
"weight": 1.0,
122-
"bias": 0.0,
123-
"running_mean": 0.0,
124-
"running_var": 1.0,
125-
}
126-
for w in sd_weight_name:
127-
if w in state_dict:
128-
params[w.split(".")[-1]] = state_dict[w]
129-
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-5)
130-
shift = params["bias"] - params["running_mean"] * scale
131-
# Set scale to scale or shift to shift
132-
engine_weight_map[engine_weight_name] = eval(
133-
engine_weight_name.split(" ")[-1].lower()
134-
)
135118

136-
elif sd_weight_name not in state_dict:
119+
if sd_weight_name not in state_dict:
137120
# If weights is not in sd, we can leave it unchanged
138121
continue
139122
else:
@@ -180,7 +163,7 @@ def _refit_single_trt_engine_with_gm(
180163

181164
# Debug Use
182165
# correct = construct_refit_mapping(new_gm, input_list, settings)
183-
# {k: np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2) for k in mapping if k in correct}
166+
# comparison = {k: (np.allclose(correct[k][0], mapping[k][0].cpu().numpy(), 1e-2, 1e-2), correct[k][0], mapping[k][0]) for k in mapping if k in correct}
184167

185168
for layer_name in weight_list:
186169
if layer_name not in mapping:

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def convert_module(
137137
refit_test_engine = runtime.deserialize_cuda_engine(
138138
interpreter_result.serialized_engine
139139
)
140-
weight_name_map = interpreter_result.weight_name_map
141140
try:
142141
_refit_single_trt_engine_with_gm(
143142
new_gm=module,
@@ -146,6 +145,7 @@ def convert_module(
146145
settings=settings,
147146
weight_name_map=interpreter_result.weight_name_map,
148147
)
148+
weight_name_map = interpreter_result.weight_name_map
149149
except AssertionError:
150150
logger.warning("Fast refit test failed. Removing the weight map caching.")
151151

tests/py/dynamo/models/test_model_refit.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
@pytest.mark.unit
3636
def test_mapping():
3737

38-
model = models.resnet18(pretrained=False).eval().to("cuda")
39-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
38+
model = models.resnet18(pretrained=True).eval().to("cuda")
39+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
4040
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
4141
trt_input = [
4242
torchtrt.Input(i.shape, dtype=torch.float, format=torch.contiguous_format)
@@ -91,8 +91,8 @@ def test_mapping():
9191
@pytest.mark.unit
9292
def test_refit_one_engine_with_weightmap():
9393

94-
model = models.resnet152(pretrained=False).eval().to("cuda")
95-
model2 = models.resnet152(pretrained=True).eval().to("cuda")
94+
model = models.resnet18(pretrained=True).eval().to("cuda")
95+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
9696
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
9797
enabled_precisions = {torch.float}
9898
debug = False
@@ -140,8 +140,8 @@ def test_refit_one_engine_with_weightmap():
140140
@pytest.mark.unit
141141
def test_refit_one_engine_no_map_with_weightmap():
142142

143-
model = models.resnet18(pretrained=False).eval().to("cuda")
144-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
143+
model = models.resnet18(pretrained=True).eval().to("cuda")
144+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
145145
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
146146
enabled_precisions = {torch.float}
147147
debug = False
@@ -191,8 +191,8 @@ def test_refit_one_engine_no_map_with_weightmap():
191191
@pytest.mark.unit
192192
def test_refit_one_engine_with_wrong_weightmap():
193193

194-
model = models.resnet18(pretrained=False).eval().to("cuda")
195-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
194+
model = models.resnet18(pretrained=True).eval().to("cuda")
195+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
196196
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
197197
enabled_precisions = {torch.float}
198198
debug = False
@@ -301,8 +301,8 @@ def test_refit_one_engine_bert_with_weightmap():
301301
@pytest.mark.unit
302302
def test_refit_one_engine_inline_runtime__with_weightmap():
303303
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
304-
model = models.resnet18(pretrained=False).eval().to("cuda")
305-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
304+
model = models.resnet18(pretrained=True).eval().to("cuda")
305+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
306306
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
307307
enabled_precisions = {torch.float}
308308
debug = False
@@ -347,8 +347,8 @@ def test_refit_one_engine_inline_runtime__with_weightmap():
347347
@pytest.mark.unit
348348
def test_refit_one_engine_python_runtime_with_weightmap():
349349

350-
model = models.resnet18(pretrained=False).eval().to("cuda")
351-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
350+
model = models.resnet18(pretrained=True).eval().to("cuda")
351+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
352352
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
353353
enabled_precisions = {torch.float}
354354
debug = False
@@ -467,8 +467,8 @@ def forward(self, x):
467467
@pytest.mark.unit
468468
def test_refit_one_engine_without_weightmap():
469469

470-
model = models.resnet18(pretrained=False).eval().to("cuda")
471-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
470+
model = models.resnet18(pretrained=True).eval().to("cuda")
471+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
472472
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
473473
enabled_precisions = {torch.float}
474474
debug = False
@@ -571,8 +571,8 @@ def test_refit_one_engine_bert_without_weightmap():
571571
@pytest.mark.unit
572572
def test_refit_one_engine_inline_runtime_without_weightmap():
573573
trt_ep_path = os.path.join(tempfile.gettempdir(), "compiled.ep")
574-
model = models.resnet18(pretrained=False).eval().to("cuda")
575-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
574+
model = models.resnet18(pretrained=True).eval().to("cuda")
575+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
576576
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
577577
enabled_precisions = {torch.float}
578578
debug = False
@@ -617,8 +617,8 @@ def test_refit_one_engine_inline_runtime_without_weightmap():
617617
@pytest.mark.unit
618618
def test_refit_one_engine_python_runtime_without_weightmap():
619619

620-
model = models.resnet18(pretrained=False).eval().to("cuda")
621-
model2 = models.resnet18(pretrained=True).eval().to("cuda")
620+
model = models.resnet18(pretrained=True).eval().to("cuda")
621+
model2 = models.resnet18(pretrained=False).eval().to("cuda")
622622
inputs = [torch.randn((1, 3, 224, 224)).to("cuda")]
623623
enabled_precisions = {torch.float}
624624
debug = False

0 commit comments

Comments
 (0)