Skip to content

[mlir][nvvm] Introduce elect.sync Op #68323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,26 @@ def NVVM_SyncWarpOp :
}


def NVVM_ElectSyncOp : NVVM_Op<"elect.sync",
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>
{
let results = (outs I1:$pred);
let assemblyFormat = "attr-dict `->` type(results)";
let extraClassDefinition = [{
std::string $cppClass::getPtx() {
return std::string(
"{ \n"
".reg .u32 rx; \n"
".reg .pred px; \n"
" mov.u32 %0, 0; \n"
" elect.sync rx | px, 0xFFFFFFFF;\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The instruction is only available on sm_90/PTX8.0. Does MLIR have any constraints based on which GPU model we're compiling for?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a good point and thanks for the review.

NVVM dialect doesn't know SM arch. You can think that NVVM is similar to PTX, you can compile it but you should make sure that you have the SM arch to run.
A high-level dialect knows the SM arch and it's responsible of generating the right NVVM Ops. For example, a transformation from vector dialect -> nvvm should know the SM target.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the explanation. So it would be similar to using LLVM intrinsics that may or may not be available for a particular target.

Speaking of intrinsics, it may make sense to make elect.sync an LLVM intrinsic. We'll eventually need it there in any case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it would be similar to using LLVM intrinsics that may or may not be available for a particular target.

Exactly, it is similar to using LLVM intrinsic. NVVM was designed to map 1to1 to LLVM intrinsic.

FWIW, generating PTX like I do here is relatively new in MLIR. We noticed that LLVM doesn't support all of the PTX (also predicates). Therefore, I've recently implemented the BasicPtxBuilderOpInterface that generates inline assembly with the given PTX. This is what I use here.

Speaking of intrinsics, it may make sense to make elect.sync an LLVM intrinsic. We'll eventually need it there in any case.

Yes it does makes sense. We would like to generate LLVM intrinsic whenever it is possible.

"@px mov.u32 %0, 1; \n"
"}\n"
);
}
}];
}

def LoadCacheModifierCA : I32EnumAttrCase<"CA", 0, "ca">;
def LoadCacheModifierCG : I32EnumAttrCase<"CG", 1, "cg">;
def LoadCacheModifierCS : I32EnumAttrCase<"CS", 2, "cs">;
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class PtxBuilder {

// https://docs.nvidia.com/cuda/inline-ptx-assembly/index.html#constraints
char getRegisterType(Type type) {
if (type.isInteger(1))
return 'b';
if (type.isInteger(16))
return 'h';
if (type.isInteger(32))
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,17 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 {
: !mat32f32 -> !mat32f32
return %result2 : !mat32f32
}

// -----

func.func @elect_one_leader_sync() {
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{
// CHECK-SAME: .reg .u32 rx;
// CHECK-SAME: .reg .pred px;
// CHECK-SAME: mov.u32 $0, 0;
// CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF;
// CHECK-SAME: @px mov.u32 $0, 1;
// CHECK-SAME: "=b" : () -> i1
%cnd = nvvm.elect.sync -> i1
return
}