We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 9f65061 commit c477cc6Copy full SHA for c477cc6
onnxscript/tests/function_libs/torch_lib/ops_test_data.py
@@ -1691,6 +1691,7 @@ def _where_input_wrangler(
1691
"nn.functional.scaled_dot_product_attention",
1692
nn_ops.aten_scaled_dot_product_attention,
1693
trace_only=True,
1694
+ tolerance={torch.float32: (1e-5, 1e-5)},
1695
)
1696
.skip(
1697
matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None
@@ -1705,6 +1706,7 @@ def _where_input_wrangler(
1705
1706
"nn.functional.scaled_dot_product_attention_bool_mask",
1707
nn_ops.aten_scaled_dot_product_attention_bool_mask,
1708
1709
1710
1711
1712
0 commit comments