|
8 | 8 | import numpy as np
|
9 | 9 | import os.path
|
10 | 10 | import onnx
|
| 11 | +import onnxsim |
11 | 12 | import google.protobuf.text_format
|
12 | 13 | import io
|
13 | 14 |
|
@@ -72,6 +73,14 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs):
|
72 | 73 | model = onnx.helper.make_model(graph, producer_name=name)
|
73 | 74 | onnx.save(model, models_files)
|
74 | 75 |
|
| 76 | +def simplify(name, rename=False, **kwargs): |
| 77 | + model, check = onnxsim.simplify(name, **kwargs) |
| 78 | + assert check, "couldn't valide" |
| 79 | + name = name[:-5] |
| 80 | + if rename: |
| 81 | + name += '_optimized' |
| 82 | + onnx.save(model, name + '.onnx') |
| 83 | + |
75 | 84 | torch.manual_seed(0)
|
76 | 85 | np.random.seed(0)
|
77 | 86 |
|
@@ -412,6 +421,17 @@ def forward(self, x):
|
412 | 421 | save_data_and_model("slice", input, model)
|
413 | 422 | save_data_and_model("slice_opset_11", input, model, version=11)
|
414 | 423 |
|
| 424 | +class SliceStarts(nn.Module): |
| 425 | + def __init__(self, *args, **kwargs): |
| 426 | + super(SliceStarts, self).__init__() |
| 427 | + |
| 428 | + def forward(self, x): |
| 429 | + return x[-1:] |
| 430 | + |
| 431 | +model = SliceStarts() |
| 432 | +input_ = Variable(torch.randn(1, 10, dtype=torch.float32)) |
| 433 | +save_data_and_model("slice_neg_starts", input_, model) |
| 434 | + |
415 | 435 | input_2 = Variable(torch.randn(6, 6))
|
416 | 436 | custom_slice_list = [
|
417 | 437 | slice(1, 3, 1),
|
@@ -573,6 +593,18 @@ def forward(self, x):
|
573 | 593 | input_ = Variable(torch.tensor(list(range(20)), dtype=torch.float32))
|
574 | 594 | save_data_and_model("split_sizes", input_, model)
|
575 | 595 |
|
| 596 | +class SplitAxis(nn.Module): |
| 597 | + def __init__(self, *args, **kwargs): |
| 598 | + super(SplitAxis, self).__init__() |
| 599 | + |
| 600 | + def forward(self, x): |
| 601 | + tup = torch.split(x, 2, -1) |
| 602 | + return torch.cat(tup, 1) |
| 603 | + |
| 604 | +model = SplitAxis() |
| 605 | +input_ = Variable(torch.randn(1, 10, dtype=torch.float32)) |
| 606 | +save_data_and_model("split_neg_axis", input_, model) |
| 607 | + |
576 | 608 | class SplitMax(nn.Module):
|
577 | 609 |
|
578 | 610 | def __init__(self):
|
@@ -840,6 +872,32 @@ def forward(self, x):
|
840 | 872 | output = np.mean(x, axis=2, keepdims=True)
|
841 | 873 | save_onnx_data_and_model(x, output, 'reduce_mean_axis2', 'ReduceMean', axes=(2), keepdims=True)
|
842 | 874 |
|
| 875 | +class Expand(nn.Module): |
| 876 | + def __init__(self): |
| 877 | + super(Expand, self).__init__() |
| 878 | + |
| 879 | + def forward(self, x): |
| 880 | + return x.expand(1, 3, -1, -1, -1) |
| 881 | + |
| 882 | +input = Variable(torch.randn(1, 3, 2, 4)) |
| 883 | +model = Expand() |
| 884 | +model.eval() |
| 885 | +save_data_and_model("expand", input, model, export_params=True, version=12) |
| 886 | +simplify('models/expand.onnx', False) |
| 887 | + |
| 888 | +class ExpandIdentity(nn.Module): |
| 889 | + def __init__(self): |
| 890 | + super(ExpandIdentity, self).__init__() |
| 891 | + |
| 892 | + def forward(self, x): |
| 893 | + return x.expand(1, 3, -1, -1) |
| 894 | + |
| 895 | +input = Variable(torch.randn(1, 3, 2, 4)) |
| 896 | +model = ExpandIdentity() |
| 897 | +model.eval() |
| 898 | +save_data_and_model("expand_identity", input, model, export_params=True, version=12) |
| 899 | +simplify('models/expand_identity.onnx', False) |
| 900 | + |
843 | 901 | class Expand(nn.Module):
|
844 | 902 | def __init__(self, shape):
|
845 | 903 | super(Expand, self).__init__()
|
@@ -908,6 +966,23 @@ def forward(self, x):
|
908 | 966 | x = Variable(torch.randn(1, 2, 3, 4))
|
909 | 967 | save_data_and_model("reduceL2_subgraph_2", x, model)
|
910 | 968 |
|
| 969 | +class reduceL2_subgraph2_2(nn.Module): |
| 970 | + def __init__(self): |
| 971 | + super(reduceL2_subgraph2_2, self).__init__() |
| 972 | + self.size = torch.Size([1, 3, 2, 4]) |
| 973 | + |
| 974 | + def forward(self, x): |
| 975 | + norm = torch.norm(x, p=2, dim=1, keepdim=True) |
| 976 | + clip = torch.clamp(norm, min=0) |
| 977 | + expand = clip.expand([1, 3, 2, 4]) |
| 978 | + return x / expand |
| 979 | + |
| 980 | +input = Variable(torch.randn(1, 3, 2, 4)) |
| 981 | +model = reduceL2_subgraph2_2() |
| 982 | +model.eval() |
| 983 | +save_data_and_model("reduceL2_subgraph2_2", input, model, export_params=True, version=12) |
| 984 | +simplify('models/reduceL2_subgraph2_2.onnx', False) |
| 985 | + |
911 | 986 | from torchvision.ops.misc import *
|
912 | 987 | n = 3
|
913 | 988 | model = FrozenBatchNorm2d(n)
|
@@ -1148,6 +1223,18 @@ def forward(self, x0, x1, x2):
|
1148 | 1223 | input_2 = Variable(torch.ones(2, 1, 4, 1, dtype=torch.float32))
|
1149 | 1224 | save_data_and_model_multy_inputs("scale_broadcast", model, input_0, input_1, input_2)
|
1150 | 1225 |
|
| 1226 | +class ScaleBroadcastMid(nn.Module): |
| 1227 | + def __init__(self, *args, **kwargs): |
| 1228 | + super(ScaleBroadcastMid, self).__init__() |
| 1229 | + |
| 1230 | + def forward(self, x0, x1): |
| 1231 | + return torch.mul(x0, x1) |
| 1232 | + |
| 1233 | +model = ScaleBroadcastMid() |
| 1234 | +input_0 = Variable(torch.ones(2, 1, 4, dtype=torch.float32)) |
| 1235 | +input_1 = Variable(torch.ones(2, 5, 4, dtype=torch.float32)) |
| 1236 | +save_data_and_model_multy_inputs("scale_broadcast_mid", model, input_0, input_1) |
| 1237 | + |
1151 | 1238 | x = Variable(torch.randn(1, 3, 25))
|
1152 | 1239 | conv1d = nn.Conv1d(3, 2, kernel_size=3, padding=2, stride=2, dilation=2, bias=False)
|
1153 | 1240 | save_data_and_model("conv1d", x, conv1d)
|
|
0 commit comments