Skip to content

Implement aten::div.Tensor_mode | feat(torchlib) #988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Aug 8, 2023

Conversation

justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Aug 8, 2023

aten::div.Tensor_mode is implemented with two ONNX functions. When rounding_mode is None, we use aten_div. When it is not None, we use aten_div_mode. This way we don't need to handle when rounding_mode==None in aten_div.

For float16 inputs we need to cast to float32 to preserve precision. Otherwise -inf sometimes becomes inf in the output.

Fixes #980

@justinchuby justinchuby changed the title Support rounding_mode in div | feat(torchlib) Implement aten::div.Tensor_mode | feat(torchlib) Aug 8, 2023
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Aug 8, 2023
@codecov
Copy link

codecov bot commented Aug 8, 2023

Codecov Report

Merging #988 (ea0158c) into main (c6e216e) will increase coverage by 0.05%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main     #988      +/-   ##
==========================================
+ Coverage   77.12%   77.18%   +0.05%     
==========================================
  Files         112      112              
  Lines       13947    13970      +23     
  Branches     1438     1441       +3     
==========================================
+ Hits        10757    10783      +26     
+ Misses       2829     2826       -3     
  Partials      361      361              
Files Changed Coverage Δ
onnxscript/function_libs/torch_lib/ops/core.py 77.60% <100.00%> (+0.14%) ⬆️
...ests/function_libs/torch_lib/error_reproduction.py 100.00% <100.00%> (ø)
...nxscript/tests/function_libs/torch_lib/ops_test.py 94.73% <100.00%> (+0.12%) ⬆️
...ipt/tests/function_libs/torch_lib/ops_test_data.py 95.95% <100.00%> (+0.01%) ⬆️

@github-actions
Copy link

github-actions bot commented Aug 8, 2023

Test Results

         18 files  ±    0         18 suites  ±0   1h 6m 0s ⏱️ - 6m 15s
  10 213 tests +  37    7 458 ✔️ +  11      2 754 💤 +  26  0 ±0  1 🔥 ±0 
152 663 runs  +370  33 356 ✔️ +110  119 306 💤 +260  0 ±0  1 🔥 ±0 

For more details on these errors, see this check.

Results for commit ea0158c. ± Comparison against base commit c6e216e.

This pull request removes 352 and adds 389 tests. Note that renamed tests count towards both.
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_051_aten_dot
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_052_aten_empty
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_053_aten_eq
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_054_aten_equal
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_055_aten_exp
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_056_aten_exp2
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_057_aten_expand
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_058_aten_expand_as
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_059_aten_special_erf
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_060_aten_special_erfc
…
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_051_aten_div_mode
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_052_aten_dot
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_053_aten_empty
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_054_aten_eq
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_055_aten_equal
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_056_aten_exp
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_057_aten_exp2
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_058_aten_expand
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_059_aten_expand_as
onnxscript.tests.function_libs.torch_lib.ops_test.TestFunctionValidity ‑ test_function_has_op_schema_060_aten_special_erf
…

♻️ This comment has been updated with latest results.

@justinchuby justinchuby mentioned this pull request Aug 8, 2023
@justinchuby
Copy link
Collaborator Author

Split a variant for float16 and use cast to INT64 for trunc because max of float16 is 65536, which is within INT64’s range.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Aug 8, 2023

Example mismatch report:

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
ops_test.TestOutputConsistencyEagerCPU.test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16, sample 0 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16

Inputs

Details

inputs = (tensor([[[-6.6953,  6.7070, -7.7422, -5.9688, -3.5234],
         [-0.0703,  8.8828,  5.0195, -1.2656, -8.9922],
         [-1.1250,  2.6641, -6.7773,  5.7227, -5.0273],
         [ 0.3164,  1.3975,  8.9453, -0.0879, -8.9297],
         [ 2.0391,  0.8613,  5.7812, -1.8369, -7.1797],
         [ 1.1426, -5.4570, -7.6719, -8.5312,  1.0459],
         [ 5.9062,  8.0781, -3.4883, -3.8496,  5.4844],
         [ 8.5000,  3.6836,  3.6992, -8.0938,  7.4805],
         [-6.7773,  1.9863,  0.9756, -4.5273,  4.1484],
         [-8.1406, -7.6641,  7.5586, -0.9756,  7.8672]],

        [[-1.2920, -6.0391,  6.1953, -1.7842,  1.4238],
         [ 6.3203, -0.5977,  5.9766,  8.7422, -7.5938],
         [ 5.0898, -1.9688, -5.5117,  2.4258, -1.8369],
         [ 8.0391,  4.4219,  7.8672,  3.4375,  1.1777],
         [ 1.7051, -5.4844, -3.3828,  0.2812, -2.9609],
         [-4.9648, -0.6152,  7.7500, -4.8789,  3.2871],
         [ 0.1846, -2.4180,  5.8789, -8.6719, -6.9883],
         [ 3.1992, -4.0625, -5.1602,  6.8125,  2.4785],
         [-8.7734, -7.0234,  2.4258, -1.9951,  5.3438],
         [-6.5117, -1.2832, -0.5713,  4.2188, -5.0273]],

        [[-0.0439,  7.2773, -1.8369, -3.2168,  4.5273],
         [-4.2031,  1.6787,  0.4219,  4.4922, -8.8594],
         [ 1.5029,  1.9248, -7.6211, -4.8164,  0.3955],
         [ 3.2695, -4.4570, -8.1406,  7.8320,  6.5391],
         [ 0.3164, -6.5469,  3.0586, -4.6406, -6.3555],
         [-7.4180, -3.6738,  0.8613, -2.8555, -0.2812],
         [ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
         [ 8.6328,  1.4062,  2.2422,  2.8301, -7.8828],
         [-8.8516,  6.9609,  8.7969,  6.4531, -4.0352],
         [-4.0000,  2.8301,  7.6562, -2.9805,  6.9531]],

        [[-3.5078,  2.5234,  8.2812, -1.5029,  5.8203],
         [-0.5537,  5.6094, -7.6484, -4.7031,  3.3828],
         [ 8.6562, -2.1094,  0.9053,  8.1562,  3.4453],
         [ 7.1172, -6.8477,  1.5381,  0.3340,  4.1836],
         [ 5.9844, -2.3379, -4.2812,  7.6094,  7.1797],
         [-0.9141, -5.4219, -4.3945, -1.7314,  4.4375],
         [-1.0723,  7.2422, -8.6953, -2.9883, -8.8359],
         [-6.1875,  6.8125, -4.5078,  8.8359, -5.4922],
         [-3.2617, -0.4658,  7.7500, -4.8672,  1.2480],
         [ 0.7998,  3.1016, -1.6787,  7.0234, -2.6191]],

        [[ 8.4219, -1.5029, -8.5625,  5.8359,  6.1250],
         [-3.9551,  1.6611,  1.0898, -7.6914, -6.2500],
         [-4.0000,  2.9961, -6.3281,  6.0742, -3.0156],
         [-1.6436,  6.8750,  0.4658, -0.1934, -5.1055],
         [-3.6211, -0.2812, -6.6875,  0.4570,  5.0195],
         [-0.1846,  5.4414, -8.9688,  0.8965,  1.7139],
         [-2.2070, -1.3008, -0.2900, -1.2393,  3.0664],
         [-8.5938, -1.3887, -1.7754, -0.7822,  7.1992],
         [-3.2422, -6.6445, -5.7578, -3.4180,  0.9229],
         [ 5.8789,  5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062,  5.3008,  0.5010,  8.3516,  3.0312],
         [-2.4961,  4.1562, -7.9805,  7.2852, -2.0566],
         [ 1.1074,  3.4453,  7.9727, -0.1055,  5.6250],
         [ 0.5010, -3.4375, -4.5000,  6.5938, -1.9600],
         [ 0.8350,  5.5977,  1.0283, -4.0859,  5.2734],
         [-4.7188,  2.0742,  7.7695, -1.1250,  2.3828],
         [ 1.6523,  8.7734, -2.8203,  3.9199, -3.9023],
         [-2.0664, -1.4414, -6.8359,  8.4531,  0.6592],
         [-7.6562,  5.7578,  3.8145, -0.2461,  0.7471],
         [-2.8828,  2.4961,  8.5547,  1.9248,  0.5977]],

        [[ 0.4043, -0.1406, -8.1797,  6.5820, -0.5889],
         [-0.1143, -5.0273,  1.7842, -0.9932, -7.0938],
         [-4.8438, -5.2109, -6.0117, -4.1641,  5.3711],
         [ 2.5312,  2.2227,  6.4414, -5.8711, -8.8047],
         [ 6.6875, -4.4219,  6.6523,  7.3477,  0.8701],
         [-1.9600,  1.5732,  1.9512,  7.9531, -2.3125],
         [ 5.7812, -5.1250,  5.3359, -2.2500, -0.0791],
         [ 2.8477,  8.8281, -4.4727,  0.1318, -2.0469],
         [-6.0547,  6.6641, -6.7422, -4.6406,  8.6719],
         [-1.1338, -7.6016,  4.1641, -1.5205,  5.8281]],

        [[ 7.9727, -2.9453,  5.9219, -5.6680,  4.7812],
         [ 8.1562, -0.5977, -2.9355,  3.7441,  5.3789],
         [ 2.3477,  7.3281,  5.0625,  3.6562,  1.4678],
         [ 7.3906,  0.3691,  4.8789, -6.5039, -5.5195],
         [-8.9062,  5.3281,  2.6445,  7.8828,  2.2148],
         [-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
         [-8.2188, -2.1875, -3.6914,  8.0938, -0.4834],
         [ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
         [ 7.9297,  1.7139, -1.7490,  4.3945,  5.4922],
         [ 4.3242,  0.6504, -5.3789,  3.2969,  3.6836]],

        [[-2.6016,  2.4258, -0.0352, -2.3203, -3.0664],
         [ 3.4531,  7.6289, -4.9922, -0.9229, -7.3203],
         [-2.5137,  6.0469,  0.8613,  8.8516,  4.8086],
         [-0.4131, -0.1406,  8.7734, -8.1641,  5.8281],
         [-3.1816,  1.6611,  6.4609, -8.6875,  6.8477],
         [ 3.2344,  6.1016,  4.6055, -6.5547, -7.1016],
         [-1.3887, -1.3359,  7.0039,  2.3906,  7.7344],
         [-8.4609, -3.7695,  7.7266, -2.6016,  2.9961],
         [-5.5195, -2.8652,  0.1582, -7.4531,  5.3711],
         [-5.5273, -5.0977,  2.0391,  8.0312, -8.3438]],

        [[-6.6797, -5.0078,  6.9883, -2.6992, -2.5488],
         [ 3.3672,  4.9492, -1.5293,  4.4375,  3.2168],
         [-7.7422, -1.2305,  0.5977,  0.8613, -5.5273],
         [ 0.2109, -0.6768, -1.1777,  4.1133,  4.2461],
         [ 4.9570,  1.3184,  1.2568, -4.0078,  0.3340],
         [-8.0469,  6.0820,  0.3604,  1.6260, -0.7910],
         [-5.5117,  1.9512, -3.4375,  1.1602,  1.4238],
         [ 1.8809, -2.5664, -6.9453,  7.3984, -3.2266],
         [-2.4082, -3.8398,  0.0703,  0.6416, -0.6240],
         [-4.6836, -0.9844,  2.2148,  2.0117,  4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'}

Expected output

expected = tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
         [   0.,    2.,   -0.,   -0.,    4.],
         [  -1.,    0.,   -0.,  -54.,   -0.],
         [   0.,   -0.,   -1.,   -0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [  -0.,   -2.,   -0.,    7.,    0.],
         [   3.,    0.,    1.,   -0.,   -1.],
         [  -4.,   -2.,   -0.,   -0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,   -0.,   13.]],

        [[  -3.,   42.,   -0.,   -0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,   -0.,   -0.],
         [   3.,    1.,    1.,   -0.,   -0.],
         [   0.,    1.,   -0.,    0.,   -3.],
         [   2.,   -0.,    3.,   -0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -0.,    1.,   51.,   -1.],
         [   1.,   -1.,   -0.,    0.,    0.],
         [   5.,    0.,   -0.,   -2.,   -0.]],

        [[  -0.,   -2.,   -0.,    0.,    0.],
         [  -0.,   -2.,   -0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [  -0.,   -1.,    1.,   -0.,   -2.],
         [   1.,    1.,   -0.,    1.,    0.],
         [  -0.,    3.,    1.,   -0.,    1.],
         [   1.,  -53.,   -0.,   -0.,    2.],
         [  -1.,    4.,   -5.,    1.,   -0.],
         [  -0.,    4.,   -1.,   -0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [  -0.,    0.,    1.,    5.,   -0.],
         [  -3.,   -0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,   -0.,    0.],
         [  -1.,   -1.,   -0.,   -0.,    1.],
         [  -0.,   -0.,   -0.,    0.,   -0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,   -0.,   -3.,   -1.],
         [   0.,    0.,   49.,    0.,    0.],
         [  -0.,   -0.,   -0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,   -0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,   -0.,   -0.,   -1.],
         [  -0.,   -0.,   -5.,   -0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,   -0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,   -0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Actual output

actual = tensor([[[   0.,    1.,  -15.,    0.,   -1.],
         [   0.,    2.,    0.,    0.,    4.],
         [  -1.,    0.,    0.,  -54.,    0.],
         [   0.,    0.,   -1.,    0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [   0.,   -2.,    0.,    7.,    0.],
         [   3.,    0.,    1.,    0.,   -1.],
         [  -4.,   -2.,    0.,    0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,    0.,   13.]],

        [[  -3.,   42.,    0.,    0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,    0.,    0.],
         [   3.,    1.,    1.,    0.,    0.],
         [   0.,    1.,    0.,    0.,   -3.],
         [   2.,    0.,    3.,    0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,    0.,    1.,   51.,   -1.],
         [   1.,   -1.,    0.,    0.,    0.],
         [   5.,    0.,    0.,   -2.,    0.]],

        [[   0.,   -2.,    0.,    0.,    0.],
         [   0.,   -2.,    0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    1.,    0.,   -2.],
         [   1.,    1.,    0.,    1.,    0.],
         [   0.,    3.,    1.,    0.,    1.],
         [   1.,  -53.,    0.,    0.,    2.],
         [  -1.,    4.,   -5.,    1.,    0.],
         [   0.,    4.,   -1.,    0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [   0.,    0.,    1.,    5.,    0.],
         [  -3.,    0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,    0.,    0.],
         [  -1.,   -1.,    0.,    0.,    1.],
         [   0.,    0.,    0.,    0.,    0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    0.,   -3.,   -1.],
         [   0.,    0.,   48.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,    0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,    0.,    0.,   -1.],
         [   0.,    0.,   -5.,    0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,    0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,    0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)

Full error stack

Tensor-likes are not close!

Mismatched elements: 1 / 250 (0.4%)
Greatest absolute difference: 1.0 at index (3, 8, 2) (up to 1e-05 allowed)
Greatest relative difference: 0.0204010009765625 at index (3, 8, 2) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)

@titaiwangms
Copy link
Contributor

CREATE_REPRODUCTION_REPORT=1 is nice!

@justinchuby
Copy link
Collaborator Author

justinchuby commented Aug 8, 2023

Infinities are flipped:

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
ops_test.TestOutputConsistencyEagerCPU.test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16, sample 6 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16

Inputs

Shapes: ['Tensor<torch.Size([10, 1, 5]), dtype=torch.float16>', 'Tensor<torch.Size([10, 5]), dtype=torch.float16>']

inputs = (tensor([[[ 1.5205,  5.7305,  1.4678, -7.5156,  3.4453]],

        [[-0.6943, -2.0215, -3.4375,  1.6963,  4.5352]],

        [[-8.7500,  1.4414,  3.6992, -6.0312,  5.6875]],

        [[ 4.3516,  3.1992,  7.3125,  2.9961, -4.6133]],

        [[ 5.7656, -4.5156,  2.8652, -7.1016,  2.6367]],

        [[ 1.1514,  7.0586,  0.6064,  2.2070,  3.6387]],

        [[ 6.9062, -5.0977, -1.3096,  1.5908, -7.5391]],

        [[-8.2109, -2.7344,  7.5312,  4.6562, -8.6641]],

        [[ 1.0986,  3.9551, -7.4883, -3.2695, -2.3828]],

        [[ 1.9248,  5.5977,  0.5010,  5.7930, -4.8242]]], dtype=torch.float16), tensor([[-5.0000e+00, -1.7227e+00, -1.1426e+00,  3.1641e+00,  8.2109e+00],
        [ 6.1035e-05, -5.6094e+00,  8.5469e+00, -5.9492e+00, -7.3398e+00],
        [-4.5703e+00,  1.6699e+00,  4.4570e+00,  3.2344e+00,  3.5859e+00],
        [-1.7666e+00, -8.4375e-01, -7.9980e-01,  1.0195e+00, -7.8320e+00],
        [-8.5000e+00,  4.8789e+00, -2.9883e+00, -3.1016e+00, -7.4883e+00],
        [ 7.2773e+00, -1.8457e+00,  1.3008e+00, -8.3047e+00, -1.2217e+00],
        [ 1.8633e+00, -7.2266e+00, -2.2930e+00,  2.0039e+00,  5.9336e+00],
        [ 1.9863e+00,  5.7812e+00,  4.6680e+00,  3.8672e+00,  3.5234e+00],
        [-8.5938e+00,  2.0391e+00,  2.6719e+00,  4.6133e+00,  1.2480e+00],
        [ 3.6387e+00,  4.5000e+00,  3.2695e+00,  4.9648e+00,  7.7773e+00]],
       dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'}

Expected output

expected = tensor([[[-0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00,  0.0000e+00],
         [ 2.4912e+04, -1.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
         [-0.0000e+00,  3.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
         [-0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, -0.0000e+00],
         [-0.0000e+00,  1.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  1.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
         [-0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00,  2.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  1.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00],
         [-1.1376e+04,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
         [ 0.0000e+00, -1.0000e+00, -0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  4.0000e+00,  1.0000e+00, -0.0000e+00],
         [ 0.0000e+00, -0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  1.0000e+00, -2.0000e+00, -0.0000e+00, -3.0000e+00],
         [-0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00,  3.0000e+00],
         [-0.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 1.0000e+00, -0.0000e+00, -3.0000e+00, -1.0000e+00,  0.0000e+00],
         [       -inf, -0.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
         [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, -0.0000e+00],
         [ 1.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00, -0.0000e+00],
         [-1.0000e+00, -0.0000e+00,  2.0000e+00,  0.0000e+00, -4.0000e+00],
         [-4.0000e+00, -0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
         [-4.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  4.0000e+00],
         [-2.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[-0.0000e+00, -1.0000e+00, -6.0000e+00,  0.0000e+00, -0.0000e+00],
         [        inf, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
         [-2.0000e+00, -3.0000e+00, -9.0000e+00,  2.0000e+00,  0.0000e+00],
         [-0.0000e+00,  0.0000e+00, -2.0000e+00, -0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  5.0000e+00, -0.0000e+00,  3.0000e+00],
         [ 2.0000e+00, -0.0000e+00, -3.0000e+00,  1.0000e+00, -0.0000e+00],
         [ 2.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
         [-0.0000e+00,  1.0000e+00,  2.0000e+00,  0.0000e+00, -3.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -0.0000e+00]],

        [[-1.0000e+00,  2.0000e+00, -2.0000e+00, -2.0000e+00,  0.0000e+00],
         [        inf,  0.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
         [-1.0000e+00, -2.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
         [-3.0000e+00,  5.0000e+00, -3.0000e+00, -6.0000e+00, -0.0000e+00],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  2.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 3.0000e+00,  0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
         [ 2.0000e+00, -0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
         [-0.0000e+00, -2.0000e+00,  1.0000e+00, -1.0000e+00,  2.0000e+00],
         [ 1.0000e+00, -1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[-0.0000e+00, -4.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.8864e+04, -1.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  4.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [-0.0000e+00, -8.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
         [-0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  0.0000e+00, -0.0000e+00, -2.0000e+00],
         [ 0.0000e+00, -0.0000e+00, -0.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [-0.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-1.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00, -0.0000e+00],
         [        inf,  0.0000e+00, -0.0000e+00, -0.0000e+00,  1.0000e+00],
         [-1.0000e+00, -3.0000e+00, -0.0000e+00,  0.0000e+00, -2.0000e+00],
         [-3.0000e+00,  6.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00],
         [-0.0000e+00, -1.0000e+00,  0.0000e+00, -0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  2.0000e+00, -1.0000e+00, -0.0000e+00,  6.0000e+00],
         [ 3.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00],
         [ 3.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -2.0000e+00],
         [-0.0000e+00, -2.0000e+00, -0.0000e+00,  0.0000e+00, -6.0000e+00],
         [ 1.0000e+00, -1.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00]],

        [[ 1.0000e+00,  1.0000e+00, -6.0000e+00,  1.0000e+00, -1.0000e+00],
         [       -inf,  0.0000e+00,  0.0000e+00, -0.0000e+00,  1.0000e+00],
         [ 1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
         [ 4.0000e+00,  3.0000e+00, -9.0000e+00,  4.0000e+00,  1.0000e+00],
         [ 0.0000e+00, -0.0000e+00, -2.0000e+00, -1.0000e+00,  1.0000e+00],
         [-1.0000e+00,  1.0000e+00,  5.0000e+00, -0.0000e+00,  7.0000e+00],
         [-4.0000e+00,  0.0000e+00, -3.0000e+00,  2.0000e+00, -1.0000e+00],
         [-4.0000e+00, -0.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  2.0000e+00,  1.0000e+00, -6.0000e+00],
         [-2.0000e+00, -0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00]],

        [[-0.0000e+00, -2.0000e+00,  6.0000e+00, -1.0000e+00, -0.0000e+00],
         [ 1.8000e+04, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  2.0000e+00, -1.0000e+00, -1.0000e+00, -0.0000e+00],
         [-0.0000e+00, -4.0000e+00,  9.0000e+00, -3.0000e+00,  0.0000e+00],
         [-0.0000e+00,  0.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -2.0000e+00, -5.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00, -0.0000e+00,  3.0000e+00, -1.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0000e+00, -0.0000e+00, -0.0000e+00],
         [-0.0000e+00,  1.0000e+00, -2.0000e+00, -0.0000e+00, -1.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -2.0000e+00, -0.0000e+00, -0.0000e+00]],

        [[-0.0000e+00, -3.0000e+00, -0.0000e+00,  1.0000e+00, -0.0000e+00],
         [ 3.1536e+04, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
         [-0.0000e+00,  3.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
         [-1.0000e+00, -6.0000e+00, -0.0000e+00,  5.0000e+00,  0.0000e+00],
         [-0.0000e+00,  1.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  0.0000e+00, -0.0000e+00,  3.0000e+00],
         [ 1.0000e+00, -0.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
         [-0.0000e+00,  2.0000e+00,  0.0000e+00,  1.0000e+00, -3.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00]]],
       dtype=torch.float16)

Shape: torch.Size([10, 10, 5])

Actual output

actual = tensor([[[ 0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00,  0.0000e+00],
         [ 2.4912e+04, -1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  3.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  1.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00,  2.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  1.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00],
         [-1.1376e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  4.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00, -2.0000e+00,  0.0000e+00, -3.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  3.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 1.0000e+00,  0.0000e+00, -3.0000e+00, -1.0000e+00,  0.0000e+00],
         [        inf,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
         [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00,  0.0000e+00],
         [ 1.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00,  0.0000e+00],
         [-1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -4.0000e+00],
         [-4.0000e+00,  0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
         [-4.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  4.0000e+00],
         [-2.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00, -1.0000e+00, -6.0000e+00,  0.0000e+00,  0.0000e+00],
         [       -inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
         [-2.0000e+00, -3.0000e+00, -9.0000e+00,  2.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  5.0000e+00,  0.0000e+00,  3.0000e+00],
         [ 2.0000e+00,  0.0000e+00, -3.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 2.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  2.0000e+00,  0.0000e+00, -3.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-1.0000e+00,  2.0000e+00, -2.0000e+00, -2.0000e+00,  0.0000e+00],
         [       -inf,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
         [-1.0000e+00, -2.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
         [-3.0000e+00,  5.0000e+00, -3.0000e+00, -6.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  2.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 3.0000e+00,  0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
         [ 2.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -2.0000e+00,  1.0000e+00, -1.0000e+00,  2.0000e+00],
         [ 1.0000e+00, -1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00, -4.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 1.8864e+04, -1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  4.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00, -8.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[-1.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [       -inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [-1.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
         [-3.0000e+00,  6.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  2.0000e+00, -1.0000e+00,  0.0000e+00,  6.0000e+00],
         [ 3.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00],
         [ 3.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
         [ 0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00, -6.0000e+00],
         [ 1.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 1.0000e+00,  1.0000e+00, -6.0000e+00,  1.0000e+00, -1.0000e+00],
         [        inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
         [ 4.0000e+00,  3.0000e+00, -9.0000e+00,  4.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -2.0000e+00, -1.0000e+00,  1.0000e+00],
         [-1.0000e+00,  1.0000e+00,  5.0000e+00,  0.0000e+00,  7.0000e+00],
         [-4.0000e+00,  0.0000e+00, -3.0000e+00,  2.0000e+00, -1.0000e+00],
         [-4.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
         [ 0.0000e+00, -1.0000e+00,  2.0000e+00,  1.0000e+00, -6.0000e+00],
         [-2.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00]],

        [[ 0.0000e+00, -2.0000e+00,  6.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 1.8000e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  2.0000e+00, -1.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -4.0000e+00,  9.0000e+00, -3.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -2.0000e+00, -5.0000e+00,  0.0000e+00,  1.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  3.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00, -2.0000e+00,  0.0000e+00, -1.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00, -3.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
         [ 3.1536e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  3.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
         [-1.0000e+00, -6.0000e+00,  0.0000e+00,  5.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
         [ 0.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00,  3.0000e+00],
         [ 1.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
         [ 0.0000e+00,  2.0000e+00,  0.0000e+00,  1.0000e+00, -3.0000e+00],
         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00]]],
       dtype=torch.float16)

Shape: torch.Size([10, 10, 5])

Difference

--- actual
+++ expected
@@ -1,110 +1,110 @@
-tensor([[[ 0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00,  0.0000e+00],
-         [ 2.4912e+04, -1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  3.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
+tensor([[[-0.0000e+00, -3.0000e+00, -1.0000e+00, -2.0000e+00,  0.0000e+00],
+         [ 2.4912e+04, -1.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  3.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
+         [-0.0000e+00, -6.0000e+00, -1.0000e+00, -7.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  1.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
          [ 0.0000e+00, -3.0000e+00,  1.0000e+00,  0.0000e+00, -2.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.0000e+00,  0.0000e+00],
+         [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -3.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00,  2.0000e+00],
+         [-0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00,  2.0000e+00],
          [ 0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],
 
         [[ 0.0000e+00,  1.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00],
-         [-1.1376e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  2.0000e+00,  4.0000e+00,  1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00, -2.0000e+00,  0.0000e+00, -3.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  3.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00]],
+         [-1.1376e+04,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
+         [ 0.0000e+00, -1.0000e+00, -0.0000e+00,  0.0000e+00,  1.0000e+00],
+         [ 0.0000e+00,  2.0000e+00,  4.0000e+00,  1.0000e+00, -0.0000e+00],
+         [ 0.0000e+00, -0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  1.0000e+00, -2.0000e+00, -0.0000e+00, -3.0000e+00],
+         [-0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
+         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00,  1.0000e+00],
+         [ 0.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00,  3.0000e+00],
+         [-0.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00]],
 
-        [[ 1.0000e+00,  0.0000e+00, -3.0000e+00, -1.0000e+00,  0.0000e+00],
-         [        inf,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
+        [[ 1.0000e+00, -0.0000e+00, -3.0000e+00, -1.0000e+00,  0.0000e+00],
+         [       -inf, -0.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
          [ 1.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
-         [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00,  0.0000e+00],
-         [ 1.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00,  0.0000e+00],
-         [-1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -4.0000e+00],
-         [-4.0000e+00,  0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
+         [ 4.0000e+00, -1.0000e+00, -4.0000e+00, -5.0000e+00, -0.0000e+00],
+         [ 1.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00, -0.0000e+00],
+         [-1.0000e+00, -0.0000e+00,  2.0000e+00,  0.0000e+00, -4.0000e+00],
+         [-4.0000e+00, -0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
          [-4.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  1.0000e+00],
          [ 1.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  4.0000e+00],
          [-2.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00,  0.0000e+00]],
 
-        [[ 0.0000e+00, -1.0000e+00, -6.0000e+00,  0.0000e+00,  0.0000e+00],
-         [       -inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
+        [[-0.0000e+00, -1.0000e+00, -6.0000e+00,  0.0000e+00, -0.0000e+00],
+         [        inf, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
          [-2.0000e+00, -3.0000e+00, -9.0000e+00,  2.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -1.0000e+00,  5.0000e+00,  0.0000e+00,  3.0000e+00],
-         [ 2.0000e+00,  0.0000e+00, -3.0000e+00,  1.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  0.0000e+00, -2.0000e+00, -0.0000e+00,  0.0000e+00],
+         [ 0.0000e+00, -1.0000e+00,  5.0000e+00, -0.0000e+00,  3.0000e+00],
+         [ 2.0000e+00, -0.0000e+00, -3.0000e+00,  1.0000e+00, -0.0000e+00],
          [ 2.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  2.0000e+00,  0.0000e+00, -3.0000e+00],
-         [ 1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00,  0.0000e+00]],
+         [-0.0000e+00,  1.0000e+00,  2.0000e+00,  0.0000e+00, -3.0000e+00],
+         [ 1.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -0.0000e+00]],
 
         [[-1.0000e+00,  2.0000e+00, -2.0000e+00, -2.0000e+00,  0.0000e+00],
-         [       -inf,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
+         [        inf,  0.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00],
          [-1.0000e+00, -2.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00],
-         [-3.0000e+00,  5.0000e+00, -3.0000e+00, -6.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
+         [-3.0000e+00,  5.0000e+00, -3.0000e+00, -6.0000e+00, -0.0000e+00],
+         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  2.0000e+00,  2.0000e+00,  0.0000e+00, -2.0000e+00],
          [ 3.0000e+00,  0.0000e+00, -1.0000e+00, -3.0000e+00,  0.0000e+00],
-         [ 2.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -2.0000e+00,  1.0000e+00, -1.0000e+00,  2.0000e+00],
+         [ 2.0000e+00, -0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
+         [-0.0000e+00, -2.0000e+00,  1.0000e+00, -1.0000e+00,  2.0000e+00],
          [ 1.0000e+00, -1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00]],
 
-        [[ 0.0000e+00, -4.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 1.8864e+04, -1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  4.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00, -8.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
+        [[-0.0000e+00, -4.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
+         [ 1.8864e+04, -1.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  4.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
+         [-0.0000e+00, -8.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  1.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
+         [ 0.0000e+00, -3.0000e+00,  0.0000e+00, -0.0000e+00, -2.0000e+00],
+         [ 0.0000e+00, -0.0000e+00, -0.0000e+00,  1.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00],
+         [-0.0000e+00,  3.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00],
          [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
 
-        [[-1.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
-         [       -inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [-1.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
+        [[-1.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00, -0.0000e+00],
+         [        inf,  0.0000e+00, -0.0000e+00, -0.0000e+00,  1.0000e+00],
+         [-1.0000e+00, -3.0000e+00, -0.0000e+00,  0.0000e+00, -2.0000e+00],
          [-3.0000e+00,  6.0000e+00,  1.0000e+00,  1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  2.0000e+00, -1.0000e+00,  0.0000e+00,  6.0000e+00],
+         [-0.0000e+00, -1.0000e+00,  0.0000e+00, -0.0000e+00,  1.0000e+00],
+         [ 0.0000e+00,  2.0000e+00, -1.0000e+00, -0.0000e+00,  6.0000e+00],
          [ 3.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+00],
-         [ 3.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -2.0000e+00],
-         [ 0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00, -6.0000e+00],
-         [ 1.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]],
+         [ 3.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -2.0000e+00],
+         [-0.0000e+00, -2.0000e+00, -0.0000e+00,  0.0000e+00, -6.0000e+00],
+         [ 1.0000e+00, -1.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00]],
 
         [[ 1.0000e+00,  1.0000e+00, -6.0000e+00,  1.0000e+00, -1.0000e+00],
-         [        inf,  0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00],
+         [       -inf,  0.0000e+00,  0.0000e+00, -0.0000e+00,  1.0000e+00],
          [ 1.0000e+00, -1.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
          [ 4.0000e+00,  3.0000e+00, -9.0000e+00,  4.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -2.0000e+00, -1.0000e+00,  1.0000e+00],
-         [-1.0000e+00,  1.0000e+00,  5.0000e+00,  0.0000e+00,  7.0000e+00],
+         [ 0.0000e+00, -0.0000e+00, -2.0000e+00, -1.0000e+00,  1.0000e+00],
+         [-1.0000e+00,  1.0000e+00,  5.0000e+00, -0.0000e+00,  7.0000e+00],
          [-4.0000e+00,  0.0000e+00, -3.0000e+00,  2.0000e+00, -1.0000e+00],
-         [-4.0000e+00,  0.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
+         [-4.0000e+00, -0.0000e+00,  1.0000e+00,  1.0000e+00, -2.0000e+00],
          [ 0.0000e+00, -1.0000e+00,  2.0000e+00,  1.0000e+00, -6.0000e+00],
-         [-2.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00]],
+         [-2.0000e+00, -0.0000e+00,  2.0000e+00,  0.0000e+00, -1.0000e+00]],
 
-        [[ 0.0000e+00, -2.0000e+00,  6.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 1.8000e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  2.0000e+00, -1.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -4.0000e+00,  9.0000e+00, -3.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00],
+        [[-0.0000e+00, -2.0000e+00,  6.0000e+00, -1.0000e+00, -0.0000e+00],
+         [ 1.8000e+04, -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  2.0000e+00, -1.0000e+00, -1.0000e+00, -0.0000e+00],
+         [-0.0000e+00, -4.0000e+00,  9.0000e+00, -3.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  0.0000e+00,  2.0000e+00,  1.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -2.0000e+00, -5.0000e+00,  0.0000e+00,  1.0000e+00],
-         [ 0.0000e+00,  0.0000e+00,  3.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00, -2.0000e+00,  0.0000e+00, -1.0000e+00],
-         [ 0.0000e+00,  0.0000e+00, -2.0000e+00,  0.0000e+00,  0.0000e+00]],
+         [ 0.0000e+00, -0.0000e+00,  3.0000e+00, -1.0000e+00, -0.0000e+00],
+         [ 0.0000e+00,  0.0000e+00, -1.0000e+00, -0.0000e+00, -0.0000e+00],
+         [-0.0000e+00,  1.0000e+00, -2.0000e+00, -0.0000e+00, -1.0000e+00],
+         [ 0.0000e+00,  0.0000e+00, -2.0000e+00, -0.0000e+00, -0.0000e+00]],
 
-        [[ 0.0000e+00, -3.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00],
-         [ 3.1536e+04,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  3.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
-         [-1.0000e+00, -6.0000e+00,  0.0000e+00,  5.0000e+00,  0.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  0.0000e+00, -1.0000e+00,  0.0000e+00],
-         [ 0.0000e+00, -3.0000e+00,  0.0000e+00,  0.0000e+00,  3.0000e+00],
-         [ 1.0000e+00,  0.0000e+00,  0.0000e+00,  2.0000e+00,  0.0000e+00],
+        [[-0.0000e+00, -3.0000e+00, -0.0000e+00,  1.0000e+00, -0.0000e+00],
+         [ 3.1536e+04, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  3.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
+         [-1.0000e+00, -6.0000e+00, -0.0000e+00,  5.0000e+00,  0.0000e+00],
+         [-0.0000e+00,  1.0000e+00, -0.0000e+00, -1.0000e+00,  0.0000e+00],
+         [ 0.0000e+00, -3.0000e+00,  0.0000e+00, -0.0000e+00,  3.0000e+00],
+         [ 1.0000e+00, -0.0000e+00, -0.0000e+00,  2.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00, -1.0000e+00],
-         [ 0.0000e+00,  2.0000e+00,  0.0000e+00,  1.0000e+00, -3.0000e+00],
-         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00]]],
+         [-0.0000e+00,  2.0000e+00,  0.0000e+00,  1.0000e+00, -3.0000e+00],
+         [ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00, -0.0000e+00]]],
        dtype=torch.float16)

Full error stack

Tensor-likes are not close!

Mismatched elements: 5 / 500 (1.0%)
Greatest absolute difference: inf at index (2, 1, 0) (up to 1e-05 allowed)
Greatest relative difference: nan at index (2, 1, 0) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)

@justinchuby
Copy link
Collaborator Author

-8.7500 / 6.1035e-05 should be -inf, not inf

@justinchuby
Copy link
Collaborator Author

aten_trunc(fp16(-143360.3670025395)) is -inf. Puzzling

@justinchuby
Copy link
Collaborator Author

justinchuby commented Aug 8, 2023

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
ops_test.TestOutputConsistencyEagerCPU.test_output_match_opinfo__div_mode_floor_rounding_cpu_float16, sample 5 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_floor_rounding_cpu_float16

Inputs

Shapes: ['Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>', 'Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>']

inputs = (tensor([[[-6.6953,  6.7070, -7.7422, -5.9688, -3.5234],
         [-0.0703,  8.8828,  5.0195, -1.2656, -8.9922],
         [-1.1250,  2.6641, -6.7773,  5.7227, -5.0273],
         [ 0.3164,  1.3975,  8.9453, -0.0879, -8.9297],
         [ 2.0391,  0.8613,  5.7812, -1.8369, -7.1797],
         [ 1.1426, -5.4570, -7.6719, -8.5312,  1.0459],
         [ 5.9062,  8.0781, -3.4883, -3.8496,  5.4844],
         [ 8.5000,  3.6836,  3.6992, -8.0938,  7.4805],
         [-6.7773,  1.9863,  0.9756, -4.5273,  4.1484],
         [-8.1406, -7.6641,  7.5586, -0.9756,  7.8672]],

        [[-1.2920, -6.0391,  6.1953, -1.7842,  1.4238],
         [ 6.3203, -0.5977,  5.9766,  8.7422, -7.5938],
         [ 5.0898, -1.9688, -5.5117,  2.4258, -1.8369],
         [ 8.0391,  4.4219,  7.8672,  3.4375,  1.1777],
         [ 1.7051, -5.4844, -3.3828,  0.2812, -2.9609],
         [-4.9648, -0.6152,  7.7500, -4.8789,  3.2871],
         [ 0.1846, -2.4180,  5.8789, -8.6719, -6.9883],
         [ 3.1992, -4.0625, -5.1602,  6.8125,  2.4785],
         [-8.7734, -7.0234,  2.4258, -1.9951,  5.3438],
         [-6.5117, -1.2832, -0.5713,  4.2188, -5.0273]],

        [[-0.0439,  7.2773, -1.8369, -3.2168,  4.5273],
         [-4.2031,  1.6787,  0.4219,  4.4922, -8.8594],
         [ 1.5029,  1.9248, -7.6211, -4.8164,  0.3955],
         [ 3.2695, -4.4570, -8.1406,  7.8320,  6.5391],
         [ 0.3164, -6.5469,  3.0586, -4.6406, -6.3555],
         [-7.4180, -3.6738,  0.8613, -2.8555, -0.2812],
         [ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
         [ 8.6328,  1.4062,  2.2422,  2.8301, -7.8828],
         [-8.8516,  6.9609,  8.7969,  6.4531, -4.0352],
         [-4.0000,  2.8301,  7.6562, -2.9805,  6.9531]],

        [[-3.5078,  2.5234,  8.2812, -1.5029,  5.8203],
         [-0.5537,  5.6094, -7.6484, -4.7031,  3.3828],
         [ 8.6562, -2.1094,  0.9053,  8.1562,  3.4453],
         [ 7.1172, -6.8477,  1.5381,  0.3340,  4.1836],
         [ 5.9844, -2.3379, -4.2812,  7.6094,  7.1797],
         [-0.9141, -5.4219, -4.3945, -1.7314,  4.4375],
         [-1.0723,  7.2422, -8.6953, -2.9883, -8.8359],
         [-6.1875,  6.8125, -4.5078,  8.8359, -5.4922],
         [-3.2617, -0.4658,  7.7500, -4.8672,  1.2480],
         [ 0.7998,  3.1016, -1.6787,  7.0234, -2.6191]],

        [[ 8.4219, -1.5029, -8.5625,  5.8359,  6.1250],
         [-3.9551,  1.6611,  1.0898, -7.6914, -6.2500],
         [-4.0000,  2.9961, -6.3281,  6.0742, -3.0156],
         [-1.6436,  6.8750,  0.4658, -0.1934, -5.1055],
         [-3.6211, -0.2812, -6.6875,  0.4570,  5.0195],
         [-0.1846,  5.4414, -8.9688,  0.8965,  1.7139],
         [-2.2070, -1.3008, -0.2900, -1.2393,  3.0664],
         [-8.5938, -1.3887, -1.7754, -0.7822,  7.1992],
         [-3.2422, -6.6445, -5.7578, -3.4180,  0.9229],
         [ 5.8789,  5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062,  5.3008,  0.5010,  8.3516,  3.0312],
         [-2.4961,  4.1562, -7.9805,  7.2852, -2.0566],
         [ 1.1074,  3.4453,  7.9727, -0.1055,  5.6250],
         [ 0.5010, -3.4375, -4.5000,  6.5938, -1.9600],
         [ 0.8350,  5.5977,  1.0283, -4.0859,  5.2734],
         [-4.7188,  2.0742,  7.7695, -1.1250,  2.3828],
         [ 1.6523,  8.7734, -2.8203,  3.9199, -3.9023],
         [-2.0664, -1.4414, -6.8359,  8.4531,  0.6592],
         [-7.6562,  5.7578,  3.8145, -0.2461,  0.7471],
         [-2.8828,  2.4961,  8.5547,  1.9248,  0.5977]],

        [[ 0.4043, -0.1406, -8.1797,  6.5820, -0.5889],
         [-0.1143, -5.0273,  1.7842, -0.9932, -7.0938],
         [-4.8438, -5.2109, -6.0117, -4.1641,  5.3711],
         [ 2.5312,  2.2227,  6.4414, -5.8711, -8.8047],
         [ 6.6875, -4.4219,  6.6523,  7.3477,  0.8701],
         [-1.9600,  1.5732,  1.9512,  7.9531, -2.3125],
         [ 5.7812, -5.1250,  5.3359, -2.2500, -0.0791],
         [ 2.8477,  8.8281, -4.4727,  0.1318, -2.0469],
         [-6.0547,  6.6641, -6.7422, -4.6406,  8.6719],
         [-1.1338, -7.6016,  4.1641, -1.5205,  5.8281]],

        [[ 7.9727, -2.9453,  5.9219, -5.6680,  4.7812],
         [ 8.1562, -0.5977, -2.9355,  3.7441,  5.3789],
         [ 2.3477,  7.3281,  5.0625,  3.6562,  1.4678],
         [ 7.3906,  0.3691,  4.8789, -6.5039, -5.5195],
         [-8.9062,  5.3281,  2.6445,  7.8828,  2.2148],
         [-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
         [-8.2188, -2.1875, -3.6914,  8.0938, -0.4834],
         [ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
         [ 7.9297,  1.7139, -1.7490,  4.3945,  5.4922],
         [ 4.3242,  0.6504, -5.3789,  3.2969,  3.6836]],

        [[-2.6016,  2.4258, -0.0352, -2.3203, -3.0664],
         [ 3.4531,  7.6289, -4.9922, -0.9229, -7.3203],
         [-2.5137,  6.0469,  0.8613,  8.8516,  4.8086],
         [-0.4131, -0.1406,  8.7734, -8.1641,  5.8281],
         [-3.1816,  1.6611,  6.4609, -8.6875,  6.8477],
         [ 3.2344,  6.1016,  4.6055, -6.5547, -7.1016],
         [-1.3887, -1.3359,  7.0039,  2.3906,  7.7344],
         [-8.4609, -3.7695,  7.7266, -2.6016,  2.9961],
         [-5.5195, -2.8652,  0.1582, -7.4531,  5.3711],
         [-5.5273, -5.0977,  2.0391,  8.0312, -8.3438]],

        [[-6.6797, -5.0078,  6.9883, -2.6992, -2.5488],
         [ 3.3672,  4.9492, -1.5293,  4.4375,  3.2168],
         [-7.7422, -1.2305,  0.5977,  0.8613, -5.5273],
         [ 0.2109, -0.6768, -1.1777,  4.1133,  4.2461],
         [ 4.9570,  1.3184,  1.2568, -4.0078,  0.3340],
         [-8.0469,  6.0820,  0.3604,  1.6260, -0.7910],
         [-5.5117,  1.9512, -3.4375,  1.1602,  1.4238],
         [ 1.8809, -2.5664, -6.9453,  7.3984, -3.2266],
         [-2.4082, -3.8398,  0.0703,  0.6416, -0.6240],
         [-4.6836, -0.9844,  2.2148,  2.0117,  4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'floor'}

Expected output

expected = tensor([[[  -1.,    1.,  -16.,   -1.,   -2.],
         [   0.,    2.,   -1.,   -1.,    4.],
         [  -2.,    0.,   -1.,  -55.,   -1.],
         [   0.,   -1.,   -2.,   -1.,    4.],
         [   2.,    0.,    5.,    0.,   -2.],
         [  -1.,   -3.,   -1.,    7.,    0.],
         [   3.,    0.,    1.,   -1.,   -2.],
         [  -5.,   -3.,   -1.,   -1.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -4.,    0.,   -1.,   13.]],

        [[  -4.,   42.,   -1.,   -1.,   -3.],
         [ -56.,    0.,    3.,   -9.,    1.],
         [  -2.,    0.,    0.,   -1.,   -1.],
         [   3.,    1.,    1.,   -1.,   -1.],
         [   0.,    1.,   -1.,    0.,   -4.],
         [   2.,   -1.,    3.,   -1.,   -2.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -1.,    1.,   51.,   -2.],
         [   1.,   -2.,   -1.,    0.,    0.],
         [   5.,    0.,   -1.,   -3.,   -1.]],

        [[  -1.,   -3.,   -1.,    0.,    0.],
         [  -1.,   -3.,   -1.,    1.,   -2.],
         [   0.,    0.,   -2.,   -2.,    0.],
         [   0.,  -13.,   -2.,   -2.,   -2.],
         [  -1.,   -2.,    1.,   -1.,   -3.],
         [   1.,    1.,   -1.,    1.,    0.],
         [  -1.,    3.,    1.,   -1.,    1.],
         [   1.,  -54.,   -1.,   -1.,    2.],
         [  -2.,    4.,   -6.,    1.,   -1.],
         [  -1.,    4.,   -2.,   -1.,    1.]],

        [[   1.,    1., -236.,    0.,   -2.],
         [  -1.,    0.,    1.,    5.,   -1.],
         [  -4.,   -1.,    1.,    0.,    0.],
         [ -18.,   48.,    0.,   -1.,    0.],
         [  -2.,   -2.,   -1.,   -1.,    1.],
         [  -1.,   -1.,   -1.,    0.,   -1.],
         [   0.,   -6.,   -2.,   -2.,   -2.],
         [   0.,   -2.,   -1.,   -4.,   -2.],
         [   0.,    0.,   48.,    0.,    0.],
         [  -1.,   -1.,   -1.,    0.,    0.]],

        [[  -2.,    0.,   -2.,   -3.,   -3.],
         [  -2.,    0.,   -1.,   -2.,   -2.],
         [   0.,   -3.,  -11.,    7.,    0.],
         [  -8.,  -11.,   -1.,   -1.,   -2.],
         [  -1.,   -1.,   -6.,   -1.,   15.],
         [   0.,    0.,  -25.,    0.,   -3.],
         [   0.,   -1.,    0.,   -2.,    2.],
         [  -5.,    0.,    0.,   -1.,   -3.],
         [   1.,    1.,  -82.,   -6.,   -2.],
         [  -2.,   -6.,   -2.,   -4.,   -1.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Actual output

actual = tensor([[[  -1.,    1.,  -16.,   -1.,   -2.],
         [   0.,    2.,   -1.,   -1.,    4.],
         [  -2.,    0.,   -1.,  -55.,   -1.],
         [   0.,   -1.,   -2.,   -1.,    4.],
         [   2.,    0.,    5.,    0.,   -2.],
         [  -1.,   -3.,   -1.,    7.,    0.],
         [   3.,    0.,    1.,   -1.,   -2.],
         [  -5.,   -3.,   -1.,   -1.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -4.,    0.,   -1.,   13.]],

        [[  -4.,   42.,   -1.,   -1.,   -3.],
         [ -56.,    0.,    3.,   -9.,    1.],
         [  -2.,    0.,    0.,   -1.,   -1.],
         [   3.,    1.,    1.,   -1.,   -1.],
         [   0.,    1.,   -1.,    0.,   -4.],
         [   2.,   -1.,    3.,   -1.,   -2.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -1.,    1.,   51.,   -2.],
         [   1.,   -2.,   -1.,    0.,    0.],
         [   5.,    0.,   -1.,   -3.,   -1.]],

        [[  -1.,   -3.,   -1.,    0.,    0.],
         [  -1.,   -3.,   -1.,    1.,   -2.],
         [   0.,    0.,   -2.,   -2.,    0.],
         [   0.,  -13.,   -2.,   -2.,   -2.],
         [  -1.,   -2.,    1.,   -1.,   -3.],
         [   1.,    1.,   -1.,    1.,    0.],
         [  -1.,    3.,    1.,   -1.,    1.],
         [   1.,  -54.,   -1.,   -1.,    2.],
         [  -2.,    4.,   -6.,    1.,   -1.],
         [  -1.,    4.,   -2.,   -1.,    1.]],

        [[   1.,    1., -236.,    0.,   -2.],
         [  -1.,    0.,    1.,    5.,   -1.],
         [  -4.,   -1.,    1.,    0.,    0.],
         [ -18.,   48.,    0.,   -1.,    0.],
         [  -2.,   -2.,   -1.,   -1.,    1.],
         [  -1.,   -1.,   -1.,    0.,   -1.],
         [   0.,   -6.,   -2.,   -2.,   -2.],
         [   0.,   -2.,   -1.,   -4.,   -2.],
         [   0.,    0.,   49.,    0.,    0.],
         [  -1.,   -1.,   -1.,    0.,    0.]],

        [[  -2.,    0.,   -2.,   -3.,   -3.],
         [  -2.,    0.,   -1.,   -2.,   -2.],
         [   0.,   -3.,  -11.,    7.,    0.],
         [  -8.,  -11.,   -1.,   -1.,   -2.],
         [  -1.,   -1.,   -6.,   -1.,   15.],
         [   0.,    0.,  -25.,    0.,   -3.],
         [   0.,   -1.,    0.,   -2.,    2.],
         [  -5.,    0.,    0.,   -1.,   -3.],
         [   1.,    1.,  -82.,   -6.,   -2.],
         [  -2.,   -6.,   -2.,   -4.,   -1.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Difference

--- actual
+++ expected
@@ -39,7 +39,7 @@
          [  -1.,   -1.,   -1.,    0.,   -1.],
          [   0.,   -6.,   -2.,   -2.,   -2.],
          [   0.,   -2.,   -1.,   -4.,   -2.],
-         [   0.,    0.,   49.,    0.,    0.],
+         [   0.,    0.,   48.,    0.,    0.],
          [  -1.,   -1.,   -1.,    0.,    0.]],
 
         [[  -2.,    0.,   -2.,   -3.,   -3.],

Full error stack

Tensor-likes are not close!

Mismatched elements: 1 / 250 (0.4%)
Greatest absolute difference: 1.0 at index (3, 8, 2) (up to 1e-05 allowed)
Greatest relative difference: 0.0208282470703125 at index (3, 8, 2) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)

@justinchuby
Copy link
Collaborator Author

In [13]: op.Floor(fp16([7.75])/fp16([0.1582]))
Out[13]: Tensor(array([49.], dtype=float16))

In [14]: fp16([7.75])/fp16([0.1582])
Out[14]: array([49.], dtype=float16)

numpy decides fp16([7.75])/fp16([0.1582]) is 49. Not 48. I think that is ok.

@justinchuby
Copy link
Collaborator Author

After casting to float32

Summary

The output of ONNX Runtime does not match that of PyTorch when executing test
ops_test.TestOutputConsistencyFullGraphCPU.test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16, sample 5 in ONNX Script TorchLib.

To recreate this report, use

CREATE_REPRODUCTION_REPORT=1 python -m pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k test_output_match_opinfo__div_mode_trunc_rounding_cpu_float16

Inputs

Shapes: ['Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>', 'Tensor<torch.Size([5, 10, 5]), dtype=torch.float16>']

inputs = (tensor([[[-6.6953,  6.7070, -7.7422, -5.9688, -3.5234],
         [-0.0703,  8.8828,  5.0195, -1.2656, -8.9922],
         [-1.1250,  2.6641, -6.7773,  5.7227, -5.0273],
         [ 0.3164,  1.3975,  8.9453, -0.0879, -8.9297],
         [ 2.0391,  0.8613,  5.7812, -1.8369, -7.1797],
         [ 1.1426, -5.4570, -7.6719, -8.5312,  1.0459],
         [ 5.9062,  8.0781, -3.4883, -3.8496,  5.4844],
         [ 8.5000,  3.6836,  3.6992, -8.0938,  7.4805],
         [-6.7773,  1.9863,  0.9756, -4.5273,  4.1484],
         [-8.1406, -7.6641,  7.5586, -0.9756,  7.8672]],

        [[-1.2920, -6.0391,  6.1953, -1.7842,  1.4238],
         [ 6.3203, -0.5977,  5.9766,  8.7422, -7.5938],
         [ 5.0898, -1.9688, -5.5117,  2.4258, -1.8369],
         [ 8.0391,  4.4219,  7.8672,  3.4375,  1.1777],
         [ 1.7051, -5.4844, -3.3828,  0.2812, -2.9609],
         [-4.9648, -0.6152,  7.7500, -4.8789,  3.2871],
         [ 0.1846, -2.4180,  5.8789, -8.6719, -6.9883],
         [ 3.1992, -4.0625, -5.1602,  6.8125,  2.4785],
         [-8.7734, -7.0234,  2.4258, -1.9951,  5.3438],
         [-6.5117, -1.2832, -0.5713,  4.2188, -5.0273]],

        [[-0.0439,  7.2773, -1.8369, -3.2168,  4.5273],
         [-4.2031,  1.6787,  0.4219,  4.4922, -8.8594],
         [ 1.5029,  1.9248, -7.6211, -4.8164,  0.3955],
         [ 3.2695, -4.4570, -8.1406,  7.8320,  6.5391],
         [ 0.3164, -6.5469,  3.0586, -4.6406, -6.3555],
         [-7.4180, -3.6738,  0.8613, -2.8555, -0.2812],
         [ 1.0107, -7.3281, -6.7852, -0.3867, -0.8525],
         [ 8.6328,  1.4062,  2.2422,  2.8301, -7.8828],
         [-8.8516,  6.9609,  8.7969,  6.4531, -4.0352],
         [-4.0000,  2.8301,  7.6562, -2.9805,  6.9531]],

        [[-3.5078,  2.5234,  8.2812, -1.5029,  5.8203],
         [-0.5537,  5.6094, -7.6484, -4.7031,  3.3828],
         [ 8.6562, -2.1094,  0.9053,  8.1562,  3.4453],
         [ 7.1172, -6.8477,  1.5381,  0.3340,  4.1836],
         [ 5.9844, -2.3379, -4.2812,  7.6094,  7.1797],
         [-0.9141, -5.4219, -4.3945, -1.7314,  4.4375],
         [-1.0723,  7.2422, -8.6953, -2.9883, -8.8359],
         [-6.1875,  6.8125, -4.5078,  8.8359, -5.4922],
         [-3.2617, -0.4658,  7.7500, -4.8672,  1.2480],
         [ 0.7998,  3.1016, -1.6787,  7.0234, -2.6191]],

        [[ 8.4219, -1.5029, -8.5625,  5.8359,  6.1250],
         [-3.9551,  1.6611,  1.0898, -7.6914, -6.2500],
         [-4.0000,  2.9961, -6.3281,  6.0742, -3.0156],
         [-1.6436,  6.8750,  0.4658, -0.1934, -5.1055],
         [-3.6211, -0.2812, -6.6875,  0.4570,  5.0195],
         [-0.1846,  5.4414, -8.9688,  0.8965,  1.7139],
         [-2.2070, -1.3008, -0.2900, -1.2393,  3.0664],
         [-8.5938, -1.3887, -1.7754, -0.7822,  7.1992],
         [-3.2422, -6.6445, -5.7578, -3.4180,  0.9229],
         [ 5.8789,  5.7812, -3.2344, -7.9531, -1.7051]]], dtype=torch.float16), tensor([[[ 6.9062,  5.3008,  0.5010,  8.3516,  3.0312],
         [-2.4961,  4.1562, -7.9805,  7.2852, -2.0566],
         [ 1.1074,  3.4453,  7.9727, -0.1055,  5.6250],
         [ 0.5010, -3.4375, -4.5000,  6.5938, -1.9600],
         [ 0.8350,  5.5977,  1.0283, -4.0859,  5.2734],
         [-4.7188,  2.0742,  7.7695, -1.1250,  2.3828],
         [ 1.6523,  8.7734, -2.8203,  3.9199, -3.9023],
         [-2.0664, -1.4414, -6.8359,  8.4531,  0.6592],
         [-7.6562,  5.7578,  3.8145, -0.2461,  0.7471],
         [-2.8828,  2.4961,  8.5547,  1.9248,  0.5977]],

        [[ 0.4043, -0.1406, -8.1797,  6.5820, -0.5889],
         [-0.1143, -5.0273,  1.7842, -0.9932, -7.0938],
         [-4.8438, -5.2109, -6.0117, -4.1641,  5.3711],
         [ 2.5312,  2.2227,  6.4414, -5.8711, -8.8047],
         [ 6.6875, -4.4219,  6.6523,  7.3477,  0.8701],
         [-1.9600,  1.5732,  1.9512,  7.9531, -2.3125],
         [ 5.7812, -5.1250,  5.3359, -2.2500, -0.0791],
         [ 2.8477,  8.8281, -4.4727,  0.1318, -2.0469],
         [-6.0547,  6.6641, -6.7422, -4.6406,  8.6719],
         [-1.1338, -7.6016,  4.1641, -1.5205,  5.8281]],

        [[ 7.9727, -2.9453,  5.9219, -5.6680,  4.7812],
         [ 8.1562, -0.5977, -2.9355,  3.7441,  5.3789],
         [ 2.3477,  7.3281,  5.0625,  3.6562,  1.4678],
         [ 7.3906,  0.3691,  4.8789, -6.5039, -5.5195],
         [-8.9062,  5.3281,  2.6445,  7.8828,  2.2148],
         [-4.6250, -2.3828, -8.7031, -2.6191, -1.5469],
         [-8.2188, -2.1875, -3.6914,  8.0938, -0.4834],
         [ 5.2812, -0.0264, -5.3359, -8.0000, -3.5859],
         [ 7.9297,  1.7139, -1.7490,  4.3945,  5.4922],
         [ 4.3242,  0.6504, -5.3789,  3.2969,  3.6836]],

        [[-2.6016,  2.4258, -0.0352, -2.3203, -3.0664],
         [ 3.4531,  7.6289, -4.9922, -0.9229, -7.3203],
         [-2.5137,  6.0469,  0.8613,  8.8516,  4.8086],
         [-0.4131, -0.1406,  8.7734, -8.1641,  5.8281],
         [-3.1816,  1.6611,  6.4609, -8.6875,  6.8477],
         [ 3.2344,  6.1016,  4.6055, -6.5547, -7.1016],
         [-1.3887, -1.3359,  7.0039,  2.3906,  7.7344],
         [-8.4609, -3.7695,  7.7266, -2.6016,  2.9961],
         [-5.5195, -2.8652,  0.1582, -7.4531,  5.3711],
         [-5.5273, -5.0977,  2.0391,  8.0312, -8.3438]],

        [[-6.6797, -5.0078,  6.9883, -2.6992, -2.5488],
         [ 3.3672,  4.9492, -1.5293,  4.4375,  3.2168],
         [-7.7422, -1.2305,  0.5977,  0.8613, -5.5273],
         [ 0.2109, -0.6768, -1.1777,  4.1133,  4.2461],
         [ 4.9570,  1.3184,  1.2568, -4.0078,  0.3340],
         [-8.0469,  6.0820,  0.3604,  1.6260, -0.7910],
         [-5.5117,  1.9512, -3.4375,  1.1602,  1.4238],
         [ 1.8809, -2.5664, -6.9453,  7.3984, -3.2266],
         [-2.4082, -3.8398,  0.0703,  0.6416, -0.6240],
         [-4.6836, -0.9844,  2.2148,  2.0117,  4.5547]]], dtype=torch.float16))
kwargs = {'rounding_mode': 'trunc'}

Expected output

expected = tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
         [   0.,    2.,   -0.,   -0.,    4.],
         [  -1.,    0.,   -0.,  -54.,   -0.],
         [   0.,   -0.,   -1.,   -0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [  -0.,   -2.,   -0.,    7.,    0.],
         [   3.,    0.,    1.,   -0.,   -1.],
         [  -4.,   -2.,   -0.,   -0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,   -0.,   13.]],

        [[  -3.,   42.,   -0.,   -0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,   -0.,   -0.],
         [   3.,    1.,    1.,   -0.,   -0.],
         [   0.,    1.,   -0.,    0.,   -3.],
         [   2.,   -0.,    3.,   -0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,   -0.,    1.,   51.,   -1.],
         [   1.,   -1.,   -0.,    0.,    0.],
         [   5.,    0.,   -0.,   -2.,   -0.]],

        [[  -0.,   -2.,   -0.,    0.,    0.],
         [  -0.,   -2.,   -0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [  -0.,   -1.,    1.,   -0.,   -2.],
         [   1.,    1.,   -0.,    1.,    0.],
         [  -0.,    3.,    1.,   -0.,    1.],
         [   1.,  -53.,   -0.,   -0.,    2.],
         [  -1.,    4.,   -5.,    1.,   -0.],
         [  -0.,    4.,   -1.,   -0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [  -0.,    0.,    1.,    5.,   -0.],
         [  -3.,   -0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,   -0.,    0.],
         [  -1.,   -1.,   -0.,   -0.,    1.],
         [  -0.,   -0.,   -0.,    0.,   -0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,   -0.,   -3.,   -1.],
         [   0.,    0.,   49.,    0.,    0.],
         [  -0.,   -0.,   -0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,   -0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,   -0.,   -0.,   -1.],
         [  -0.,   -0.,   -5.,   -0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,   -0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,   -0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Actual output

actual = tensor([[[   0.,    1.,  -15.,    0.,   -1.],
         [   0.,    2.,    0.,    0.,    4.],
         [  -1.,    0.,    0.,  -54.,    0.],
         [   0.,    0.,   -1.,    0.,    4.],
         [   2.,    0.,    5.,    0.,   -1.],
         [   0.,   -2.,    0.,    7.,    0.],
         [   3.,    0.,    1.,    0.,   -1.],
         [  -4.,   -2.,    0.,    0.,   11.],
         [   0.,    0.,    0.,   18.,    5.],
         [   2.,   -3.,    0.,    0.,   13.]],

        [[  -3.,   42.,    0.,    0.,   -2.],
         [ -55.,    0.,    3.,   -8.,    1.],
         [  -1.,    0.,    0.,    0.,    0.],
         [   3.,    1.,    1.,    0.,    0.],
         [   0.,    1.,    0.,    0.,   -3.],
         [   2.,    0.,    3.,    0.,   -1.],
         [   0.,    0.,    1.,    3.,   88.],
         [   1.,    0.,    1.,   51.,   -1.],
         [   1.,   -1.,    0.,    0.,    0.],
         [   5.,    0.,    0.,   -2.,    0.]],

        [[   0.,   -2.,    0.,    0.,    0.],
         [   0.,   -2.,    0.,    1.,   -1.],
         [   0.,    0.,   -1.,   -1.,    0.],
         [   0.,  -12.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    1.,    0.,   -2.],
         [   1.,    1.,    0.,    1.,    0.],
         [   0.,    3.,    1.,    0.,    1.],
         [   1.,  -53.,    0.,    0.,    2.],
         [  -1.,    4.,   -5.,    1.,    0.],
         [   0.,    4.,   -1.,    0.,    1.]],

        [[   1.,    1., -235.,    0.,   -1.],
         [   0.,    0.,    1.,    5.,    0.],
         [  -3.,    0.,    1.,    0.,    0.],
         [ -17.,   48.,    0.,    0.,    0.],
         [  -1.,   -1.,    0.,    0.,    1.],
         [   0.,    0.,    0.,    0.,    0.],
         [   0.,   -5.,   -1.,   -1.,   -1.],
         [   0.,   -1.,    0.,   -3.,   -1.],
         [   0.,    0.,   48.,    0.,    0.],
         [   0.,    0.,    0.,    0.,    0.]],

        [[  -1.,    0.,   -1.,   -2.,   -2.],
         [  -1.,    0.,    0.,   -1.,   -1.],
         [   0.,   -2.,  -10.,    7.,    0.],
         [  -7.,  -10.,    0.,    0.,   -1.],
         [   0.,    0.,   -5.,    0.,   15.],
         [   0.,    0.,  -24.,    0.,   -2.],
         [   0.,    0.,    0.,   -1.,    2.],
         [  -4.,    0.,    0.,    0.,   -2.],
         [   1.,    1.,  -81.,   -5.,   -1.],
         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)

Shape: torch.Size([5, 10, 5])

Difference

--- actual
+++ expected
@@ -1,54 +1,54 @@
-tensor([[[   0.,    1.,  -15.,    0.,   -1.],
-         [   0.,    2.,    0.,    0.,    4.],
-         [  -1.,    0.,    0.,  -54.,    0.],
-         [   0.,    0.,   -1.,    0.,    4.],
+tensor([[[  -0.,    1.,  -15.,   -0.,   -1.],
+         [   0.,    2.,   -0.,   -0.,    4.],
+         [  -1.,    0.,   -0.,  -54.,   -0.],
+         [   0.,   -0.,   -1.,   -0.,    4.],
          [   2.,    0.,    5.,    0.,   -1.],
-         [   0.,   -2.,    0.,    7.,    0.],
-         [   3.,    0.,    1.,    0.,   -1.],
-         [  -4.,   -2.,    0.,    0.,   11.],
+         [  -0.,   -2.,   -0.,    7.,    0.],
+         [   3.,    0.,    1.,   -0.,   -1.],
+         [  -4.,   -2.,   -0.,   -0.,   11.],
          [   0.,    0.,    0.,   18.,    5.],
-         [   2.,   -3.,    0.,    0.,   13.]],
+         [   2.,   -3.,    0.,   -0.,   13.]],
 
-        [[  -3.,   42.,    0.,    0.,   -2.],
+        [[  -3.,   42.,   -0.,   -0.,   -2.],
          [ -55.,    0.,    3.,   -8.,    1.],
-         [  -1.,    0.,    0.,    0.,    0.],
-         [   3.,    1.,    1.,    0.,    0.],
-         [   0.,    1.,    0.,    0.,   -3.],
-         [   2.,    0.,    3.,    0.,   -1.],
+         [  -1.,    0.,    0.,   -0.,   -0.],
+         [   3.,    1.,    1.,   -0.,   -0.],
+         [   0.,    1.,   -0.,    0.,   -3.],
+         [   2.,   -0.,    3.,   -0.,   -1.],
          [   0.,    0.,    1.,    3.,   88.],
-         [   1.,    0.,    1.,   51.,   -1.],
-         [   1.,   -1.,    0.,    0.,    0.],
-         [   5.,    0.,    0.,   -2.,    0.]],
+         [   1.,   -0.,    1.,   51.,   -1.],
+         [   1.,   -1.,   -0.,    0.,    0.],
+         [   5.,    0.,   -0.,   -2.,   -0.]],
 
-        [[   0.,   -2.,    0.,    0.,    0.],
-         [   0.,   -2.,    0.,    1.,   -1.],
+        [[  -0.,   -2.,   -0.,    0.,    0.],
+         [  -0.,   -2.,   -0.,    1.,   -1.],
          [   0.,    0.,   -1.,   -1.,    0.],
          [   0.,  -12.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    1.,    0.,   -2.],
-         [   1.,    1.,    0.,    1.,    0.],
-         [   0.,    3.,    1.,    0.,    1.],
-         [   1.,  -53.,    0.,    0.,    2.],
-         [  -1.,    4.,   -5.,    1.,    0.],
-         [   0.,    4.,   -1.,    0.,    1.]],
+         [  -0.,   -1.,    1.,   -0.,   -2.],
+         [   1.,    1.,   -0.,    1.,    0.],
+         [  -0.,    3.,    1.,   -0.,    1.],
+         [   1.,  -53.,   -0.,   -0.,    2.],
+         [  -1.,    4.,   -5.,    1.,   -0.],
+         [  -0.,    4.,   -1.,   -0.,    1.]],
 
         [[   1.,    1., -235.,    0.,   -1.],
-         [   0.,    0.,    1.,    5.,    0.],
-         [  -3.,    0.,    1.,    0.,    0.],
-         [ -17.,   48.,    0.,    0.,    0.],
-         [  -1.,   -1.,    0.,    0.,    1.],
-         [   0.,    0.,    0.,    0.,    0.],
+         [  -0.,    0.,    1.,    5.,   -0.],
+         [  -3.,   -0.,    1.,    0.,    0.],
+         [ -17.,   48.,    0.,   -0.,    0.],
+         [  -1.,   -1.,   -0.,   -0.,    1.],
+         [  -0.,   -0.,   -0.,    0.,   -0.],
          [   0.,   -5.,   -1.,   -1.,   -1.],
-         [   0.,   -1.,    0.,   -3.,   -1.],
-         [   0.,    0.,   48.,    0.,    0.],
-         [   0.,    0.,    0.,    0.,    0.]],
+         [   0.,   -1.,   -0.,   -3.,   -1.],
+         [   0.,    0.,   49.,    0.,    0.],
+         [  -0.,   -0.,   -0.,    0.,    0.]],
 
         [[  -1.,    0.,   -1.,   -2.,   -2.],
-         [  -1.,    0.,    0.,   -1.,   -1.],
+         [  -1.,    0.,   -0.,   -1.,   -1.],
          [   0.,   -2.,  -10.,    7.,    0.],
-         [  -7.,  -10.,    0.,    0.,   -1.],
-         [   0.,    0.,   -5.,    0.,   15.],
+         [  -7.,  -10.,   -0.,   -0.,   -1.],
+         [  -0.,   -0.,   -5.,   -0.,   15.],
          [   0.,    0.,  -24.,    0.,   -2.],
-         [   0.,    0.,    0.,   -1.,    2.],
-         [  -4.,    0.,    0.,    0.,   -2.],
+         [   0.,   -0.,    0.,   -1.,    2.],
+         [  -4.,    0.,    0.,   -0.,   -2.],
          [   1.,    1.,  -81.,   -5.,   -1.],
-         [  -1.,   -5.,   -1.,   -3.,    0.]]], dtype=torch.float16)
+         [  -1.,   -5.,   -1.,   -3.,   -0.]]], dtype=torch.float16)

Full error stack

Tensor-likes are not close!

Mismatched elements: 1 / 250 (0.4%)
Greatest absolute difference: 1.0 at index (3, 8, 2) (up to 1e-05 allowed)
Greatest relative difference: 0.0204010009765625 at index (3, 8, 2) (up to 0.001 allowed)
  File "/home/justinchu/dev/onnx-script/onnxscript/tests/function_libs/torch_lib/ops_test.py", line 259, in run_test_output_match
    torch.testing.assert_close(
  File "/home/justinchu/anaconda3/envs/onnx/lib/python3.10/site-packages/torch/testing/_comparison.py", line 1511, in assert_close
    raise error_metas[0].to_error(msg)

@justinchuby
Copy link
Collaborator Author

I am now adding xfails

@justinchuby justinchuby added the merge at lgtm Reviewers can merge when they approve label Aug 8, 2023
if rounding_mode == "trunc":
# Rounds the results of the division towards zero.
# Equivalent to C-style integer division
result = aten_trunc(op.Div(self, other))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will move to a common function when #834 is done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this. Could you share more about why we can use nested OnnxFunction now?

Copy link
Collaborator Author

@justinchuby justinchuby Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a trace only function. So calling functions is fine. However we still do not like calling other aten functions. When we can have nested OnnxFunction calls, I will extract the trunc logic to a common function and call it from aten_trunc and this.

Right now I am doing this so aten_trunc doesn't become trace_only

"aten::div.Tensor",
"aten::div.Scalar",
# When rounding_mode is None, performs a true division
# https://pytorch.org/docs/stable/generated/torch.div.html
Copy link
Contributor

@titaiwangms titaiwangms Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is dispatcher expected to filter any attribute with None?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would consider this to be a better match I think? Any suggestions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do think the dispatcher should strip None keyword args

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I think that makes sense. It's just we are altering attributes, and param_schema matching is diverged from the inputs/attributes sent into OnnxFunction. It's like there are many indications around dispatching/OnnxFunction param_schema. And it's not good for debugging.

Dispatcher alters inputs/attributes with hidden assumptions, but never return the altered inputs/attributes. So in OnnxFunction perspective, it runs directly on that dispatched function with attributes it doesn't need (won't error).

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want an implicit logic that when we feeds args/kwargs into param_schema, we ignore the attributes with None? Like they never exists, so they won't match the function signature having that attribute? I think I don't have this in dispatcher yet.

I take it as we do, as I see there is no function considering when certain attribute is None. My concern is that in what level should we expose this information for debugging purpose.

cc @BowenBao

@justinchuby
Copy link
Collaborator Author

so they won't match the function signature having that attribute?

I think so. And unless that attribute has a default value of course

@titaiwangms
Copy link
Contributor

I could add it into pytorch/pytorch#106478.

def aten_divide(self: TensorType, other: TensorType) -> TensorType:
"""divide.Tensor(Tensor self, Tensor other) -> Tensor"""
@torch_op(("aten::div.Tensor_mode", "aten::div.Scalar_mode"), trace_only=True)
def aten_div_mode(self: TFloat, other: TFloat, rounding_mode: str) -> TFloat:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you recall what kind of attributes have default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str, int, float, bool attributes can have defaults I think. But I suppose any attributes should be able to have defaults with the attribute proto. Is this what you are asking?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a situation that two ONNX variants only differs on one default attribute. In that case, the dispatcher won't be able to dispatch it.

aten_op_attr(X, Y, attr="Good"):
    ...

aten_op(X, Y):
    ...

Copy link
Collaborator Author

@justinchuby justinchuby Aug 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True. I would just hope/make sure that we don’t create variants like these.

I wonder if there is a way to test it. I think the matching logic you created can come in handy here. We can use that to test all variants registered in torchlib are not compatible with each other.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In dispatcher if we do see this case we can only pick one I suppose?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supposedly, if I pick any from them, there shouldn't be an issue, because they should be equal when it comes to no attr specified.

@justinchuby
Copy link
Collaborator Author

justinchuby commented Aug 8, 2023

Windows test: please don’t fail here. Fail in #986 instead

Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just tested out that we actually don't need to change anything in pytorch/pytorch#106478. The attribute=None can't be perfect match anyway, as it's a NoneType. While aten::div param_schemas ignores None attribute, and is dispatched.

@justinchuby justinchuby merged commit ba255f7 into main Aug 8, 2023
@justinchuby justinchuby deleted the justinchu/div-mode branch August 8, 2023 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
merge at lgtm Reviewers can merge when they approve module: torchlib Related to the torch/aten function lib in development
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support aten::div.Tensor_mode
2 participants