Skip to content

[OpenMP][mlir] Added num_teams, thread_limit translation to LLVM IR #68821

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 15, 2023

Conversation

shraiysh
Copy link
Member

This patch adds translation to LLVM IR for num_teams and thread_limit in for omp.teams operation.

This patch adds translation to LLVM IR for `num_teams` and
`thread_limit` in for `omp.teams` operation.
@llvmbot
Copy link
Member

llvmbot commented Oct 11, 2023

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir-llvm

Author: Shraiysh (shraiysh)

Changes

This patch adds translation to LLVM IR for num_teams and thread_limit in for omp.teams operation.


Full diff: https://github.com/llvm/llvm-project/pull/68821.diff

2 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+16-6)
  • (modified) mlir/test/Target/LLVMIR/openmp-teams.mlir (+111)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 1ec3bb8e7562a9e..ae974c14fac41a6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -667,11 +667,9 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
                 LLVM::ModuleTranslation &moduleTranslation) {
   using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
   LogicalResult bodyGenStatus = success();
-  if (op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() ||
-      op.getThreadLimit() || !op.getAllocatorsVars().empty() ||
-      op.getReductions()) {
+  if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
     return op.emitError("unhandled clauses for translation to LLVM IR");
-  }
+
   auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
     LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
         moduleTranslation, allocaIP);
@@ -680,9 +678,21 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
                         moduleTranslation, bodyGenStatus);
   };
 
+  llvm::Value *numTeamsLower = nullptr;
+  if (auto numTeamsLowerVar = op.getNumTeamsLower())
+    numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
+
+  llvm::Value *numTeamsUpper = nullptr;
+  if (auto numTeamsUpperVar = op.getNumTeamsUpper())
+    numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
+
+  llvm::Value *threadLimit = nullptr;
+  if (auto threadLimitVar = op.getThreadLimit())
+    threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+
   llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
-  builder.restoreIP(
-      moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
+  builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
+      ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
   return bodyGenStatus;
 }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 18fc2bb5a3c61b2..87ef90223ed704a 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -124,3 +124,114 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a
 // CHECK-NEXT: br label
 // CHECK: ret void
 
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit
+// CHECK-SAME: (i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_thread_limit(%threadLimit: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[THREAD_LIMIT]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%threadLimit : i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_upper(%numTeamsUpper: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_UPPER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(to %numTeamsUpper : i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_lower_and_upper
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]])
+llvm.func @omp_teams_num_teams_lower_and_upper(%numTeamsLower: i32, %numTeamsUpper: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void
+
+// -----
+
+llvm.func @beforeTeams()
+llvm.func @duringTeams()
+llvm.func @afterTeams()
+
+// CHECK-LABEL: @omp_teams_num_teams_and_thread_limit
+// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
+llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
+    // CHECK-NEXT: call void @beforeTeams()
+    llvm.call @beforeTeams() : () -> ()
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 [[THREAD_LIMIT]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    // CHECK: call void @afterTeams
+    llvm.call @afterTeams() : () -> ()
+    // CHECK: ret void
+    llvm.return
+}
+
+// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
+// CHECK: call void @duringTeams()
+// CHECK: ret void

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

Comment on lines 681 to 691
llvm::Value *numTeamsLower = nullptr;
if (auto numTeamsLowerVar = op.getNumTeamsLower())
numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);

llvm::Value *numTeamsUpper = nullptr;
if (auto numTeamsUpperVar = op.getNumTeamsUpper())
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);

llvm::Value *threadLimit = nullptr;
if (auto threadLimitVar = op.getThreadLimit())
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Spell the auto?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants