Skip to content

Commit 52c8531

Browse files
authored
Use the builtin bernoulli test | test(torchlib) (#850)
There is builtin test for the PyTorch op `bernoulli`. We should use that to ensure consistency with PyTorch. Also updated deterministic tests for bernoulli_p.
1 parent 39f3f8e commit 52c8531

2 files changed

Lines changed: 21 additions & 25 deletions

File tree

onnxscript/tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ def sample_inputs_bernoulli_p(op_info, device, dtype, requires_grad, **kwargs):
501501
yield opinfo_core.SampleInput(t, kwargs={"p": p})
502502

503503

504-
def sample_inputs_bernoulli_default(op_info, device, dtype, requires_grad, **kwargs):
504+
def sample_inputs_bernoulli_p_deterministic(op_info, device, dtype, requires_grad, **kwargs):
505505
del op_info
506506

507507
shapes = [
@@ -512,16 +512,18 @@ def sample_inputs_bernoulli_default(op_info, device, dtype, requires_grad, **kwa
512512
]
513513

514514
for shape in shapes:
515-
t = torch_testing.make_tensor(
516-
shape,
517-
low=0,
518-
high=1,
519-
device=device,
520-
dtype=dtype,
521-
requires_grad=requires_grad,
522-
**kwargs,
523-
)
524-
yield opinfo_core.SampleInput(t)
515+
for p in (0, 1):
516+
t = torch_testing.make_tensor(
517+
shape,
518+
low=0,
519+
high=1,
520+
device=device,
521+
dtype=dtype,
522+
requires_grad=requires_grad,
523+
**kwargs,
524+
)
525+
yield opinfo_core.SampleInput(t, args=(p,))
526+
yield opinfo_core.SampleInput(t, kwargs={"p": p})
525527

526528

527529
OP_DB: List[opinfo_core.OpInfo] = [
@@ -680,10 +682,11 @@ def sample_inputs_bernoulli_default(op_info, device, dtype, requires_grad, **kwa
680682
sample_inputs_func=sample_inputs_bernoulli_p,
681683
),
682684
opinfo_core.OpInfo(
683-
"aten.bernoulli",
684-
aten_name="bernoulli",
685-
op=torch.ops.aten.bernoulli.default,
686-
dtypes=common_dtype.floating_types(),
687-
sample_inputs_func=sample_inputs_bernoulli_default,
685+
# Deterministic bernoulli sampling where p is either 0 or 1
686+
"aten.bernoulli.p_deterministic",
687+
aten_name="bernoulli.p",
688+
op=torch.ops.aten.bernoulli.p,
689+
dtypes=common_dtype.all_types(),
690+
sample_inputs_func=sample_inputs_bernoulli_p_deterministic,
688691
),
689692
]

onnxscript/tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,7 @@ def _where_input_wrangler(
495495
reason="atleast_3d_single_tensor overload takes single tensor as input",
496496
),
497497
TorchLibOpInfo("baddbmm", core_ops.aten_baddbmm),
498-
TorchLibOpInfo(
499-
# This string is a unique ID. In extra_opinfo.py, we
500-
# also define test data for this ID with
501-
# `opinfo_core.OpInfo("aten.bernoulli.p", ...)`.
502-
"aten.bernoulli",
503-
core_ops.aten_bernoulli,
504-
# Skip comparison for the output of this op because it is a random tensor.
505-
nondeterministic=True,
506-
),
498+
TorchLibOpInfo("bernoulli", core_ops.aten_bernoulli, nondeterministic=True),
507499
TorchLibOpInfo(
508500
# This string is a unique ID. In extra_opinfo.py, we
509501
# also define test data for this ID with
@@ -513,6 +505,7 @@ def _where_input_wrangler(
513505
# Skip comparison for the output of this op because it is a random tensor.
514506
nondeterministic=True,
515507
),
508+
TorchLibOpInfo("aten.bernoulli.p_deterministic", core_ops.aten_bernoulli_p),
516509
TorchLibOpInfo("bmm", core_ops.aten_bmm),
517510
TorchLibOpInfo("broadcast_to", core_ops.aten_broadcast_to),
518511
TorchLibOpInfo("cat", core_ops.aten_cat).skip(

0 commit comments

Comments
 (0)