Skip to content

🐛 [Bug] Refit bug that state_dict could be empty #3126

Closed
@zewenli98

Description

@zewenli98

Bug Description

The bug was observed while testing engine caching with torch.compile(...). The error message is:

Traceback (most recent call last):
  File "/home/zewenl/anaconda3/envs/trt-10.1-py310/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/zewenl/anaconda3/envs/trt-10.1-py310/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/zewenl/anaconda3/envs/trt-10.1-py310/lib/python3.10/site-packages/torch/__init__.py", line 2284, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 44, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 52, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/backend/backends.py", line 110, in _pretraced_backend
    trt_compiled = compile_module(
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 485, in compile_module
    trt_module = convert_module(
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 134, in convert_module
    interpreter_result = interpret_module_to_result(
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 113, in interpret_module_to_result
    interpreter_result = interpreter.run()
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 589, in run
    self._save_weight_mapping()
  File "/home/zewenl/Documents/pytorch/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 442, in _save_weight_mapping
    gm_is_on_cuda = list(sd.values())[0].device.type == "cuda"
IndexError: list index out of range

where I printed out the sd which is

OrderedDict()

and self.module is:

GraphModule()



def forward(self, arg1_1, arg0_1, arg4_1, arg5_1, arg2_1, arg3_1, arg6_1, arg9_1, arg10_1, arg7_1, arg8_1, arg11_1, arg14_1, arg15_1, arg12_1, arg13_1, arg16_1, arg19_1, arg20_1, arg17_1, arg18_1, arg21_1, arg24_1, arg25_1, arg22_1, arg23_1, arg26_1, arg29_1, arg30_1, arg27_1, arg28_1, arg31_1, arg34_1, arg35_1, arg32_1, arg33_1, arg36_1, arg39_1, arg40_1, arg37_1, arg38_1, arg41_1, arg44_1, arg45_1, arg42_1, arg43_1, arg46_1, arg49_1, arg50_1, arg47_1, arg48_1, arg51_1, arg54_1, arg55_1, arg52_1, arg53_1, arg56_1, arg59_1, arg60_1, arg57_1, arg58_1, arg61_1, arg64_1, arg65_1, arg62_1, arg63_1, arg66_1, arg69_1, arg70_1, arg67_1, arg68_1, arg71_1, arg74_1, arg75_1, arg72_1, arg73_1, arg76_1, arg79_1, arg80_1, arg77_1, arg78_1, arg81_1, arg84_1, arg85_1, arg82_1, arg83_1, arg86_1, arg89_1, arg90_1, arg87_1, arg88_1, arg91_1, arg94_1, arg95_1, arg92_1, arg93_1, arg96_1, arg99_1, arg100_1, arg97_1, arg98_1, arg101_1, arg102_1):
    convolution = torch.ops.aten.convolution.default(arg1_1, arg0_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1);  arg1_1 = arg0_1 = None
    _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution, arg4_1, arg5_1, arg2_1, arg3_1, 0.1, 1e-05);  convolution = arg4_1 = arg5_1 = arg2_1 = arg3_1 = None
    getitem = _native_batch_norm_legit_no_training[0];  _native_batch_norm_legit_no_training = None
    relu = torch.ops.aten.relu.default(getitem);  getitem = None
    max_pool2d_default = torch.ops.aten.max_pool2d.default(relu, [3, 3], [2, 2], [1, 1]);  relu = None
    convolution_1 = torch.ops.aten.convolution.default(max_pool2d_default, arg6_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg6_1 = None
    _native_batch_norm_legit_no_training_1 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_1, arg9_1, arg10_1, arg7_1, arg8_1, 0.1, 1e-05);  convolution_1 = arg9_1 = arg10_1 = arg7_1 = arg8_1 = None
    getitem_5 = _native_batch_norm_legit_no_training_1[0];  _native_batch_norm_legit_no_training_1 = None
    relu_1 = torch.ops.aten.relu.default(getitem_5);  getitem_5 = None
    convolution_2 = torch.ops.aten.convolution.default(relu_1, arg11_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_1 = arg11_1 = None
    _native_batch_norm_legit_no_training_2 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_2, arg14_1, arg15_1, arg12_1, arg13_1, 0.1, 1e-05);  convolution_2 = arg14_1 = arg15_1 = arg12_1 = arg13_1 = None
    getitem_8 = _native_batch_norm_legit_no_training_2[0];  _native_batch_norm_legit_no_training_2 = None
    add = torch.ops.aten.add.Tensor(getitem_8, max_pool2d_default);  getitem_8 = max_pool2d_default = None
    relu_2 = torch.ops.aten.relu.default(add);  add = None
...

The previous version of refitting worked for torch.compile. I believe this bug was introduced from PR #3097

To Reproduce

On branch engine_cache, run the test test_torch_compile_with_default_disk_engine_cache: https://github.com/pytorch/TensorRT/blob/ccddbc66e33da2a2693847073895c6263e8486a5/tests/py/dynamo/models/test_engine_cache.py#L190C9-L190C58

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions