Skip to content

Commit e80f7ae

Browse files
committed
add the code for heaviside op
1 parent acf7fdd commit e80f7ae

File tree

8 files changed

+187
-0
lines changed

8 files changed

+187
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13127,6 +13127,53 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
1312713127
let hasFolder = 1;
1312813128
}
1312913129

13130+
def Torch_AtenHeavisideOp : Torch_Op<"aten.heaviside", [
13131+
AllowsTypeRefinement,
13132+
HasValueSemantics,
13133+
ReadOnly
13134+
]> {
13135+
let summary = "Generated op for `aten::heaviside : (Tensor, Tensor) -> (Tensor)`";
13136+
let arguments = (ins
13137+
AnyTorchTensorType:$self,
13138+
AnyTorchTensorType:$values
13139+
);
13140+
let results = (outs
13141+
AnyTorchOptionalTensorType:$result
13142+
);
13143+
let hasCustomAssemblyFormat = 1;
13144+
let extraClassDefinition = [{
13145+
ParseResult AtenHeavisideOp::parse(OpAsmParser &parser, OperationState &result) {
13146+
return parseDefaultTorchOp(parser, result, 2, 1);
13147+
}
13148+
void AtenHeavisideOp::print(OpAsmPrinter &printer) {
13149+
printDefaultTorchOp(printer, *this, 2, 1);
13150+
}
13151+
}];
13152+
}
13153+
13154+
def Torch_AtenHeaviside_Op : Torch_Op<"aten.heaviside_", [
13155+
IsTrailingUnderscoreInplaceVariant,
13156+
AllowsTypeRefinement
13157+
]> {
13158+
let summary = "Generated op for `aten::heaviside_ : (Tensor, Tensor) -> (Tensor)`";
13159+
let arguments = (ins
13160+
Torch_NonValueTensorType:$self,
13161+
Torch_NonValueTensorType:$values
13162+
);
13163+
let results = (outs
13164+
AnyTorchOptionalNonValueTensorType:$result
13165+
);
13166+
let hasCustomAssemblyFormat = 1;
13167+
let extraClassDefinition = [{
13168+
ParseResult AtenHeaviside_Op::parse(OpAsmParser &parser, OperationState &result) {
13169+
return parseDefaultTorchOp(parser, result, 2, 1);
13170+
}
13171+
void AtenHeaviside_Op::print(OpAsmPrinter &printer) {
13172+
printDefaultTorchOp(printer, *this, 2, 1);
13173+
}
13174+
}];
13175+
}
13176+
1313013177
def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
1313113178
AllowsTypeRefinement,
1313213179
HasValueSemantics,

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9675,6 +9675,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
96759675
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
96769676
" return %0 : !torch.list<int>\n"
96779677
" }\n"
9678+
" func.func @\"__torch_mlir_shape_fn.aten.heaviside\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
9679+
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
9680+
" return %0 : !torch.list<int>\n"
9681+
" }\n"
96789682
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
96799683
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
96809684
" return %0 : !torch.list<int>\n"
@@ -15239,6 +15243,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1523915243
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1524015244
" return %4 : !torch.int\n"
1524115245
" }\n"
15246+
" func.func @\"__torch_mlir_dtype_fn.aten.heaviside\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
15247+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15248+
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
15249+
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
15250+
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
15251+
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
15252+
" return %4 : !torch.int\n"
15253+
" }\n"
1524215254
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
1524315255
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
1524415256
" return %0#1 : !torch.int\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11150,6 +11150,76 @@ class DecomposeAtenSgnOp : public OpRewritePattern<AtenSgnOp> {
1115011150
};
1115111151
} // namespace
1115211152

11153+
namespace {
11154+
// Decomposes aten.heaviside op into
11155+
// using aten.eq, aten.lt, aten.logical_or, aten.where
11156+
// Heaviside(x, y) returns:
11157+
// 0 if x < 0
11158+
// y if x == 0
11159+
// 1 if x > 0
11160+
class DecomposeAtenHeaviside : public OpRewritePattern<AtenHeavisideOp> {
11161+
public:
11162+
using OpRewritePattern::OpRewritePattern;
11163+
LogicalResult matchAndRewrite(AtenHeavisideOp op,
11164+
PatternRewriter &rewriter) const override {
11165+
auto input = op.getSelf();
11166+
auto value = op.getValues();
11167+
auto loc = op.getLoc();
11168+
auto inputTy = dyn_cast<BaseTensorType>(input.getType());
11169+
if (!inputTy || !inputTy.hasDtype() || !inputTy.hasSizes())
11170+
return rewriter.notifyMatchFailure(op, "input must have dtype and size.");
11171+
11172+
auto valueTy = dyn_cast<BaseTensorType>(value.getType());
11173+
if (!valueTy || !valueTy.hasDtype() || !valueTy.hasSizes())
11174+
return rewriter.notifyMatchFailure(op, "value must have dtype and size.");
11175+
auto resultTy = dyn_cast<BaseTensorType>(op.getType());
11176+
SmallVector<int64_t> broadcastShape;
11177+
SmallVector<Value> broadcastShapeValue;
11178+
computeBroadcastShape(rewriter, loc, input, value, broadcastShape,
11179+
broadcastShapeValue);
11180+
11181+
auto broadcastType = ValueTensorType::get(
11182+
op.getContext(), llvm::ArrayRef(broadcastShape), resultTy.getDtype());
11183+
auto boolBroadcastType = ValueTensorType::get(
11184+
op.getContext(), llvm::ArrayRef(broadcastShape), rewriter.getI1Type());
11185+
Value indexBroadcastShapeTorchList = rewriter.create<PrimListConstructOp>(
11186+
loc, Torch::ListType::get(Torch::IntType::get(op.getContext())),
11187+
broadcastShapeValue);
11188+
auto inputBroadcasted = rewriter.create<AtenBroadcastToOp>(
11189+
loc, broadcastType, input, indexBroadcastShapeTorchList);
11190+
auto valueBroadcasted = rewriter.create<AtenBroadcastToOp>(
11191+
loc, broadcastType, value, indexBroadcastShapeTorchList);
11192+
11193+
Value zero = getConstantWithGivenDtypeAndValue(rewriter, loc, 0,
11194+
resultTy.getDtype());
11195+
Value one = getConstantWithGivenDtypeAndValue(rewriter, loc, 1,
11196+
resultTy.getDtype());
11197+
// Compute mask: input == 0
11198+
auto inputEqZero = rewriter
11199+
.create<AtenEqScalarOp>(loc, boolBroadcastType,
11200+
inputBroadcasted, zero)
11201+
->getResult(0);
11202+
// Compute mask: input < 0
11203+
auto inputLtZero = rewriter.create<AtenLtScalarOp>(loc, boolBroadcastType,
11204+
inputBroadcasted, zero);
11205+
// Compute mask: isnan(input)
11206+
auto isNan =
11207+
rewriter.create<AtenIsnanOp>(loc, boolBroadcastType, inputBroadcasted);
11208+
// Combine: input < 0 || isnan(input)
11209+
auto inputNegativeOrNan = rewriter.create<AtenLogicalOrOp>(
11210+
loc, boolBroadcastType, inputLtZero, isNan);
11211+
// Select 0 if input < 0 or input is nan, else 1
11212+
auto zerosOrOnes = rewriter.create<AtenWhereScalarOp>(
11213+
loc, resultTy, inputNegativeOrNan, zero, one);
11214+
// Final result: if input == 0, take from valueBroadcasted, else take from
11215+
// zerosOrOnes
11216+
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, resultTy, inputEqZero,
11217+
valueBroadcasted, zerosOrOnes);
11218+
return success();
11219+
}
11220+
};
11221+
} // namespace
11222+
1115311223
namespace {
1115411224
// Unconditionally decompose `torch.type_as` into `prim.dtype` +
1115511225
// `torch.to.dtype`.
@@ -12374,6 +12444,7 @@ class DecomposeComplexOpsPass
1237412444
DecomposeConstantTensorNewLikeOp<AtenNewOnesOp, AtenOnesOp>>(patterns);
1237512445
addPatternIfTargetOpIsIllegal<DecomposeAtenHardtanhOp>(patterns);
1237612446
addPatternIfTargetOpIsIllegal<DecomposeAtenFullOp>(patterns);
12447+
addPatternIfTargetOpIsIllegal<DecomposeAtenHeaviside>(patterns);
1237712448
addPatternIfTargetOpIsIllegal<DecomposeAtenLinearOp>(patterns);
1237812449
addPatternIfTargetOpIsIllegal<DecomposeAtenMishOp>(patterns);
1237912450
addPatternIfTargetOpIsIllegal<DecomposeAtenFullLikeOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
461461
target.addIllegalOp<AtenSquareOp>();
462462
target.addIllegalOp<AtenVarOp>();
463463
target.addIllegalOp<AtenStdOp>();
464+
target.addIllegalOp<AtenHeavisideOp>();
464465
target.addIllegalOp<Aten_UnsafeViewOp>();
465466
target.addIllegalOp<Aten_ReshapeAliasOp>();
466467
target.addIllegalOp<AtenBernoulliOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,6 +1251,7 @@
12511251
"ElementwiseToDtypeI64ToI8Module_basic",
12521252
"ElementwiseToDtypeIdentityModule_basic",
12531253
"ElementwiseUnaryModule_basic",
1254+
"ElementwiseHeavisideModule_basic",
12541255
"EmptyLikeMemoryFormatModule_basic",
12551256
"EmptyLikeModule_defaultDtype",
12561257
"EmptyLikeModule_falsePinMemory",
@@ -1855,6 +1856,7 @@
18551856
"ElementwiseFracModule_basic",
18561857
"ElementwiseLdexpModule_basic",
18571858
"ElementwiseSignbitIntModule_basic",
1859+
"ElementwiseHeavisideModule_basic",
18581860
"Exp2StaticIntModule_basic",
18591861
"MaxPool1dEmptyStrideStaticModule_basic",
18601862
"MaxPool1dStaticCeilModeTrueModule_basic",
@@ -2968,6 +2970,8 @@
29682970
"GtFloatIntModule_basic",
29692971
"GtIntModule_basic",
29702972
"HardtanhBackward_basic",
2973+
"ElementwiseHeavisideModule_basic",
2974+
"ElementwiseHeavisideIntModule_basic",
29712975
"HstackBasicComplexModule_basic",
29722976
"HstackBasicFloatModule_basic",
29732977
"HstackBasicIntFloatModule_basic",
@@ -3983,6 +3987,8 @@
39833987
"ElementwiseRreluWithNoiseEvalStaticModule_basic",
39843988
"ElementwiseRreluWithNoiseTrainModule_basic",
39853989
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
3990+
"ElementwiseHeavisideModule_basic",
3991+
"ElementwiseHeavisideIntModule_basic",
39863992
"RreluWithNoiseBackwardEvalModule_basic",
39873993
"RreluWithNoiseBackwardEvalStaticModule_basic",
39883994
"RreluWithNoiseBackwardTrainModule_basic",

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
17751775
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
17761776
return upstream_shape_functions.broadcast(condition, other)
17771777

1778+
def aten〇heaviside〡shape(self: List[int], values: List[int]) -> List[int]:
1779+
return upstream_shape_functions.broadcast(self, values)
1780+
17781781
def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
17791782
return upstream_shape_functions.unary(self)
17801783

@@ -5088,6 +5091,14 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
50885091
dtypes = [get_dtype_of_scalar(self), other_dtype]
50895092
return promote_dtypes(ranks, dtypes)
50905093

5094+
def aten〇heaviside〡dtype(self_rank_dtype: Tuple[int, int], values_rank_dtype: Tuple[int, int]) -> int:
5095+
self_rank,self_dtype = self_rank_dtype
5096+
values_rank,values_dtype = values_rank_dtype
5097+
ranks: List[Optional[int]] = [self_rank, values_rank]
5098+
dtypes = [self_dtype, values_dtype]
5099+
promoted_dtype = promote_dtypes(ranks, dtypes)
5100+
return promoted_dtype
5101+
50915102
@check_dtype_function(
50925103
_check_tensors_with_the_same_dtype(num_of_tensors=1))
50935104
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,7 @@ def emit_with_mutating_variants(key, **kwargs):
960960
emit(
961961
"aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True
962962
)
963+
emit_with_mutating_variants("aten::heaviside : (Tensor, Tensor) -> (Tensor)")
963964
emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)")
964965
emit(
965966
"aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)",

projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,44 @@ def ElementwiseLtFloatScalarModule_basic(module, tu: TestUtils):
298298
# ==============================================================================
299299

300300

301+
class ElementwiseHeavisideModule(torch.nn.Module):
302+
def __init__(self):
303+
super().__init__()
304+
305+
@export
306+
@annotate_args([None, ([5], torch.float32, True), ([1], torch.float32, True)])
307+
def forward(self, x, values):
308+
return torch.heaviside(x, values)
309+
310+
311+
@register_test_case(module_factory=lambda: ElementwiseHeavisideModule())
312+
def ElementwiseHeavisideModule_basic(module, tu: TestUtils):
313+
module.forward(
314+
torch.tensor([1.0, -2.0, torch.inf, torch.nan, -torch.inf]), torch.tensor([5.0])
315+
)
316+
317+
318+
class ElementwiseHeavisideIntModule(torch.nn.Module):
319+
def __init__(self):
320+
super().__init__()
321+
322+
@export
323+
@annotate_args([None, ([-1, -1], torch.int32, True), ([-1], torch.int32, True)])
324+
def forward(self, x, values):
325+
return torch.heaviside(x, values)
326+
327+
328+
@register_test_case(module_factory=lambda: ElementwiseHeavisideIntModule())
329+
def ElementwiseHeavisideIntModule_basic(module, tu: TestUtils):
330+
module.forward(
331+
tu.randint(5, 1, low=-100, high=1000).to(torch.int32),
332+
tu.randint(1, low=-100, high=1000).to(torch.int32),
333+
)
334+
335+
336+
# ==============================================================================
337+
338+
301339
class ElementwiseLtIntScalarModule(torch.nn.Module):
302340
def __init__(self):
303341
super().__init__()

0 commit comments

Comments
 (0)