Skip to content

Commit 393db9b

Browse files
committed
Add test case for condition where broadcast is not required
1 parent ffd429c commit 393db9b

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2972,6 +2972,7 @@
29722972
"HardtanhBackward_basic",
29732973
"ElementwiseHeavisideModule_basic",
29742974
"ElementwiseHeavisideIntModule_basic",
2975+
"ElementwiseHeavisideNoBroadcastModule_basic",
29752976
"HstackBasicComplexModule_basic",
29762977
"HstackBasicFloatModule_basic",
29772978
"HstackBasicIntFloatModule_basic",
@@ -3989,6 +3990,7 @@
39893990
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
39903991
"ElementwiseHeavisideModule_basic",
39913992
"ElementwiseHeavisideIntModule_basic",
3993+
"ElementwiseHeavisideNoBroadcastModule_basic",
39923994
"RreluWithNoiseBackwardEvalModule_basic",
39933995
"RreluWithNoiseBackwardEvalStaticModule_basic",
39943996
"RreluWithNoiseBackwardTrainModule_basic",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,26 @@ def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
335335
)
336336

337337

338+
class ElementwiseHeavisideNoBroadcastModule(torch.nn.Module):
339+
def __init__(self):
340+
super().__init__()
341+
342+
@export
343+
@annotate_args(
344+
[None, ([-1, -1], torch.float32, True), ([-1, -1], torch.float32, True)]
345+
)
346+
def forward(self, x, values):
347+
return torch.heaviside(x, values)
348+
349+
350+
@register_test_case(module_factory=lambda: ElementwiseHeavisideNoBroadcastModule())
351+
def ElementwiseHeavisideNoBroadcastModule_basic(module, tu: TestUtils):
352+
module.forward(
353+
tu.rand(5, 8),
354+
tu.rand(5, 8),
355+
)
356+
357+
338358
# ==============================================================================
339359

340360

0 commit comments

Comments
 (0)