Skip to content

Commit 2d943d1

Browse files
committed
Add testdata for Gather with multiple outputs
1 parent 6927bec commit 2d943d1

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed
Binary file not shown.
Binary file not shown.

testdata/dnn/onnx/generate_onnx_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -952,3 +952,20 @@ def forward(self, x):
952952
x = Variable(torch.randn(1, 3, 2, 2))
953953
model = ReduceMax()
954954
save_data_and_model("reduce_max", x, model)
955+
956+
class GatherMultiOutput(nn.Module):
957+
def __init__(self, in_dim = 2):
958+
super(GatherMultiOutput, self).__init__()
959+
self.in_dim = in_dim
960+
self.lin_inp = nn.Linear(in_dim, 2, bias=False)
961+
def forward(self, x):
962+
x_projected = self.lin_inp(x).long()
963+
x_gather = x_projected[:,0,:]
964+
x_float1 = x_gather.float()
965+
x_float2 = x_gather.float()
966+
x_float3 = x_gather.float()
967+
return x_float1+x_float2+x_float3
968+
969+
x = Variable(torch.zeros([1, 2, 2]))
970+
model = GatherMultiOutput()
971+
save_data_and_model("gather_multi_output", x, model)
432 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)