-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Open
Description
Torchscript is now fully deprecated in Pytorch 2.9. We should investigate converting usage in MONAI of torch.jit over to torch.export. Not all features and behaviours are supported as it is not the same sort of JIT architecture, and torch.compile is not meant for exporting models to a saved format.
Some MONAI models will have no issues with torch.export.export, however it doesn't capture control flow like Torchscript and is more like tracing. Networks without control flow in their forward definitions can be exported currently, eg.:
import torch
from monai.networks.nets import UNet
net = UNet(spatial_dims=2, in_channels=2, out_channels=1, channels=[4, 8, 16], strides=[2, 2])
t1 = torch.rand(3, 2, 16, 16)
t2 = torch.rand(5, 2, 32, 32)
print(net(t1).shape, net(t2).shape) # expected shapes
D = torch.export.Dim.DYNAMIC
S = torch.export.Dim.STATIC
enet = torch.export.export(net, args=(t1,), dynamic_shapes=((D, S, D, D),))
torch.export.save(enet, "out.pt2")
enet1 = torch.export.load("out.pt2")
net1 = enet1.module()
print(net1(t1).shape, net1(t2).shape) # same as expected shapesWork is needed to develop:
- Helper routines to help export then import networks in regular use.
- Helper routines and tests to ensure compatibility.
- Adaptation of existing code to overcome any compatibility issues.
- Tutorials on this new usage and how to design compatible networks.
- Removal of legacy Torchscript usage wherever present, especially when its removal from Pytorch is imminent.
Metadata
Metadata
Assignees
Labels
No labels
Type
Projects
Status
No status