Skip to content

Commit 8e58016

Browse files
authored
Remove skips in scaled_dot_product_attention | test(torchlib) (#970)
Fixes #968
1 parent 2f39b94 commit 8e58016

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,9 +1691,7 @@ def _where_input_wrangler(
16911691
"nn.functional.scaled_dot_product_attention",
16921692
nn_ops.aten_scaled_dot_product_attention,
16931693
trace_only=True,
1694-
)
1695-
.skip(
1696-
reason="fixme: ORT crashes on Windows, segfaults randomly on Linux",
1694+
tolerance={torch.float32: (1e-5, 1e-5)},
16971695
)
16981696
.skip(
16991697
matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None
@@ -1708,9 +1706,7 @@ def _where_input_wrangler(
17081706
"nn.functional.scaled_dot_product_attention_bool_mask",
17091707
nn_ops.aten_scaled_dot_product_attention_bool_mask,
17101708
trace_only=True,
1711-
)
1712-
.skip(
1713-
reason="fixme: ORT crashes on Windows, segfaults randomly on Linux",
1709+
tolerance={torch.float32: (1e-5, 1e-5)},
17141710
)
17151711
.skip(
17161712
matcher=lambda sample: (attn_mask := sample.kwargs.get("attn_mask")) is not None

0 commit comments

Comments
 (0)