Skip to content

Commit 01e80a0

Browse files
[mlir] Add maxnumf and minnumf to AtomicRMWKind (llvm#66442)
This commit adds the mentioned kinds of `AtomicRMWKind` as well as code generation for them.
1 parent 52b33ff commit 01e80a0

File tree

4 files changed

+27
-6
lines changed

4 files changed

+27
-6
lines changed

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,16 @@ def ATOMIC_RMW_KIND_MULF : I64EnumAttrCase<"mulf", 9>;
8282
def ATOMIC_RMW_KIND_MULI : I64EnumAttrCase<"muli", 10>;
8383
def ATOMIC_RMW_KIND_ORI : I64EnumAttrCase<"ori", 11>;
8484
def ATOMIC_RMW_KIND_ANDI : I64EnumAttrCase<"andi", 12>;
85+
def ATOMIC_RMW_KIND_MAXNUMF : I64EnumAttrCase<"maxnumf", 13>;
86+
def ATOMIC_RMW_KIND_MINNUMF : I64EnumAttrCase<"minnumf", 14>;
8587

8688
def AtomicRMWKindAttr : I64EnumAttr<
8789
"AtomicRMWKind", "",
8890
[ATOMIC_RMW_KIND_ADDF, ATOMIC_RMW_KIND_ADDI, ATOMIC_RMW_KIND_ASSIGN,
8991
ATOMIC_RMW_KIND_MAXIMUMF, ATOMIC_RMW_KIND_MAXS, ATOMIC_RMW_KIND_MAXU,
9092
ATOMIC_RMW_KIND_MINIMUMF, ATOMIC_RMW_KIND_MINS, ATOMIC_RMW_KIND_MINU,
9193
ATOMIC_RMW_KIND_MULF, ATOMIC_RMW_KIND_MULI, ATOMIC_RMW_KIND_ORI,
92-
ATOMIC_RMW_KIND_ANDI]> {
94+
ATOMIC_RMW_KIND_ANDI, ATOMIC_RMW_KIND_MAXNUMF, ATOMIC_RMW_KIND_MINNUMF]> {
9395
let cppNamespace = "::mlir::arith";
9496
}
9597

mlir/lib/Dialect/Arith/IR/ArithOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2523,6 +2523,10 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
25232523
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
25242524
case AtomicRMWKind::minimumf:
25252525
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
2526+
case AtomicRMWKind::maxnumf:
2527+
return builder.create<arith::MaxNumFOp>(loc, lhs, rhs);
2528+
case AtomicRMWKind::minnumf:
2529+
return builder.create<arith::MinNumFOp>(loc, lhs, rhs);
25262530
case AtomicRMWKind::maxs:
25272531
return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
25282532
case AtomicRMWKind::mins:

mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
2121
#include "mlir/IR/TypeUtilities.h"
2222
#include "mlir/Transforms/DialectConversion.h"
23+
#include "llvm/ADT/STLExtras.h"
2324

2425
namespace mlir {
2526
namespace memref {
@@ -126,8 +127,10 @@ struct ExpandOpsPass : public memref::impl::ExpandOpsBase<ExpandOpsPass> {
126127
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect>();
127128
target.addDynamicallyLegalOp<memref::AtomicRMWOp>(
128129
[](memref::AtomicRMWOp op) {
129-
return op.getKind() != arith::AtomicRMWKind::maximumf &&
130-
op.getKind() != arith::AtomicRMWKind::minimumf;
130+
constexpr std::array shouldBeExpandedKinds = {
131+
arith::AtomicRMWKind::maximumf, arith::AtomicRMWKind::minimumf,
132+
arith::AtomicRMWKind::minnumf, arith::AtomicRMWKind::maxnumf};
133+
return !llvm::is_contained(shouldBeExpandedKinds, op.getKind());
131134
});
132135
target.addDynamicallyLegalOp<memref::ReshapeOp>([](memref::ReshapeOp op) {
133136
return !cast<MemRefType>(op.getShape().getType()).hasStaticShape();

mlir/test/Dialect/MemRef/expand-ops.mlir

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
// CHECK-LABEL: func @atomic_rmw_to_generic
44
// CHECK-SAME: ([[F:%.*]]: memref<10xf32>, [[f:%.*]]: f32, [[i:%.*]]: index)
55
func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
6-
%x = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7-
%y = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
8-
return %x : f32
6+
%a = memref.atomic_rmw maximumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
7+
%b = memref.atomic_rmw minimumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
8+
%c = memref.atomic_rmw maxnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
9+
%d = memref.atomic_rmw minnumf %f, %F[%i] : (f32, memref<10xf32>) -> f32
10+
return %a : f32
911
}
1012
// CHECK: [[RESULT:%.*]] = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
1113
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
@@ -17,6 +19,16 @@ func.func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32
1719
// CHECK: [[MINIMUM:%.*]] = arith.minimumf [[CUR_VAL]], [[f]] : f32
1820
// CHECK: memref.atomic_yield [[MINIMUM]] : f32
1921
// CHECK: }
22+
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
23+
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
24+
// CHECK: [[MAXNUM:%.*]] = arith.maxnumf [[CUR_VAL]], [[f]] : f32
25+
// CHECK: memref.atomic_yield [[MAXNUM]] : f32
26+
// CHECK: }
27+
// CHECK: memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
28+
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
29+
// CHECK: [[MINNUM:%.*]] = arith.minnumf [[CUR_VAL]], [[f]] : f32
30+
// CHECK: memref.atomic_yield [[MINNUM]] : f32
31+
// CHECK: }
2032
// CHECK: return [[RESULT]] : f32
2133

2234
// -----

0 commit comments

Comments
 (0)