diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c97c3a6229..b3e243187d 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -613,10 +613,10 @@ def convert_exported_program_to_serialized_trt_engine( DeprecationWarning, stacklevel=2, ) - if not arg_inputs and not inputs: + if arg_inputs is None and inputs is None: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") - elif arg_inputs and inputs: + elif arg_inputs is not None and inputs is not None: raise AssertionError( "'arg_inputs' and 'inputs' should not be used at the same time." ) diff --git a/tests/py/dynamo/models/test_export_kwargs_serde.py b/tests/py/dynamo/models/test_export_kwargs_serde.py index 08b23d55e0..f4587375ea 100644 --- a/tests/py/dynamo/models/test_export_kwargs_serde.py +++ b/tests/py/dynamo/models/test_export_kwargs_serde.py @@ -511,3 +511,50 @@ def forward(self, x, b=5, c=None, d=None): engine = convert_exported_program_to_serialized_trt_engine( exp_program, **compile_spec ) + + +def test_custom_model_compile_engine_with_pure_kwarg_inputs(): + class net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 12, 3, padding=1) + self.bn = nn.BatchNorm2d(12) + self.conv2 = nn.Conv2d(12, 12, 3, padding=1) + self.fc1 = nn.Linear(12 * 56 * 56, 10) + + def forward(self, x, b=5, c=None, d=None): + x = self.conv1(x) + x = F.relu(x) + x = self.bn(x) + x = F.max_pool2d(x, (2, 2)) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, (2, 2)) + x = torch.flatten(x, 1) + x = x + b + if c is not None: + x = x * c + if d is not None: + x = x - d["value"] + return self.fc1(x) + + model = net().eval().to("cuda") + kwargs = { + "x": torch.rand((1, 3, 224, 224)).to("cuda"), + "b": torch.tensor(6).to("cuda"), + "d": {"value": torch.tensor(8).to("cuda")}, + } + + compile_spec = { + "arg_inputs": (), + "kwarg_inputs": kwargs, + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 1, + "ir": "dynamo", + } + + exp_program = torch.export.export(model, args=(), kwargs=kwargs) + _ = convert_exported_program_to_serialized_trt_engine(exp_program, **compile_spec)