diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 42cff7b954..dc4b5da222 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -34,7 +34,7 @@ Here's an example usage model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] # trt_ep is a torch.fx.GraphModule object - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs) # Later, you can load it and run inference @@ -52,7 +52,7 @@ b) Torchscript model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] # trt_gm is a torch.fx.GraphModule object - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs) # Later, you can load it and run inference @@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well. model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object + trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object torch.jit.save(trt_ts, "trt_model.ts") # Later, you can load it and run inference