Skip to content

Commit c477cc6

Browse files
committed
Update tolerance
1 parent 9f65061 commit c477cc6

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,6 +1691,7 @@ def _where_input_wrangler(
16911691
"nn.functional.scaled_dot_product_attention",
16921692
nn_ops.aten_scaled_dot_product_attention,
16931693
trace_only=True,
1694+
tolerance={torch.float32: (1e-5, 1e-5)},
16941695
)
16951696
.skip(
16961697
matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None
@@ -1705,6 +1706,7 @@ def _where_input_wrangler(
17051706
"nn.functional.scaled_dot_product_attention_bool_mask",
17061707
nn_ops.aten_scaled_dot_product_attention_bool_mask,
17071708
trace_only=True,
1709+
tolerance={torch.float32: (1e-5, 1e-5)},
17081710
)
17091711
.skip(
17101712
matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None

0 commit comments

Comments
 (0)