diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 0d4d734edd2b6..ecfc5dd2e8e1d 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -574,6 +574,26 @@ def NVVM_SyncWarpOp : } +def NVVM_ElectSyncOp : NVVM_Op<"elect.sync", + [DeclareOpInterfaceMethods]> +{ + let results = (outs I1:$pred); + let assemblyFormat = "attr-dict `->` type(results)"; + let extraClassDefinition = [{ + std::string $cppClass::getPtx() { + return std::string( + "{ \n" + ".reg .u32 rx; \n" + ".reg .pred px; \n" + " mov.u32 %0, 0; \n" + " elect.sync rx | px, 0xFFFFFFFF;\n" + "@px mov.u32 %0, 1; \n" + "}\n" + ); + } + }]; +} + def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">; def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">; def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">; diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 2d7a441e95004..15703fb99339e 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -63,6 +63,8 @@ class PtxBuilder { // https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints char getRegisterType(Type type) { + if (type.isInteger(1)) + return 'b'; if (type.isInteger(16)) return 'h'; if (type.isInteger(32)) diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 7ffe1ad2bb2b1..bf10ddbb4016a 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -363,3 +363,17 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { : !mat32f32 -> !mat32f32 return %result2 : !mat32f32 } + +// ----- + +func.func @elect_one_leader_sync() { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{ + // CHECK-SAME: .reg .u32 rx; + // CHECK-SAME: .reg .pred px; + // CHECK-SAME: mov.u32 $0, 0; + // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF; + // CHECK-SAME: @px mov.u32 $0, 1; + // CHECK-SAME: "=b" : () -> i1 + %cnd = nvvm.elect.sync -> i1 + return +}