-
Notifications
You must be signed in to change notification settings - Fork 64
Implement aten::div.Tensor_mode
| feat(torchlib)
#988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
aten::div.Tensor_mode
| feat(torchlib)
Codecov Report
@@ Coverage Diff @@
## main #988 +/- ##
==========================================
+ Coverage 77.12% 77.18% +0.05%
==========================================
Files 112 112
Lines 13947 13970 +23
Branches 1438 1441 +3
==========================================
+ Hits 10757 10783 +26
+ Misses 2829 2826 -3
Partials 361 361
|
Test Results 18 files ± 0 18 suites ±0 1h 6m 0s ⏱️ - 6m 15s For more details on these errors, see this check. Results for commit ea0158c. ± Comparison against base commit c6e216e. This pull request removes 352 and adds 389 tests. Note that renamed tests count towards both.
♻️ This comment has been updated with latest results. |
Split a variant for float16 and use cast to INT64 for trunc because max of float16 is 65536, which is within INT64’s range. |
Example mismatch report: SummaryThe output of ONNX Runtime does not match that of PyTorch when executing test To recreate this report, use CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16 InputsDetails
inputs = (tensor([[[-6.6953, 6.7070, -7.7422, -5.9688, -3.5234],
[-0.0703, 8.8828, 5.0195, -1.2656, -8.9922],
[-1.1250, 2.6641, -6.7773, 5.7227, -5.0273],
[ 0.3164, 1.3975, 8.9453, -0.0879, -8.9297],
[ 2.0391, 0.8613, 5.7812, -1.8369, -7.1797],
[ 1.1426, -5.4570, -7.6719, -8.5312, 1.0459],
[ 5.9062, 8.0781, -3.4883, -3.8496, 5.4844],
[ 8.5000, 3.6836, 3.6992, -8.0938, 7.4805],
[-6.7773, 1.9863, 0.9756, -4.5273, 4.1484],
[-8.1406, -7.6641, 7.5586, -0.9756, 7.8672]],
[[-1.2920, -6.0391, 6.1953, -1.7842, 1.4238],
[ 6.3203, -0.5977, 5.9766, 8.7422, -7.5938],
[ 5.0898, -1.9688, -5.5117, 2.4258, -1.8369],
[ 8.0391, 4.4219, 7.8672, 3.4375, 1.1777],
[ 1.7051, -5.4844, -3.3828, 0.2812, -2.9609],
[-4.9648, -0.6152, 7.7500, -4.8789, 3.2871],
[ 0.1846, -2.4180, 5.8789, -8.6719, -6.9883],
[ 3.1992, -4.0625, -5.1602, 6.8125, 2.4785],
[-8.7734, -7.0234, 2.4258, -1.9951, 5.3438],
[-6.5117, -1.2832, -0.5713, 4.2188, -5.0273]],
[[-0.0439, 7.2773, -1.8369, -3.2168, 4.5273],
[-4.2031, 1.6787, 0.4219, 4.4922, -8.8594],
[ 1.5029, 1.9248, -7.6211, -4.8164, 0.3955],
[ 3.2695, -4.4570, -8.1406, 7.8320, 6.5391],
[ 0.3164, -6.5469, 3.0586, -4.6406, -6.3555],
[-7.4180, -3.6738, 0.8613, -2.8555, -0.2812],
[ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
[ 8.6328, 1.4062, 2.2422, 2.8301, -7.8828],
[-8.8516, 6.9609, 8.7969, 6.4531, -4.0352],
[-4.0000, 2.8301, 7.6562, -2.9805, 6.9531]],
[[-3.5078, 2.5234, 8.2812, -1.5029, 5.8203],
[-0.5537, 5.6094, -7.6484, -4.7031, 3.3828],
[ 8.6562, -2.1094, 0.9053, 8.1562, 3.4453],
[ 7.1172, -6.8477, 1.5381, 0.3340, 4.1836],
[ 5.9844, -2.3379, -4.2812, 7.6094, 7.1797],
[-0.9141, -5.4219, -4.3945, -1.7314, 4.4375],
[-1.0723, 7.2422, -8.6953, -2.9883, -8.8359],
[-6.1875, 6.8125, -4.5078, 8.8359, -5.4922],
[-3.2617, -0.4658, 7.7500, -4.8672, 1.2480],
[ 0.7998, 3.1016, -1.6787, 7.0234, -2.6191]],
[[ 8.4219, -1.5029, -8.5625, 5.8359, 6.1250],
[-3.9551, 1.6611, 1.0898, -7.6914, -6.2500],
[-4.0000, 2.9961, -6.3281, 6.0742, -3.0156],
[-1.6436, 6.8750, 0.4658, -0.1934, -5.1055],
[-3.6211, -0.2812, -6.6875, 0.4570, 5.0195],
[-0.1846, 5.4414, -8.9688, 0.8965, 1.7139],
[-2.2070, -1.3008, -0.2900, -1.2393, 3.0664],
[-8.5938, -1.3887, -1.7754, -0.7822, 7.1992],
[-3.2422, -6.6445, -5.7578, -3.4180, 0.9229],
[ 5.8789, 5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062, 5.3008, 0.5010, 8.3516, 3.0312],
[-2.4961, 4.1562, -7.9805, 7.2852, -2.0566],
[ 1.1074, 3.4453, 7.9727, -0.1055, 5.6250],
[ 0.5010, -3.4375, -4.5000, 6.5938, -1.9600],
[ 0.8350, 5.5977, 1.0283, -4.0859, 5.2734],
[-4.7188, 2.0742, 7.7695, -1.1250, 2.3828],
[ 1.6523, 8.7734, -2.8203, 3.9199, -3.9023],
[-2.0664, -1.4414, -6.8359, 8.4531, 0.6592],
[-7.6562, 5.7578, 3.8145, -0.2461, 0.7471],
[-2.8828, 2.4961, 8.5547, 1.9248, 0.5977]],
[[ 0.4043, -0.1406, -8.1797, 6.5820, -0.5889],
[-0.1143, -5.0273, 1.7842, -0.9932, -7.0938],
[-4.8438, -5.2109, -6.0117, -4.1641, 5.3711],
[ 2.5312, 2.2227, 6.4414, -5.8711, -8.8047],
[ 6.6875, -4.4219, 6.6523, 7.3477, 0.8701],
[-1.9600, 1.5732, 1.9512, 7.9531, -2.3125],
[ 5.7812, -5.1250, 5.3359, -2.2500, -0.0791],
[ 2.8477, 8.8281, -4.4727, 0.1318, -2.0469],
[-6.0547, 6.6641, -6.7422, -4.6406, 8.6719],
[-1.1338, -7.6016, 4.1641, -1.5205, 5.8281]],
[[ 7.9727, -2.9453, 5.9219, -5.6680, 4.7812],
[ 8.1562, -0.5977, -2.9355, 3.7441, 5.3789],
[ 2.3477, 7.3281, 5.0625, 3.6562, 1.4678],
[ 7.3906, 0.3691, 4.8789, -6.5039, -5.5195],
[-8.9062, 5.3281, 2.6445, 7.8828, 2.2148],
[-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
[-8.2188, -2.1875, -3.6914, 8.0938, -0.4834],
[ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
[ 7.9297, 1.7139, -1.7490, 4.3945, 5.4922],
[ 4.3242, 0.6504, -5.3789, 3.2969, 3.6836]],
[[-2.6016, 2.4258, -0.0352, -2.3203, -3.0664],
[ 3.4531, 7.6289, -4.9922, -0.9229, -7.3203],
[-2.5137, 6.0469, 0.8613, 8.8516, 4.8086],
[-0.4131, -0.1406, 8.7734, -8.1641, 5.8281],
[-3.1816, 1.6611, 6.4609, -8.6875, 6.8477],
[ 3.2344, 6.1016, 4.6055, -6.5547, -7.1016],
[-1.3887, -1.3359, 7.0039, 2.3906, 7.7344],
[-8.4609, -3.7695, 7.7266, -2.6016, 2.9961],
[-5.5195, -2.8652, 0.1582, -7.4531, 5.3711],
[-5.5273, -5.0977, 2.0391, 8.0312, -8.3438]],
[[-6.6797, -5.0078, 6.9883, -2.6992, -2.5488],
[ 3.3672, 4.9492, -1.5293, 4.4375, 3.2168],
[-7.7422, -1.2305, 0.5977, 0.8613, -5.5273],
[ 0.2109, -0.6768, -1.1777, 4.1133, 4.2461],
[ 4.9570, 1.3184, 1.2568, -4.0078, 0.3340],
[-8.0469, 6.0820, 0.3604, 1.6260, -0.7910],
[-5.5117, 1.9512, -3.4375, 1.1602, 1.4238],
[ 1.8809, -2.5664, -6.9453, 7.3984, -3.2266],
[-2.4082, -3.8398, 0.0703, 0.6416, -0.6240],
[-4.6836, -0.9844, 2.2148, 2.0117, 4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'} Expected outputexpected = tensor([[[ -0., 1., -15., -0., -1.],
[ 0., 2., -0., -0., 4.],
[ -1., 0., -0., -54., -0.],
[ 0., -0., -1., -0., 4.],
[ 2., 0., 5., 0., -1.],
[ -0., -2., -0., 7., 0.],
[ 3., 0., 1., -0., -1.],
[ -4., -2., -0., -0., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -3., 0., -0., 13.]],
[[ -3., 42., -0., -0., -2.],
[ -55., 0., 3., -8., 1.],
[ -1., 0., 0., -0., -0.],
[ 3., 1., 1., -0., -0.],
[ 0., 1., -0., 0., -3.],
[ 2., -0., 3., -0., -1.],
[ 0., 0., 1., 3., 88.],
[ 1., -0., 1., 51., -1.],
[ 1., -1., -0., 0., 0.],
[ 5., 0., -0., -2., -0.]],
[[ -0., -2., -0., 0., 0.],
[ -0., -2., -0., 1., -1.],
[ 0., 0., -1., -1., 0.],
[ 0., -12., -1., -1., -1.],
[ -0., -1., 1., -0., -2.],
[ 1., 1., -0., 1., 0.],
[ -0., 3., 1., -0., 1.],
[ 1., -53., -0., -0., 2.],
[ -1., 4., -5., 1., -0.],
[ -0., 4., -1., -0., 1.]],
[[ 1., 1., -235., 0., -1.],
[ -0., 0., 1., 5., -0.],
[ -3., -0., 1., 0., 0.],
[ -17., 48., 0., -0., 0.],
[ -1., -1., -0., -0., 1.],
[ -0., -0., -0., 0., -0.],
[ 0., -5., -1., -1., -1.],
[ 0., -1., -0., -3., -1.],
[ 0., 0., 49., 0., 0.],
[ -0., -0., -0., 0., 0.]],
[[ -1., 0., -1., -2., -2.],
[ -1., 0., -0., -1., -1.],
[ 0., -2., -10., 7., 0.],
[ -7., -10., -0., -0., -1.],
[ -0., -0., -5., -0., 15.],
[ 0., 0., -24., 0., -2.],
[ 0., -0., 0., -1., 2.],
[ -4., 0., 0., -0., -2.],
[ 1., 1., -81., -5., -1.],
[ -1., -5., -1., -3., -0.]]], dtype=torch.float16) Actual outputactual = tensor([[[ 0., 1., -15., 0., -1.],
[ 0., 2., 0., 0., 4.],
[ -1., 0., 0., -54., 0.],
[ 0., 0., -1., 0., 4.],
[ 2., 0., 5., 0., -1.],
[ 0., -2., 0., 7., 0.],
[ 3., 0., 1., 0., -1.],
[ -4., -2., 0., 0., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -3., 0., 0., 13.]],
[[ -3., 42., 0., 0., -2.],
[ -55., 0., 3., -8., 1.],
[ -1., 0., 0., 0., 0.],
[ 3., 1., 1., 0., 0.],
[ 0., 1., 0., 0., -3.],
[ 2., 0., 3., 0., -1.],
[ 0., 0., 1., 3., 88.],
[ 1., 0., 1., 51., -1.],
[ 1., -1., 0., 0., 0.],
[ 5., 0., 0., -2., 0.]],
[[ 0., -2., 0., 0., 0.],
[ 0., -2., 0., 1., -1.],
[ 0., 0., -1., -1., 0.],
[ 0., -12., -1., -1., -1.],
[ 0., -1., 1., 0., -2.],
[ 1., 1., 0., 1., 0.],
[ 0., 3., 1., 0., 1.],
[ 1., -53., 0., 0., 2.],
[ -1., 4., -5., 1., 0.],
[ 0., 4., -1., 0., 1.]],
[[ 1., 1., -235., 0., -1.],
[ 0., 0., 1., 5., 0.],
[ -3., 0., 1., 0., 0.],
[ -17., 48., 0., 0., 0.],
[ -1., -1., 0., 0., 1.],
[ 0., 0., 0., 0., 0.],
[ 0., -5., -1., -1., -1.],
[ 0., -1., 0., -3., -1.],
[ 0., 0., 48., 0., 0.],
[ 0., 0., 0., 0., 0.]],
[[ -1., 0., -1., -2., -2.],
[ -1., 0., 0., -1., -1.],
[ 0., -2., -10., 7., 0.],
[ -7., -10., 0., 0., -1.],
[ 0., 0., -5., 0., 15.],
[ 0., 0., -24., 0., -2.],
[ 0., 0., 0., -1., 2.],
[ -4., 0., 0., 0., -2.],
[ 1., 1., -81., -5., -1.],
[ -1., -5., -1., -3., 0.]]], dtype=torch.float16) Full error stack
|
CREATE_REPRODUCTION_REPORT=1 is nice! |
Infinities are flipped: SummaryThe output of ONNX Runtime does not match that of PyTorch when executing test To recreate this report, use CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16 InputsShapes: inputs = (tensor([[[ 1.5205, 5.7305, 1.4678, -7.5156, 3.4453]],
[[-0.6943, -2.0215, -3.4375, 1.6963, 4.5352]],
[[-8.7500, 1.4414, 3.6992, -6.0312, 5.6875]],
[[ 4.3516, 3.1992, 7.3125, 2.9961, -4.6133]],
[[ 5.7656, -4.5156, 2.8652, -7.1016, 2.6367]],
[[ 1.1514, 7.0586, 0.6064, 2.2070, 3.6387]],
[[ 6.9062, -5.0977, -1.3096, 1.5908, -7.5391]],
[[-8.2109, -2.7344, 7.5312, 4.6562, -8.6641]],
[[ 1.0986, 3.9551, -7.4883, -3.2695, -2.3828]],
[[ 1.9248, 5.5977, 0.5010, 5.7930, -4.8242]]], dtype=torch.float16), tensor([[-5.0000e+00, -1.7227e+00, -1.1426e+00, 3.1641e+00, 8.2109e+00],
[ 6.1035e-05, -5.6094e+00, 8.5469e+00, -5.9492e+00, -7.3398e+00],
[-4.5703e+00, 1.6699e+00, 4.4570e+00, 3.2344e+00, 3.5859e+00],
[-1.7666e+00, -8.4375e-01, -7.9980e-01, 1.0195e+00, -7.8320e+00],
[-8.5000e+00, 4.8789e+00, -2.9883e+00, -3.1016e+00, -7.4883e+00],
[ 7.2773e+00, -1.8457e+00, 1.3008e+00, -8.3047e+00, -1.2217e+00],
[ 1.8633e+00, -7.2266e+00, -2.2930e+00, 2.0039e+00, 5.9336e+00],
[ 1.9863e+00, 5.7812e+00, 4.6680e+00, 3.8672e+00, 3.5234e+00],
[-8.5938e+00, 2.0391e+00, 2.6719e+00, 4.6133e+00, 1.2480e+00],
[ 3.6387e+00, 4.5000e+00, 3.2695e+00, 4.9648e+00, 7.7773e+00]],
dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'} Expected outputexpected = tensor([[[-0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00, 0.0000e+00],
[ 2.4912e+04, -1.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
[-0.0000e+00, 3.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
[-0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, -0.0000e+00],
[-0.0000e+00, 1.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 1.0000e+00, 0.0000e+00, -2.0000e+00],
[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
[-0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, 1.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00],
[-1.1376e+04, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[ 0.0000e+00, -1.0000e+00, -0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 2.0000e+00, 4.0000e+00, 1.0000e+00, -0.0000e+00],
[ 0.0000e+00, -0.0000e+00, 1.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, 1.0000e+00, -2.0000e+00, -0.0000e+00, -3.0000e+00],
[-0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00, 3.0000e+00],
[-0.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00]],
[[ 1.0000e+00, -0.0000e+00, -3.0000e+00, -1.0000e+00, 0.0000e+00],
[ -inf, -0.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
[ 1.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
[ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, -0.0000e+00],
[ 1.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00, -0.0000e+00],
[-1.0000e+00, -0.0000e+00, 2.0000e+00, 0.0000e+00, -4.0000e+00],
[-4.0000e+00, -0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
[-4.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
[ 1.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 4.0000e+00],
[-2.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 0.0000e+00]],
[[-0.0000e+00, -1.0000e+00, -6.0000e+00, 0.0000e+00, -0.0000e+00],
[ inf, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
[-2.0000e+00, -3.0000e+00, -9.0000e+00, 2.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, -2.0000e+00, -0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -1.0000e+00, 5.0000e+00, -0.0000e+00, 3.0000e+00],
[ 2.0000e+00, -0.0000e+00, -3.0000e+00, 1.0000e+00, -0.0000e+00],
[ 2.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
[-0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00, -3.0000e+00],
[ 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -0.0000e+00]],
[[-1.0000e+00, 2.0000e+00, -2.0000e+00, -2.0000e+00, 0.0000e+00],
[ inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
[-1.0000e+00, -2.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
[-3.0000e+00, 5.0000e+00, -3.0000e+00, -6.0000e+00, -0.0000e+00],
[-0.0000e+00, -0.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, 2.0000e+00, 2.0000e+00, 0.0000e+00, -2.0000e+00],
[ 3.0000e+00, 0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
[ 2.0000e+00, -0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
[-0.0000e+00, -2.0000e+00, 1.0000e+00, -1.0000e+00, 2.0000e+00],
[ 1.0000e+00, -1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
[[-0.0000e+00, -4.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 1.8864e+04, -1.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, 4.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[-0.0000e+00, -8.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[-0.0000e+00, 1.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 0.0000e+00, -0.0000e+00, -2.0000e+00],
[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[-0.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-1.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00, -0.0000e+00],
[ inf, 0.0000e+00, -0.0000e+00, -0.0000e+00, 1.0000e+00],
[-1.0000e+00, -3.0000e+00, -0.0000e+00, 0.0000e+00, -2.0000e+00],
[-3.0000e+00, 6.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00],
[-0.0000e+00, -1.0000e+00, 0.0000e+00, -0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 2.0000e+00, -1.0000e+00, -0.0000e+00, 6.0000e+00],
[ 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00],
[ 3.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -2.0000e+00],
[-0.0000e+00, -2.0000e+00, -0.0000e+00, 0.0000e+00, -6.0000e+00],
[ 1.0000e+00, -1.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00]],
[[ 1.0000e+00, 1.0000e+00, -6.0000e+00, 1.0000e+00, -1.0000e+00],
[ -inf, 0.0000e+00, 0.0000e+00, -0.0000e+00, 1.0000e+00],
[ 1.0000e+00, -1.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 4.0000e+00, 3.0000e+00, -9.0000e+00, 4.0000e+00, 1.0000e+00],
[ 0.0000e+00, -0.0000e+00, -2.0000e+00, -1.0000e+00, 1.0000e+00],
[-1.0000e+00, 1.0000e+00, 5.0000e+00, -0.0000e+00, 7.0000e+00],
[-4.0000e+00, 0.0000e+00, -3.0000e+00, 2.0000e+00, -1.0000e+00],
[-4.0000e+00, -0.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 0.0000e+00, -1.0000e+00, 2.0000e+00, 1.0000e+00, -6.0000e+00],
[-2.0000e+00, -0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00]],
[[-0.0000e+00, -2.0000e+00, 6.0000e+00, -1.0000e+00, -0.0000e+00],
[ 1.8000e+04, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
[-0.0000e+00, 2.0000e+00, -1.0000e+00, -1.0000e+00, -0.0000e+00],
[-0.0000e+00, -4.0000e+00, 9.0000e+00, -3.0000e+00, 0.0000e+00],
[-0.0000e+00, 0.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -2.0000e+00, -5.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, -0.0000e+00, 3.0000e+00, -1.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.0000e+00, -0.0000e+00, -0.0000e+00],
[-0.0000e+00, 1.0000e+00, -2.0000e+00, -0.0000e+00, -1.0000e+00],
[ 0.0000e+00, 0.0000e+00, -2.0000e+00, -0.0000e+00, -0.0000e+00]],
[[-0.0000e+00, -3.0000e+00, -0.0000e+00, 1.0000e+00, -0.0000e+00],
[ 3.1536e+04, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
[-0.0000e+00, 3.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
[-1.0000e+00, -6.0000e+00, -0.0000e+00, 5.0000e+00, 0.0000e+00],
[-0.0000e+00, 1.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 0.0000e+00, -0.0000e+00, 3.0000e+00],
[ 1.0000e+00, -0.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
[-0.0000e+00, 2.0000e+00, 0.0000e+00, 1.0000e+00, -3.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00]]],
dtype=torch.float16) Shape: Actual outputactual = tensor([[[ 0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00, 0.0000e+00],
[ 2.4912e+04, -1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 3.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
[ 0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 1.0000e+00, 0.0000e+00, -2.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, 1.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00],
[-1.1376e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 2.0000e+00, 4.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, -2.0000e+00, 0.0000e+00, -3.0000e+00],
[ 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 3.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00]],
[[ 1.0000e+00, 0.0000e+00, -3.0000e+00, -1.0000e+00, 0.0000e+00],
[ inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 1.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
[ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, 0.0000e+00],
[ 1.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00, 0.0000e+00],
[-1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -4.0000e+00],
[-4.0000e+00, 0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
[-4.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
[ 1.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 4.0000e+00],
[-2.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, -1.0000e+00, -6.0000e+00, 0.0000e+00, 0.0000e+00],
[ -inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
[-2.0000e+00, -3.0000e+00, -9.0000e+00, 2.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -1.0000e+00, 5.0000e+00, 0.0000e+00, 3.0000e+00],
[ 2.0000e+00, 0.0000e+00, -3.0000e+00, 1.0000e+00, 0.0000e+00],
[ 2.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
[ 0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00, -3.0000e+00],
[ 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-1.0000e+00, 2.0000e+00, -2.0000e+00, -2.0000e+00, 0.0000e+00],
[ -inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[-1.0000e+00, -2.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
[-3.0000e+00, 5.0000e+00, -3.0000e+00, -6.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
[ 0.0000e+00, 2.0000e+00, 2.0000e+00, 0.0000e+00, -2.0000e+00],
[ 3.0000e+00, 0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
[ 2.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -2.0000e+00, 1.0000e+00, -1.0000e+00, 2.0000e+00],
[ 1.0000e+00, -1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, -4.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 1.8864e+04, -1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 4.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, -8.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[-1.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
[ -inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[-1.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
[-3.0000e+00, 6.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 2.0000e+00, -1.0000e+00, 0.0000e+00, 6.0000e+00],
[ 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00],
[ 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
[ 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00, -6.0000e+00],
[ 1.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[ 1.0000e+00, 1.0000e+00, -6.0000e+00, 1.0000e+00, -1.0000e+00],
[ inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
[ 1.0000e+00, -1.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 4.0000e+00, 3.0000e+00, -9.0000e+00, 4.0000e+00, 1.0000e+00],
[ 0.0000e+00, 0.0000e+00, -2.0000e+00, -1.0000e+00, 1.0000e+00],
[-1.0000e+00, 1.0000e+00, 5.0000e+00, 0.0000e+00, 7.0000e+00],
[-4.0000e+00, 0.0000e+00, -3.0000e+00, 2.0000e+00, -1.0000e+00],
[-4.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 0.0000e+00, -1.0000e+00, 2.0000e+00, 1.0000e+00, -6.0000e+00],
[-2.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00]],
[[ 0.0000e+00, -2.0000e+00, 6.0000e+00, -1.0000e+00, 0.0000e+00],
[ 1.8000e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 2.0000e+00, -1.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -4.0000e+00, 9.0000e+00, -3.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -2.0000e+00, -5.0000e+00, 0.0000e+00, 1.0000e+00],
[ 0.0000e+00, 0.0000e+00, 3.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, -2.0000e+00, 0.0000e+00, -1.0000e+00],
[ 0.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, -3.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 3.1536e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[ 0.0000e+00, 3.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
[-1.0000e+00, -6.0000e+00, 0.0000e+00, 5.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, 3.0000e+00],
[ 1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
[ 0.0000e+00, 2.0000e+00, 0.0000e+00, 1.0000e+00, -3.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00]]],
dtype=torch.float16) Shape: Difference--- actual
+++ expected
@@ -1,110 +1,110 @@
-tensor([[[ 0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00, 0.0000e+00],
- [ 2.4912e+04, -1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 3.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
+tensor([[[-0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00, 0.0000e+00],
+ [ 2.4912e+04, -1.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 3.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
+ [-0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, -3.0000e+00, 1.0000e+00, 0.0000e+00, -2.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -3.0000e+00, 0.0000e+00],
+ [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.0000e+00, 0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00, 2.0000e+00],
+ [-0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
[[ 0.0000e+00, 1.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00],
- [-1.1376e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 2.0000e+00, 4.0000e+00, 1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, -2.0000e+00, 0.0000e+00, -3.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 3.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00]],
+ [-1.1376e+04, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
+ [ 0.0000e+00, -1.0000e+00, -0.0000e+00, 0.0000e+00, 1.0000e+00],
+ [ 0.0000e+00, 2.0000e+00, 4.0000e+00, 1.0000e+00, -0.0000e+00],
+ [ 0.0000e+00, -0.0000e+00, 1.0000e+00, -0.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, -2.0000e+00, -0.0000e+00, -3.0000e+00],
+ [-0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 1.0000e+00],
+ [ 0.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00, 3.0000e+00],
+ [-0.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00]],
- [[ 1.0000e+00, 0.0000e+00, -3.0000e+00, -1.0000e+00, 0.0000e+00],
- [ inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
+ [[ 1.0000e+00, -0.0000e+00, -3.0000e+00, -1.0000e+00, 0.0000e+00],
+ [ -inf, -0.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
[ 1.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
- [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, 0.0000e+00],
- [ 1.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00, 0.0000e+00],
- [-1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -4.0000e+00],
- [-4.0000e+00, 0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
+ [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, -0.0000e+00],
+ [ 1.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00, -0.0000e+00],
+ [-1.0000e+00, -0.0000e+00, 2.0000e+00, 0.0000e+00, -4.0000e+00],
+ [-4.0000e+00, -0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
[-4.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 1.0000e+00],
[ 1.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 4.0000e+00],
[-2.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00, 0.0000e+00]],
- [[ 0.0000e+00, -1.0000e+00, -6.0000e+00, 0.0000e+00, 0.0000e+00],
- [ -inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
+ [[-0.0000e+00, -1.0000e+00, -6.0000e+00, 0.0000e+00, -0.0000e+00],
+ [ inf, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
[-2.0000e+00, -3.0000e+00, -9.0000e+00, 2.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -1.0000e+00, 5.0000e+00, 0.0000e+00, 3.0000e+00],
- [ 2.0000e+00, 0.0000e+00, -3.0000e+00, 1.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 0.0000e+00, -2.0000e+00, -0.0000e+00, 0.0000e+00],
+ [ 0.0000e+00, -1.0000e+00, 5.0000e+00, -0.0000e+00, 3.0000e+00],
+ [ 2.0000e+00, -0.0000e+00, -3.0000e+00, 1.0000e+00, -0.0000e+00],
[ 2.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00, -3.0000e+00],
- [ 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, 0.0000e+00]],
+ [-0.0000e+00, 1.0000e+00, 2.0000e+00, 0.0000e+00, -3.0000e+00],
+ [ 1.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -0.0000e+00]],
[[-1.0000e+00, 2.0000e+00, -2.0000e+00, -2.0000e+00, 0.0000e+00],
- [ -inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
+ [ inf, 0.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00],
[-1.0000e+00, -2.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00],
- [-3.0000e+00, 5.0000e+00, -3.0000e+00, -6.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
+ [-3.0000e+00, 5.0000e+00, -3.0000e+00, -6.0000e+00, -0.0000e+00],
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, 2.0000e+00, 2.0000e+00, 0.0000e+00, -2.0000e+00],
[ 3.0000e+00, 0.0000e+00, -1.0000e+00, -3.0000e+00, 0.0000e+00],
- [ 2.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -2.0000e+00, 1.0000e+00, -1.0000e+00, 2.0000e+00],
+ [ 2.0000e+00, -0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
+ [-0.0000e+00, -2.0000e+00, 1.0000e+00, -1.0000e+00, 2.0000e+00],
[ 1.0000e+00, -1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00]],
- [[ 0.0000e+00, -4.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 1.8864e+04, -1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 4.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, -8.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
+ [[-0.0000e+00, -4.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
+ [ 1.8864e+04, -1.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 4.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
+ [-0.0000e+00, -8.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
+ [ 0.0000e+00, -3.0000e+00, 0.0000e+00, -0.0000e+00, -2.0000e+00],
+ [ 0.0000e+00, -0.0000e+00, -0.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00],
+ [-0.0000e+00, 3.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00],
[ 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
- [[-1.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00],
- [ -inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [-1.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
+ [[-1.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00, -0.0000e+00],
+ [ inf, 0.0000e+00, -0.0000e+00, -0.0000e+00, 1.0000e+00],
+ [-1.0000e+00, -3.0000e+00, -0.0000e+00, 0.0000e+00, -2.0000e+00],
[-3.0000e+00, 6.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 2.0000e+00, -1.0000e+00, 0.0000e+00, 6.0000e+00],
+ [-0.0000e+00, -1.0000e+00, 0.0000e+00, -0.0000e+00, 1.0000e+00],
+ [ 0.0000e+00, 2.0000e+00, -1.0000e+00, -0.0000e+00, 6.0000e+00],
[ 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -1.0000e+00],
- [ 3.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -2.0000e+00],
- [ 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00, -6.0000e+00],
- [ 1.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
+ [ 3.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -2.0000e+00],
+ [-0.0000e+00, -2.0000e+00, -0.0000e+00, 0.0000e+00, -6.0000e+00],
+ [ 1.0000e+00, -1.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00]],
[[ 1.0000e+00, 1.0000e+00, -6.0000e+00, 1.0000e+00, -1.0000e+00],
- [ inf, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00],
+ [ -inf, 0.0000e+00, 0.0000e+00, -0.0000e+00, 1.0000e+00],
[ 1.0000e+00, -1.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 4.0000e+00, 3.0000e+00, -9.0000e+00, 4.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -2.0000e+00, -1.0000e+00, 1.0000e+00],
- [-1.0000e+00, 1.0000e+00, 5.0000e+00, 0.0000e+00, 7.0000e+00],
+ [ 0.0000e+00, -0.0000e+00, -2.0000e+00, -1.0000e+00, 1.0000e+00],
+ [-1.0000e+00, 1.0000e+00, 5.0000e+00, -0.0000e+00, 7.0000e+00],
[-4.0000e+00, 0.0000e+00, -3.0000e+00, 2.0000e+00, -1.0000e+00],
- [-4.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
+ [-4.0000e+00, -0.0000e+00, 1.0000e+00, 1.0000e+00, -2.0000e+00],
[ 0.0000e+00, -1.0000e+00, 2.0000e+00, 1.0000e+00, -6.0000e+00],
- [-2.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00]],
+ [-2.0000e+00, -0.0000e+00, 2.0000e+00, 0.0000e+00, -1.0000e+00]],
- [[ 0.0000e+00, -2.0000e+00, 6.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 1.8000e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 2.0000e+00, -1.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -4.0000e+00, 9.0000e+00, -3.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00],
+ [[-0.0000e+00, -2.0000e+00, 6.0000e+00, -1.0000e+00, -0.0000e+00],
+ [ 1.8000e+04, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 2.0000e+00, -1.0000e+00, -1.0000e+00, -0.0000e+00],
+ [-0.0000e+00, -4.0000e+00, 9.0000e+00, -3.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 0.0000e+00, 2.0000e+00, 1.0000e+00, 0.0000e+00],
[ 0.0000e+00, -2.0000e+00, -5.0000e+00, 0.0000e+00, 1.0000e+00],
- [ 0.0000e+00, 0.0000e+00, 3.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, -2.0000e+00, 0.0000e+00, -1.0000e+00],
- [ 0.0000e+00, 0.0000e+00, -2.0000e+00, 0.0000e+00, 0.0000e+00]],
+ [ 0.0000e+00, -0.0000e+00, 3.0000e+00, -1.0000e+00, -0.0000e+00],
+ [ 0.0000e+00, 0.0000e+00, -1.0000e+00, -0.0000e+00, -0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, -2.0000e+00, -0.0000e+00, -1.0000e+00],
+ [ 0.0000e+00, 0.0000e+00, -2.0000e+00, -0.0000e+00, -0.0000e+00]],
- [[ 0.0000e+00, -3.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00],
- [ 3.1536e+04, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 3.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
- [-1.0000e+00, -6.0000e+00, 0.0000e+00, 5.0000e+00, 0.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 0.0000e+00, -1.0000e+00, 0.0000e+00],
- [ 0.0000e+00, -3.0000e+00, 0.0000e+00, 0.0000e+00, 3.0000e+00],
- [ 1.0000e+00, 0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00],
+ [[-0.0000e+00, -3.0000e+00, -0.0000e+00, 1.0000e+00, -0.0000e+00],
+ [ 3.1536e+04, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 3.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
+ [-1.0000e+00, -6.0000e+00, -0.0000e+00, 5.0000e+00, 0.0000e+00],
+ [-0.0000e+00, 1.0000e+00, -0.0000e+00, -1.0000e+00, 0.0000e+00],
+ [ 0.0000e+00, -3.0000e+00, 0.0000e+00, -0.0000e+00, 3.0000e+00],
+ [ 1.0000e+00, -0.0000e+00, -0.0000e+00, 2.0000e+00, -0.0000e+00],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, -1.0000e+00],
- [ 0.0000e+00, 2.0000e+00, 0.0000e+00, 1.0000e+00, -3.0000e+00],
- [ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00]]],
+ [-0.0000e+00, 2.0000e+00, 0.0000e+00, 1.0000e+00, -3.0000e+00],
+ [ 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, -0.0000e+00]]],
dtype=torch.float16) Full error stack
|
|
aten_trunc(fp16(-143360.3670025395)) is -inf. Puzzling |
SummaryThe output of ONNX Runtime does not match that of PyTorch when executing test To recreate this report, use CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_floor_rounding_cpu_float16 InputsShapes: inputs = (tensor([[[-6.6953, 6.7070, -7.7422, -5.9688, -3.5234],
[-0.0703, 8.8828, 5.0195, -1.2656, -8.9922],
[-1.1250, 2.6641, -6.7773, 5.7227, -5.0273],
[ 0.3164, 1.3975, 8.9453, -0.0879, -8.9297],
[ 2.0391, 0.8613, 5.7812, -1.8369, -7.1797],
[ 1.1426, -5.4570, -7.6719, -8.5312, 1.0459],
[ 5.9062, 8.0781, -3.4883, -3.8496, 5.4844],
[ 8.5000, 3.6836, 3.6992, -8.0938, 7.4805],
[-6.7773, 1.9863, 0.9756, -4.5273, 4.1484],
[-8.1406, -7.6641, 7.5586, -0.9756, 7.8672]],
[[-1.2920, -6.0391, 6.1953, -1.7842, 1.4238],
[ 6.3203, -0.5977, 5.9766, 8.7422, -7.5938],
[ 5.0898, -1.9688, -5.5117, 2.4258, -1.8369],
[ 8.0391, 4.4219, 7.8672, 3.4375, 1.1777],
[ 1.7051, -5.4844, -3.3828, 0.2812, -2.9609],
[-4.9648, -0.6152, 7.7500, -4.8789, 3.2871],
[ 0.1846, -2.4180, 5.8789, -8.6719, -6.9883],
[ 3.1992, -4.0625, -5.1602, 6.8125, 2.4785],
[-8.7734, -7.0234, 2.4258, -1.9951, 5.3438],
[-6.5117, -1.2832, -0.5713, 4.2188, -5.0273]],
[[-0.0439, 7.2773, -1.8369, -3.2168, 4.5273],
[-4.2031, 1.6787, 0.4219, 4.4922, -8.8594],
[ 1.5029, 1.9248, -7.6211, -4.8164, 0.3955],
[ 3.2695, -4.4570, -8.1406, 7.8320, 6.5391],
[ 0.3164, -6.5469, 3.0586, -4.6406, -6.3555],
[-7.4180, -3.6738, 0.8613, -2.8555, -0.2812],
[ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
[ 8.6328, 1.4062, 2.2422, 2.8301, -7.8828],
[-8.8516, 6.9609, 8.7969, 6.4531, -4.0352],
[-4.0000, 2.8301, 7.6562, -2.9805, 6.9531]],
[[-3.5078, 2.5234, 8.2812, -1.5029, 5.8203],
[-0.5537, 5.6094, -7.6484, -4.7031, 3.3828],
[ 8.6562, -2.1094, 0.9053, 8.1562, 3.4453],
[ 7.1172, -6.8477, 1.5381, 0.3340, 4.1836],
[ 5.9844, -2.3379, -4.2812, 7.6094, 7.1797],
[-0.9141, -5.4219, -4.3945, -1.7314, 4.4375],
[-1.0723, 7.2422, -8.6953, -2.9883, -8.8359],
[-6.1875, 6.8125, -4.5078, 8.8359, -5.4922],
[-3.2617, -0.4658, 7.7500, -4.8672, 1.2480],
[ 0.7998, 3.1016, -1.6787, 7.0234, -2.6191]],
[[ 8.4219, -1.5029, -8.5625, 5.8359, 6.1250],
[-3.9551, 1.6611, 1.0898, -7.6914, -6.2500],
[-4.0000, 2.9961, -6.3281, 6.0742, -3.0156],
[-1.6436, 6.8750, 0.4658, -0.1934, -5.1055],
[-3.6211, -0.2812, -6.6875, 0.4570, 5.0195],
[-0.1846, 5.4414, -8.9688, 0.8965, 1.7139],
[-2.2070, -1.3008, -0.2900, -1.2393, 3.0664],
[-8.5938, -1.3887, -1.7754, -0.7822, 7.1992],
[-3.2422, -6.6445, -5.7578, -3.4180, 0.9229],
[ 5.8789, 5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062, 5.3008, 0.5010, 8.3516, 3.0312],
[-2.4961, 4.1562, -7.9805, 7.2852, -2.0566],
[ 1.1074, 3.4453, 7.9727, -0.1055, 5.6250],
[ 0.5010, -3.4375, -4.5000, 6.5938, -1.9600],
[ 0.8350, 5.5977, 1.0283, -4.0859, 5.2734],
[-4.7188, 2.0742, 7.7695, -1.1250, 2.3828],
[ 1.6523, 8.7734, -2.8203, 3.9199, -3.9023],
[-2.0664, -1.4414, -6.8359, 8.4531, 0.6592],
[-7.6562, 5.7578, 3.8145, -0.2461, 0.7471],
[-2.8828, 2.4961, 8.5547, 1.9248, 0.5977]],
[[ 0.4043, -0.1406, -8.1797, 6.5820, -0.5889],
[-0.1143, -5.0273, 1.7842, -0.9932, -7.0938],
[-4.8438, -5.2109, -6.0117, -4.1641, 5.3711],
[ 2.5312, 2.2227, 6.4414, -5.8711, -8.8047],
[ 6.6875, -4.4219, 6.6523, 7.3477, 0.8701],
[-1.9600, 1.5732, 1.9512, 7.9531, -2.3125],
[ 5.7812, -5.1250, 5.3359, -2.2500, -0.0791],
[ 2.8477, 8.8281, -4.4727, 0.1318, -2.0469],
[-6.0547, 6.6641, -6.7422, -4.6406, 8.6719],
[-1.1338, -7.6016, 4.1641, -1.5205, 5.8281]],
[[ 7.9727, -2.9453, 5.9219, -5.6680, 4.7812],
[ 8.1562, -0.5977, -2.9355, 3.7441, 5.3789],
[ 2.3477, 7.3281, 5.0625, 3.6562, 1.4678],
[ 7.3906, 0.3691, 4.8789, -6.5039, -5.5195],
[-8.9062, 5.3281, 2.6445, 7.8828, 2.2148],
[-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
[-8.2188, -2.1875, -3.6914, 8.0938, -0.4834],
[ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
[ 7.9297, 1.7139, -1.7490, 4.3945, 5.4922],
[ 4.3242, 0.6504, -5.3789, 3.2969, 3.6836]],
[[-2.6016, 2.4258, -0.0352, -2.3203, -3.0664],
[ 3.4531, 7.6289, -4.9922, -0.9229, -7.3203],
[-2.5137, 6.0469, 0.8613, 8.8516, 4.8086],
[-0.4131, -0.1406, 8.7734, -8.1641, 5.8281],
[-3.1816, 1.6611, 6.4609, -8.6875, 6.8477],
[ 3.2344, 6.1016, 4.6055, -6.5547, -7.1016],
[-1.3887, -1.3359, 7.0039, 2.3906, 7.7344],
[-8.4609, -3.7695, 7.7266, -2.6016, 2.9961],
[-5.5195, -2.8652, 0.1582, -7.4531, 5.3711],
[-5.5273, -5.0977, 2.0391, 8.0312, -8.3438]],
[[-6.6797, -5.0078, 6.9883, -2.6992, -2.5488],
[ 3.3672, 4.9492, -1.5293, 4.4375, 3.2168],
[-7.7422, -1.2305, 0.5977, 0.8613, -5.5273],
[ 0.2109, -0.6768, -1.1777, 4.1133, 4.2461],
[ 4.9570, 1.3184, 1.2568, -4.0078, 0.3340],
[-8.0469, 6.0820, 0.3604, 1.6260, -0.7910],
[-5.5117, 1.9512, -3.4375, 1.1602, 1.4238],
[ 1.8809, -2.5664, -6.9453, 7.3984, -3.2266],
[-2.4082, -3.8398, 0.0703, 0.6416, -0.6240],
[-4.6836, -0.9844, 2.2148, 2.0117, 4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'floor'} Expected outputexpected = tensor([[[ -1., 1., -16., -1., -2.],
[ 0., 2., -1., -1., 4.],
[ -2., 0., -1., -55., -1.],
[ 0., -1., -2., -1., 4.],
[ 2., 0., 5., 0., -2.],
[ -1., -3., -1., 7., 0.],
[ 3., 0., 1., -1., -2.],
[ -5., -3., -1., -1., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -4., 0., -1., 13.]],
[[ -4., 42., -1., -1., -3.],
[ -56., 0., 3., -9., 1.],
[ -2., 0., 0., -1., -1.],
[ 3., 1., 1., -1., -1.],
[ 0., 1., -1., 0., -4.],
[ 2., -1., 3., -1., -2.],
[ 0., 0., 1., 3., 88.],
[ 1., -1., 1., 51., -2.],
[ 1., -2., -1., 0., 0.],
[ 5., 0., -1., -3., -1.]],
[[ -1., -3., -1., 0., 0.],
[ -1., -3., -1., 1., -2.],
[ 0., 0., -2., -2., 0.],
[ 0., -13., -2., -2., -2.],
[ -1., -2., 1., -1., -3.],
[ 1., 1., -1., 1., 0.],
[ -1., 3., 1., -1., 1.],
[ 1., -54., -1., -1., 2.],
[ -2., 4., -6., 1., -1.],
[ -1., 4., -2., -1., 1.]],
[[ 1., 1., -236., 0., -2.],
[ -1., 0., 1., 5., -1.],
[ -4., -1., 1., 0., 0.],
[ -18., 48., 0., -1., 0.],
[ -2., -2., -1., -1., 1.],
[ -1., -1., -1., 0., -1.],
[ 0., -6., -2., -2., -2.],
[ 0., -2., -1., -4., -2.],
[ 0., 0., 48., 0., 0.],
[ -1., -1., -1., 0., 0.]],
[[ -2., 0., -2., -3., -3.],
[ -2., 0., -1., -2., -2.],
[ 0., -3., -11., 7., 0.],
[ -8., -11., -1., -1., -2.],
[ -1., -1., -6., -1., 15.],
[ 0., 0., -25., 0., -3.],
[ 0., -1., 0., -2., 2.],
[ -5., 0., 0., -1., -3.],
[ 1., 1., -82., -6., -2.],
[ -2., -6., -2., -4., -1.]]], dtype=torch.float16) Shape: Actual outputactual = tensor([[[ -1., 1., -16., -1., -2.],
[ 0., 2., -1., -1., 4.],
[ -2., 0., -1., -55., -1.],
[ 0., -1., -2., -1., 4.],
[ 2., 0., 5., 0., -2.],
[ -1., -3., -1., 7., 0.],
[ 3., 0., 1., -1., -2.],
[ -5., -3., -1., -1., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -4., 0., -1., 13.]],
[[ -4., 42., -1., -1., -3.],
[ -56., 0., 3., -9., 1.],
[ -2., 0., 0., -1., -1.],
[ 3., 1., 1., -1., -1.],
[ 0., 1., -1., 0., -4.],
[ 2., -1., 3., -1., -2.],
[ 0., 0., 1., 3., 88.],
[ 1., -1., 1., 51., -2.],
[ 1., -2., -1., 0., 0.],
[ 5., 0., -1., -3., -1.]],
[[ -1., -3., -1., 0., 0.],
[ -1., -3., -1., 1., -2.],
[ 0., 0., -2., -2., 0.],
[ 0., -13., -2., -2., -2.],
[ -1., -2., 1., -1., -3.],
[ 1., 1., -1., 1., 0.],
[ -1., 3., 1., -1., 1.],
[ 1., -54., -1., -1., 2.],
[ -2., 4., -6., 1., -1.],
[ -1., 4., -2., -1., 1.]],
[[ 1., 1., -236., 0., -2.],
[ -1., 0., 1., 5., -1.],
[ -4., -1., 1., 0., 0.],
[ -18., 48., 0., -1., 0.],
[ -2., -2., -1., -1., 1.],
[ -1., -1., -1., 0., -1.],
[ 0., -6., -2., -2., -2.],
[ 0., -2., -1., -4., -2.],
[ 0., 0., 49., 0., 0.],
[ -1., -1., -1., 0., 0.]],
[[ -2., 0., -2., -3., -3.],
[ -2., 0., -1., -2., -2.],
[ 0., -3., -11., 7., 0.],
[ -8., -11., -1., -1., -2.],
[ -1., -1., -6., -1., 15.],
[ 0., 0., -25., 0., -3.],
[ 0., -1., 0., -2., 2.],
[ -5., 0., 0., -1., -3.],
[ 1., 1., -82., -6., -2.],
[ -2., -6., -2., -4., -1.]]], dtype=torch.float16) Shape: Difference--- actual
+++ expected
@@ -39,7 +39,7 @@
[ -1., -1., -1., 0., -1.],
[ 0., -6., -2., -2., -2.],
[ 0., -2., -1., -4., -2.],
- [ 0., 0., 49., 0., 0.],
+ [ 0., 0., 48., 0., 0.],
[ -1., -1., -1., 0., 0.]],
[[ -2., 0., -2., -3., -3.], Full error stack
|
numpy decides |
After casting to float32 SummaryThe output of ONNX Runtime does not match that of PyTorch when executing test To recreate this report, use CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16 InputsShapes: inputs = (tensor([[[-6.6953, 6.7070, -7.7422, -5.9688, -3.5234],
[-0.0703, 8.8828, 5.0195, -1.2656, -8.9922],
[-1.1250, 2.6641, -6.7773, 5.7227, -5.0273],
[ 0.3164, 1.3975, 8.9453, -0.0879, -8.9297],
[ 2.0391, 0.8613, 5.7812, -1.8369, -7.1797],
[ 1.1426, -5.4570, -7.6719, -8.5312, 1.0459],
[ 5.9062, 8.0781, -3.4883, -3.8496, 5.4844],
[ 8.5000, 3.6836, 3.6992, -8.0938, 7.4805],
[-6.7773, 1.9863, 0.9756, -4.5273, 4.1484],
[-8.1406, -7.6641, 7.5586, -0.9756, 7.8672]],
[[-1.2920, -6.0391, 6.1953, -1.7842, 1.4238],
[ 6.3203, -0.5977, 5.9766, 8.7422, -7.5938],
[ 5.0898, -1.9688, -5.5117, 2.4258, -1.8369],
[ 8.0391, 4.4219, 7.8672, 3.4375, 1.1777],
[ 1.7051, -5.4844, -3.3828, 0.2812, -2.9609],
[-4.9648, -0.6152, 7.7500, -4.8789, 3.2871],
[ 0.1846, -2.4180, 5.8789, -8.6719, -6.9883],
[ 3.1992, -4.0625, -5.1602, 6.8125, 2.4785],
[-8.7734, -7.0234, 2.4258, -1.9951, 5.3438],
[-6.5117, -1.2832, -0.5713, 4.2188, -5.0273]],
[[-0.0439, 7.2773, -1.8369, -3.2168, 4.5273],
[-4.2031, 1.6787, 0.4219, 4.4922, -8.8594],
[ 1.5029, 1.9248, -7.6211, -4.8164, 0.3955],
[ 3.2695, -4.4570, -8.1406, 7.8320, 6.5391],
[ 0.3164, -6.5469, 3.0586, -4.6406, -6.3555],
[-7.4180, -3.6738, 0.8613, -2.8555, -0.2812],
[ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
[ 8.6328, 1.4062, 2.2422, 2.8301, -7.8828],
[-8.8516, 6.9609, 8.7969, 6.4531, -4.0352],
[-4.0000, 2.8301, 7.6562, -2.9805, 6.9531]],
[[-3.5078, 2.5234, 8.2812, -1.5029, 5.8203],
[-0.5537, 5.6094, -7.6484, -4.7031, 3.3828],
[ 8.6562, -2.1094, 0.9053, 8.1562, 3.4453],
[ 7.1172, -6.8477, 1.5381, 0.3340, 4.1836],
[ 5.9844, -2.3379, -4.2812, 7.6094, 7.1797],
[-0.9141, -5.4219, -4.3945, -1.7314, 4.4375],
[-1.0723, 7.2422, -8.6953, -2.9883, -8.8359],
[-6.1875, 6.8125, -4.5078, 8.8359, -5.4922],
[-3.2617, -0.4658, 7.7500, -4.8672, 1.2480],
[ 0.7998, 3.1016, -1.6787, 7.0234, -2.6191]],
[[ 8.4219, -1.5029, -8.5625, 5.8359, 6.1250],
[-3.9551, 1.6611, 1.0898, -7.6914, -6.2500],
[-4.0000, 2.9961, -6.3281, 6.0742, -3.0156],
[-1.6436, 6.8750, 0.4658, -0.1934, -5.1055],
[-3.6211, -0.2812, -6.6875, 0.4570, 5.0195],
[-0.1846, 5.4414, -8.9688, 0.8965, 1.7139],
[-2.2070, -1.3008, -0.2900, -1.2393, 3.0664],
[-8.5938, -1.3887, -1.7754, -0.7822, 7.1992],
[-3.2422, -6.6445, -5.7578, -3.4180, 0.9229],
[ 5.8789, 5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062, 5.3008, 0.5010, 8.3516, 3.0312],
[-2.4961, 4.1562, -7.9805, 7.2852, -2.0566],
[ 1.1074, 3.4453, 7.9727, -0.1055, 5.6250],
[ 0.5010, -3.4375, -4.5000, 6.5938, -1.9600],
[ 0.8350, 5.5977, 1.0283, -4.0859, 5.2734],
[-4.7188, 2.0742, 7.7695, -1.1250, 2.3828],
[ 1.6523, 8.7734, -2.8203, 3.9199, -3.9023],
[-2.0664, -1.4414, -6.8359, 8.4531, 0.6592],
[-7.6562, 5.7578, 3.8145, -0.2461, 0.7471],
[-2.8828, 2.4961, 8.5547, 1.9248, 0.5977]],
[[ 0.4043, -0.1406, -8.1797, 6.5820, -0.5889],
[-0.1143, -5.0273, 1.7842, -0.9932, -7.0938],
[-4.8438, -5.2109, -6.0117, -4.1641, 5.3711],
[ 2.5312, 2.2227, 6.4414, -5.8711, -8.8047],
[ 6.6875, -4.4219, 6.6523, 7.3477, 0.8701],
[-1.9600, 1.5732, 1.9512, 7.9531, -2.3125],
[ 5.7812, -5.1250, 5.3359, -2.2500, -0.0791],
[ 2.8477, 8.8281, -4.4727, 0.1318, -2.0469],
[-6.0547, 6.6641, -6.7422, -4.6406, 8.6719],
[-1.1338, -7.6016, 4.1641, -1.5205, 5.8281]],
[[ 7.9727, -2.9453, 5.9219, -5.6680, 4.7812],
[ 8.1562, -0.5977, -2.9355, 3.7441, 5.3789],
[ 2.3477, 7.3281, 5.0625, 3.6562, 1.4678],
[ 7.3906, 0.3691, 4.8789, -6.5039, -5.5195],
[-8.9062, 5.3281, 2.6445, 7.8828, 2.2148],
[-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
[-8.2188, -2.1875, -3.6914, 8.0938, -0.4834],
[ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
[ 7.9297, 1.7139, -1.7490, 4.3945, 5.4922],
[ 4.3242, 0.6504, -5.3789, 3.2969, 3.6836]],
[[-2.6016, 2.4258, -0.0352, -2.3203, -3.0664],
[ 3.4531, 7.6289, -4.9922, -0.9229, -7.3203],
[-2.5137, 6.0469, 0.8613, 8.8516, 4.8086],
[-0.4131, -0.1406, 8.7734, -8.1641, 5.8281],
[-3.1816, 1.6611, 6.4609, -8.6875, 6.8477],
[ 3.2344, 6.1016, 4.6055, -6.5547, -7.1016],
[-1.3887, -1.3359, 7.0039, 2.3906, 7.7344],
[-8.4609, -3.7695, 7.7266, -2.6016, 2.9961],
[-5.5195, -2.8652, 0.1582, -7.4531, 5.3711],
[-5.5273, -5.0977, 2.0391, 8.0312, -8.3438]],
[[-6.6797, -5.0078, 6.9883, -2.6992, -2.5488],
[ 3.3672, 4.9492, -1.5293, 4.4375, 3.2168],
[-7.7422, -1.2305, 0.5977, 0.8613, -5.5273],
[ 0.2109, -0.6768, -1.1777, 4.1133, 4.2461],
[ 4.9570, 1.3184, 1.2568, -4.0078, 0.3340],
[-8.0469, 6.0820, 0.3604, 1.6260, -0.7910],
[-5.5117, 1.9512, -3.4375, 1.1602, 1.4238],
[ 1.8809, -2.5664, -6.9453, 7.3984, -3.2266],
[-2.4082, -3.8398, 0.0703, 0.6416, -0.6240],
[-4.6836, -0.9844, 2.2148, 2.0117, 4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'} Expected outputexpected = tensor([[[ -0., 1., -15., -0., -1.],
[ 0., 2., -0., -0., 4.],
[ -1., 0., -0., -54., -0.],
[ 0., -0., -1., -0., 4.],
[ 2., 0., 5., 0., -1.],
[ -0., -2., -0., 7., 0.],
[ 3., 0., 1., -0., -1.],
[ -4., -2., -0., -0., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -3., 0., -0., 13.]],
[[ -3., 42., -0., -0., -2.],
[ -55., 0., 3., -8., 1.],
[ -1., 0., 0., -0., -0.],
[ 3., 1., 1., -0., -0.],
[ 0., 1., -0., 0., -3.],
[ 2., -0., 3., -0., -1.],
[ 0., 0., 1., 3., 88.],
[ 1., -0., 1., 51., -1.],
[ 1., -1., -0., 0., 0.],
[ 5., 0., -0., -2., -0.]],
[[ -0., -2., -0., 0., 0.],
[ -0., -2., -0., 1., -1.],
[ 0., 0., -1., -1., 0.],
[ 0., -12., -1., -1., -1.],
[ -0., -1., 1., -0., -2.],
[ 1., 1., -0., 1., 0.],
[ -0., 3., 1., -0., 1.],
[ 1., -53., -0., -0., 2.],
[ -1., 4., -5., 1., -0.],
[ -0., 4., -1., -0., 1.]],
[[ 1., 1., -235., 0., -1.],
[ -0., 0., 1., 5., -0.],
[ -3., -0., 1., 0., 0.],
[ -17., 48., 0., -0., 0.],
[ -1., -1., -0., -0., 1.],
[ -0., -0., -0., 0., -0.],
[ 0., -5., -1., -1., -1.],
[ 0., -1., -0., -3., -1.],
[ 0., 0., 49., 0., 0.],
[ -0., -0., -0., 0., 0.]],
[[ -1., 0., -1., -2., -2.],
[ -1., 0., -0., -1., -1.],
[ 0., -2., -10., 7., 0.],
[ -7., -10., -0., -0., -1.],
[ -0., -0., -5., -0., 15.],
[ 0., 0., -24., 0., -2.],
[ 0., -0., 0., -1., 2.],
[ -4., 0., 0., -0., -2.],
[ 1., 1., -81., -5., -1.],
[ -1., -5., -1., -3., -0.]]], dtype=torch.float16) Shape: Actual outputactual = tensor([[[ 0., 1., -15., 0., -1.],
[ 0., 2., 0., 0., 4.],
[ -1., 0., 0., -54., 0.],
[ 0., 0., -1., 0., 4.],
[ 2., 0., 5., 0., -1.],
[ 0., -2., 0., 7., 0.],
[ 3., 0., 1., 0., -1.],
[ -4., -2., 0., 0., 11.],
[ 0., 0., 0., 18., 5.],
[ 2., -3., 0., 0., 13.]],
[[ -3., 42., 0., 0., -2.],
[ -55., 0., 3., -8., 1.],
[ -1., 0., 0., 0., 0.],
[ 3., 1., 1., 0., 0.],
[ 0., 1., 0., 0., -3.],
[ 2., 0., 3., 0., -1.],
[ 0., 0., 1., 3., 88.],
[ 1., 0., 1., 51., -1.],
[ 1., -1., 0., 0., 0.],
[ 5., 0., 0., -2., 0.]],
[[ 0., -2., 0., 0., 0.],
[ 0., -2., 0., 1., -1.],
[ 0., 0., -1., -1., 0.],
[ 0., -12., -1., -1., -1.],
[ 0., -1., 1., 0., -2.],
[ 1., 1., 0., 1., 0.],
[ 0., 3., 1., 0., 1.],
[ 1., -53., 0., 0., 2.],
[ -1., 4., -5., 1., 0.],
[ 0., 4., -1., 0., 1.]],
[[ 1., 1., -235., 0., -1.],
[ 0., 0., 1., 5., 0.],
[ -3., 0., 1., 0., 0.],
[ -17., 48., 0., 0., 0.],
[ -1., -1., 0., 0., 1.],
[ 0., 0., 0., 0., 0.],
[ 0., -5., -1., -1., -1.],
[ 0., -1., 0., -3., -1.],
[ 0., 0., 48., 0., 0.],
[ 0., 0., 0., 0., 0.]],
[[ -1., 0., -1., -2., -2.],
[ -1., 0., 0., -1., -1.],
[ 0., -2., -10., 7., 0.],
[ -7., -10., 0., 0., -1.],
[ 0., 0., -5., 0., 15.],
[ 0., 0., -24., 0., -2.],
[ 0., 0., 0., -1., 2.],
[ -4., 0., 0., 0., -2.],
[ 1., 1., -81., -5., -1.],
[ -1., -5., -1., -3., 0.]]], dtype=torch.float16) Shape: Difference--- actual
+++ expected
@@ -1,54 +1,54 @@
-tensor([[[ 0., 1., -15., 0., -1.],
- [ 0., 2., 0., 0., 4.],
- [ -1., 0., 0., -54., 0.],
- [ 0., 0., -1., 0., 4.],
+tensor([[[ -0., 1., -15., -0., -1.],
+ [ 0., 2., -0., -0., 4.],
+ [ -1., 0., -0., -54., -0.],
+ [ 0., -0., -1., -0., 4.],
[ 2., 0., 5., 0., -1.],
- [ 0., -2., 0., 7., 0.],
- [ 3., 0., 1., 0., -1.],
- [ -4., -2., 0., 0., 11.],
+ [ -0., -2., -0., 7., 0.],
+ [ 3., 0., 1., -0., -1.],
+ [ -4., -2., -0., -0., 11.],
[ 0., 0., 0., 18., 5.],
- [ 2., -3., 0., 0., 13.]],
+ [ 2., -3., 0., -0., 13.]],
- [[ -3., 42., 0., 0., -2.],
+ [[ -3., 42., -0., -0., -2.],
[ -55., 0., 3., -8., 1.],
- [ -1., 0., 0., 0., 0.],
- [ 3., 1., 1., 0., 0.],
- [ 0., 1., 0., 0., -3.],
- [ 2., 0., 3., 0., -1.],
+ [ -1., 0., 0., -0., -0.],
+ [ 3., 1., 1., -0., -0.],
+ [ 0., 1., -0., 0., -3.],
+ [ 2., -0., 3., -0., -1.],
[ 0., 0., 1., 3., 88.],
- [ 1., 0., 1., 51., -1.],
- [ 1., -1., 0., 0., 0.],
- [ 5., 0., 0., -2., 0.]],
+ [ 1., -0., 1., 51., -1.],
+ [ 1., -1., -0., 0., 0.],
+ [ 5., 0., -0., -2., -0.]],
- [[ 0., -2., 0., 0., 0.],
- [ 0., -2., 0., 1., -1.],
+ [[ -0., -2., -0., 0., 0.],
+ [ -0., -2., -0., 1., -1.],
[ 0., 0., -1., -1., 0.],
[ 0., -12., -1., -1., -1.],
- [ 0., -1., 1., 0., -2.],
- [ 1., 1., 0., 1., 0.],
- [ 0., 3., 1., 0., 1.],
- [ 1., -53., 0., 0., 2.],
- [ -1., 4., -5., 1., 0.],
- [ 0., 4., -1., 0., 1.]],
+ [ -0., -1., 1., -0., -2.],
+ [ 1., 1., -0., 1., 0.],
+ [ -0., 3., 1., -0., 1.],
+ [ 1., -53., -0., -0., 2.],
+ [ -1., 4., -5., 1., -0.],
+ [ -0., 4., -1., -0., 1.]],
[[ 1., 1., -235., 0., -1.],
- [ 0., 0., 1., 5., 0.],
- [ -3., 0., 1., 0., 0.],
- [ -17., 48., 0., 0., 0.],
- [ -1., -1., 0., 0., 1.],
- [ 0., 0., 0., 0., 0.],
+ [ -0., 0., 1., 5., -0.],
+ [ -3., -0., 1., 0., 0.],
+ [ -17., 48., 0., -0., 0.],
+ [ -1., -1., -0., -0., 1.],
+ [ -0., -0., -0., 0., -0.],
[ 0., -5., -1., -1., -1.],
- [ 0., -1., 0., -3., -1.],
- [ 0., 0., 48., 0., 0.],
- [ 0., 0., 0., 0., 0.]],
+ [ 0., -1., -0., -3., -1.],
+ [ 0., 0., 49., 0., 0.],
+ [ -0., -0., -0., 0., 0.]],
[[ -1., 0., -1., -2., -2.],
- [ -1., 0., 0., -1., -1.],
+ [ -1., 0., -0., -1., -1.],
[ 0., -2., -10., 7., 0.],
- [ -7., -10., 0., 0., -1.],
- [ 0., 0., -5., 0., 15.],
+ [ -7., -10., -0., -0., -1.],
+ [ -0., -0., -5., -0., 15.],
[ 0., 0., -24., 0., -2.],
- [ 0., 0., 0., -1., 2.],
- [ -4., 0., 0., 0., -2.],
+ [ 0., -0., 0., -1., 2.],
+ [ -4., 0., 0., -0., -2.],
[ 1., 1., -81., -5., -1.],
- [ -1., -5., -1., -3., 0.]]], dtype=torch.float16)
+ [ -1., -5., -1., -3., -0.]]], dtype=torch.float16) Full error stack
|
I am now adding xfails |
if rounding_mode == "trunc": | ||
# Rounds the results of the division towards zero. | ||
# Equivalent to C-style integer division | ||
result = aten_trunc(op.Div(self, other)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will move to a common function when #834 is done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this. Could you share more about why we can use nested OnnxFunction now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a trace only function. So calling functions is fine. However we still do not like calling other aten functions. When we can have nested OnnxFunction calls, I will extract the trunc logic to a common function and call it from aten_trunc and this.
Right now I am doing this so aten_trunc
doesn't become trace_only
"aten::div.Tensor", | ||
"aten::div.Scalar", | ||
# When rounding_mode is None, performs a true division | ||
# https://pytorch.org/docs/stable/generated/torch.div.html |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is dispatcher expected to filter any attribute with None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would consider this to be a better match I think? Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do think the dispatcher should strip None keyword args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No I think that makes sense. It's just we are altering attributes, and param_schema
matching is diverged from the inputs/attributes sent into OnnxFunction. It's like there are many indications around dispatching/OnnxFunction param_schema
. And it's not good for debugging.
Dispatcher alters inputs/attributes with hidden assumptions, but never return the altered inputs/attributes. So in OnnxFunction perspective, it runs directly on that dispatched function with attributes it doesn't need (won't error).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want an implicit logic that when we feeds args/kwargs into param_schema
, we ignore the attributes with None? Like they never exists, so they won't match the function signature having that attribute? I think I don't have this in dispatcher yet.
I take it as we do, as I see there is no function considering when certain attribute is None. My concern is that in what level should we expose this information for debugging purpose.
cc @BowenBao
I think so. And unless that attribute has a default value of course |
I could add it into pytorch/pytorch#106478. |
def aten_divide(self: TensorType, other: TensorType) -> TensorType: | ||
"""divide.Tensor(Tensor self, Tensor other) -> Tensor""" | ||
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True) | ||
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you recall what kind of attributes have default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
str, int, float, bool attributes can have defaults I think. But I suppose any attributes should be able to have defaults with the attribute proto. Is this what you are asking?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if there is a situation that two ONNX variants only differs on one default attribute. In that case, the dispatcher won't be able to dispatch it.
aten_op_attr(X, Y, attr="Good"):
...
aten_op(X, Y):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. I would just hope/make sure that we don’t create variants like these.
I wonder if there is a way to test it. I think the matching logic you created can come in handy here. We can use that to test all variants registered in torchlib are not compatible with each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In dispatcher if we do see this case we can only pick one I suppose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supposedly, if I pick any from them, there shouldn't be an issue, because they should be equal when it comes to no attr specified.
Windows test: please don’t fail here. Fail in #986 instead |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just tested out that we actually don't need to change anything in pytorch/pytorch#106478. The attribute=None can't be perfect match anyway, as it's a NoneType. While aten::div param_schemas
ignores None attribute, and is dispatched.
aten::div.Tensor_mode
is implemented with two ONNX functions. Whenrounding_mode
isNone
, we useaten_div
. When it is not None, we useaten_div_mode
. This way we don't need to handle whenrounding_mode==None
inaten_div
.For
float16
inputs we need to cast to float32 to preserve precision. Otherwise-inf
sometimes becomesinf
in the output.aten_div
.Fixes #980