-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[OpenMP][mlir] Added num_teams
, thread_limit
translation to LLVM IR
#68821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This patch adds translation to LLVM IR for `num_teams` and `thread_limit` in for `omp.teams` operation.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Shraiysh (shraiysh) ChangesThis patch adds translation to LLVM IR for Full diff: https://github.com/llvm/llvm-project/pull/68821.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..ae974c14fac41a6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -667,11 +667,9 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
LogicalResult bodyGenStatus = success();
- if (op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() ||
- op.getThreadLimit() || !op.getAllocatorsVars().empty() ||
- op.getReductions()) {
+ if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
return op.emitError("unhandled clauses for translation to LLVM IR");
- }
+
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
moduleTranslation, allocaIP);
@@ -680,9 +678,21 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
moduleTranslation, bodyGenStatus);
};
+ llvm::Value *numTeamsLower = nullptr;
+ if (auto numTeamsLowerVar = op.getNumTeamsLower())
+ numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
+
+ llvm::Value *numTeamsUpper = nullptr;
+ if (auto numTeamsUpperVar = op.getNumTeamsUpper())
+ numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
+
+ llvm::Value *threadLimit = nullptr;
+ if (auto threadLimitVar = op.getThreadLimit())
+ threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
- builder.restoreIP(
- moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
+ builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
+ ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
return bodyGenStatus;
}
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 18fc2bb5a3c61b2..87ef90223ed704a 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -124,3 +124,114 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a
// CHECK-NEXT: br label
// CHECK: ret void
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit
+// CHECK-SAME: (i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_thread_limit(%threadLimit: i32) {
+ // CHECK-NEXT: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[THREAD_LIMIT]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams thread_limit(%threadLimit : i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams
+ llvm.call @afterTeams() : () -> ()
+ // CHECK: ret void
+ llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_upper(%numTeamsUpper: i32) {
+ // CHECK-NEXT: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_UPPER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams num_teams(to %numTeamsUpper : i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams
+ llvm.call @afterTeams() : () -> ()
+ // CHECK: ret void
+ llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_lower_and_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_lower_and_upper(%numTeamsLower: i32, %numTeamsUpper: i32) {
+ // CHECK-NEXT: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams
+ llvm.call @afterTeams() : () -> ()
+ // CHECK: ret void
+ llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_and_thread_limit
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
+ // CHECK-NEXT: call void @beforeTeams()
+ llvm.call @beforeTeams() : () -> ()
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 [[THREAD_LIMIT]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ // CHECK: call void @afterTeams
+ llvm.call @afterTeams() : () -> ()
+ // CHECK: ret void
+ llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG.
llvm::Value *numTeamsLower = nullptr; | ||
if (auto numTeamsLowerVar = op.getNumTeamsLower()) | ||
numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar); | ||
|
||
llvm::Value *numTeamsUpper = nullptr; | ||
if (auto numTeamsUpperVar = op.getNumTeamsUpper()) | ||
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar); | ||
|
||
llvm::Value *threadLimit = nullptr; | ||
if (auto threadLimitVar = op.getThreadLimit()) | ||
threadLimit = moduleTranslation.lookupValue(threadLimitVar); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Spell the auto?
This patch adds translation to LLVM IR for
num_teams
andthread_limit
in foromp.teams
operation.