From c98e36d246903652f481e8fc51ed71a269c2221e Mon Sep 17 00:00:00 2001 From: Ohad Ravid Date: Sun, 22 Dec 2024 14:44:36 +0200 Subject: [PATCH] Fix usage example --- docsrc/user_guide/saving_models.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docsrc/user_guide/saving_models.rst b/docsrc/user_guide/saving_models.rst index 42cff7b954..dc4b5da222 100644 --- a/docsrc/user_guide/saving_models.rst +++ b/docsrc/user_guide/saving_models.rst @@ -34,7 +34,7 @@ Here's an example usage model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] # trt_ep is a torch.fx.GraphModule object - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) torch_tensorrt.save(trt_gm, "trt.ep", inputs=inputs) # Later, you can load it and run inference @@ -52,7 +52,7 @@ b) Torchscript model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] # trt_gm is a torch.fx.GraphModule object - trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs) + trt_gm = torch_tensorrt.compile(model, ir="dynamo", inputs=inputs) torch_tensorrt.save(trt_gm, "trt.ts", output_format="torchscript", inputs=inputs) # Later, you can load it and run inference @@ -73,7 +73,7 @@ For `ir=ts`, this behavior stays the same in 2.X versions as well. model = MyModel().eval().cuda() inputs = [torch.randn((1, 3, 224, 224)).cuda()] - trt_ts = torch_tensorrt.compile(model, ir="ts", inputs) # Output is a ScriptModule object + trt_ts = torch_tensorrt.compile(model, ir="ts", inputs=inputs) # Output is a ScriptModule object torch.jit.save(trt_ts, "trt_model.ts") # Later, you can load it and run inference