Skip to content

Commit c3b61e3

Browse files
authored
Merge pull request #902 from rogday:split_slice_shenanigans
tests for Normalize subgraph, Slice, Mul and Expand
1 parent 4830352 commit c3b61e3

20 files changed

+128
-0
lines changed
224 Bytes
Binary file not shown.
224 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
168 Bytes
Binary file not shown.
168 Bytes
Binary file not shown.
416 Bytes
Binary file not shown.
224 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
168 Bytes
Binary file not shown.
168 Bytes
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import os.path
1010
import onnx
11+
import onnxsim
1112
import google.protobuf.text_format
1213
import io
1314

@@ -72,6 +73,14 @@ def save_onnx_data_and_model(input, output, name, operation, *args, **kwargs):
7273
model = onnx.helper.make_model(graph, producer_name=name)
7374
onnx.save(model, models_files)
7475

76+
def simplify(name, rename=False, **kwargs):
77+
model, check = onnxsim.simplify(name, **kwargs)
78+
assert check, "couldn't valide"
79+
name = name[:-5]
80+
if rename:
81+
name += '_optimized'
82+
onnx.save(model, name + '.onnx')
83+
7584
torch.manual_seed(0)
7685
np.random.seed(0)
7786

@@ -412,6 +421,17 @@ def forward(self, x):
412421
save_data_and_model("slice", input, model)
413422
save_data_and_model("slice_opset_11", input, model, version=11)
414423

424+
class SliceStarts(nn.Module):
425+
def __init__(self, *args, **kwargs):
426+
super(SliceStarts, self).__init__()
427+
428+
def forward(self, x):
429+
return x[-1:]
430+
431+
model = SliceStarts()
432+
input_ = Variable(torch.randn(1, 10, dtype=torch.float32))
433+
save_data_and_model("slice_neg_starts", input_, model)
434+
415435
input_2 = Variable(torch.randn(6, 6))
416436
custom_slice_list = [
417437
slice(1, 3, 1),
@@ -573,6 +593,18 @@ def forward(self, x):
573593
input_ = Variable(torch.tensor(list(range(20)), dtype=torch.float32))
574594
save_data_and_model("split_sizes", input_, model)
575595

596+
class SplitAxis(nn.Module):
597+
def __init__(self, *args, **kwargs):
598+
super(SplitAxis, self).__init__()
599+
600+
def forward(self, x):
601+
tup = torch.split(x, 2, -1)
602+
return torch.cat(tup, 1)
603+
604+
model = SplitAxis()
605+
input_ = Variable(torch.randn(1, 10, dtype=torch.float32))
606+
save_data_and_model("split_neg_axis", input_, model)
607+
576608
class SplitMax(nn.Module):
577609

578610
def __init__(self):
@@ -840,6 +872,32 @@ def forward(self, x):
840872
output = np.mean(x, axis=2, keepdims=True)
841873
save_onnx_data_and_model(x, output, 'reduce_mean_axis2', 'ReduceMean', axes=(2), keepdims=True)
842874

875+
class Expand(nn.Module):
876+
def __init__(self):
877+
super(Expand, self).__init__()
878+
879+
def forward(self, x):
880+
return x.expand(1, 3, -1, -1, -1)
881+
882+
input = Variable(torch.randn(1, 3, 2, 4))
883+
model = Expand()
884+
model.eval()
885+
save_data_and_model("expand", input, model, export_params=True, version=12)
886+
simplify('models/expand.onnx', False)
887+
888+
class ExpandIdentity(nn.Module):
889+
def __init__(self):
890+
super(ExpandIdentity, self).__init__()
891+
892+
def forward(self, x):
893+
return x.expand(1, 3, -1, -1)
894+
895+
input = Variable(torch.randn(1, 3, 2, 4))
896+
model = ExpandIdentity()
897+
model.eval()
898+
save_data_and_model("expand_identity", input, model, export_params=True, version=12)
899+
simplify('models/expand_identity.onnx', False)
900+
843901
class Expand(nn.Module):
844902
def __init__(self, shape):
845903
super(Expand, self).__init__()
@@ -908,6 +966,23 @@ def forward(self, x):
908966
x = Variable(torch.randn(1, 2, 3, 4))
909967
save_data_and_model("reduceL2_subgraph_2", x, model)
910968

969+
class reduceL2_subgraph2_2(nn.Module):
970+
def __init__(self):
971+
super(reduceL2_subgraph2_2, self).__init__()
972+
self.size = torch.Size([1, 3, 2, 4])
973+
974+
def forward(self, x):
975+
norm = torch.norm(x, p=2, dim=1, keepdim=True)
976+
clip = torch.clamp(norm, min=0)
977+
expand = clip.expand([1, 3, 2, 4])
978+
return x / expand
979+
980+
input = Variable(torch.randn(1, 3, 2, 4))
981+
model = reduceL2_subgraph2_2()
982+
model.eval()
983+
save_data_and_model("reduceL2_subgraph2_2", input, model, export_params=True, version=12)
984+
simplify('models/reduceL2_subgraph2_2.onnx', False)
985+
911986
from torchvision.ops.misc import *
912987
n = 3
913988
model = FrozenBatchNorm2d(n)
@@ -1148,6 +1223,18 @@ def forward(self, x0, x1, x2):
11481223
input_2 = Variable(torch.ones(2, 1, 4, 1, dtype=torch.float32))
11491224
save_data_and_model_multy_inputs("scale_broadcast", model, input_0, input_1, input_2)
11501225

1226+
class ScaleBroadcastMid(nn.Module):
1227+
def __init__(self, *args, **kwargs):
1228+
super(ScaleBroadcastMid, self).__init__()
1229+
1230+
def forward(self, x0, x1):
1231+
return torch.mul(x0, x1)
1232+
1233+
model = ScaleBroadcastMid()
1234+
input_0 = Variable(torch.ones(2, 1, 4, dtype=torch.float32))
1235+
input_1 = Variable(torch.ones(2, 5, 4, dtype=torch.float32))
1236+
save_data_and_model_multy_inputs("scale_broadcast_mid", model, input_0, input_1)
1237+
11511238
x = Variable(torch.randn(1, 3, 25))
11521239
conv1d = nn.Conv1d(3, 2, kernel_size=3, padding=2, stride=2, dilation=2, bias=False)
11531240
save_data_and_model("conv1d", x, conv1d)

testdata/dnn/onnx/models/expand.onnx

185 Bytes
Binary file not shown.
173 Bytes
Binary file not shown.
393 Bytes
Binary file not shown.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
pytorch1.9:t
2+

3+
0
4+
12Mul_0"Multorch-jit-exportZ
5+
0
6+

7+

8+

9+
Z
10+
1
11+

12+

13+

14+
b
15+
2
16+

17+

18+

19+
B
165 Bytes
Binary file not shown.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
pytorch1.9:�
2+
S
3+
tensor12345Split_0"Split*
4+
axis����������*
5+
split@@@@@�
6+
1
7+
1
8+
2
9+
3
10+
4
11+
56Concat_1"Concat*
12+
axis�torch-jit-exportZ
13+
tensor
14+

15+

16+

17+
b
18+
6
19+

20+

21+

22+
B

0 commit comments

Comments
 (0)