Skip to content

Commit 546c3d7

Browse files
authored
[OpenMP][mlir] Added num_teams, thread_limit translation to LLVM IR (#68821)
This patch adds translation to LLVM IR for `num_teams` and `thread_limit` in for `omp.teams` operation.
1 parent 4698b99 commit 546c3d7

File tree

2 files changed

+127
-6
lines changed

2 files changed

+127
-6
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+16-6
Original file line numberDiff line numberDiff line change
@@ -666,11 +666,9 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
666666
LLVM::ModuleTranslation &moduleTranslation) {
667667
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
668668
LogicalResult bodyGenStatus = success();
669-
if (op.getNumTeamsLower() || op.getNumTeamsUpper() || op.getIfExpr() ||
670-
op.getThreadLimit() || !op.getAllocatorsVars().empty() ||
671-
op.getReductions()) {
669+
if (op.getIfExpr() || !op.getAllocatorsVars().empty() || op.getReductions())
672670
return op.emitError("unhandled clauses for translation to LLVM IR");
673-
}
671+
674672
auto bodyCB = [&](InsertPointTy allocaIP, InsertPointTy codegenIP) {
675673
LLVM::ModuleTranslation::SaveStack<OpenMPAllocaStackFrame> frame(
676674
moduleTranslation, allocaIP);
@@ -679,9 +677,21 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
679677
moduleTranslation, bodyGenStatus);
680678
};
681679

680+
llvm::Value *numTeamsLower = nullptr;
681+
if (Value numTeamsLowerVar = op.getNumTeamsLower())
682+
numTeamsLower = moduleTranslation.lookupValue(numTeamsLowerVar);
683+
684+
llvm::Value *numTeamsUpper = nullptr;
685+
if (Value numTeamsUpperVar = op.getNumTeamsUpper())
686+
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
687+
688+
llvm::Value *threadLimit = nullptr;
689+
if (Value threadLimitVar = op.getThreadLimit())
690+
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
691+
682692
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
683-
builder.restoreIP(
684-
moduleTranslation.getOpenMPBuilder()->createTeams(ompLoc, bodyCB));
693+
builder.restoreIP(moduleTranslation.getOpenMPBuilder()->createTeams(
694+
ompLoc, bodyCB, numTeamsLower, numTeamsUpper, threadLimit));
685695
return bodyGenStatus;
686696
}
687697

mlir/test/Target/LLVMIR/openmp-teams.mlir

+111
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,114 @@ llvm.func @omp_teams_branching_shared(%condition: i1, %arg0: i32, %arg1: f32, %a
124124
// CHECK-NEXT: br label
125125
// CHECK: ret void
126126

127+
// -----
128+
129+
llvm.func @beforeTeams()
130+
llvm.func @duringTeams()
131+
llvm.func @afterTeams()
132+
133+
// CHECK-LABEL: @omp_teams_thread_limit
134+
// CHECK-SAME: (i32 [[THREAD_LIMIT:.+]])
135+
llvm.func @omp_teams_thread_limit(%threadLimit: i32) {
136+
// CHECK-NEXT: call void @beforeTeams()
137+
llvm.call @beforeTeams() : () -> ()
138+
// CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
139+
// CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[THREAD_LIMIT]])
140+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
141+
omp.teams thread_limit(%threadLimit : i32) {
142+
llvm.call @duringTeams() : () -> ()
143+
omp.terminator
144+
}
145+
// CHECK: call void @afterTeams
146+
llvm.call @afterTeams() : () -> ()
147+
// CHECK: ret void
148+
llvm.return
149+
}
150+
151+
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
152+
// CHECK: call void @duringTeams()
153+
// CHECK: ret void
154+
155+
// -----
156+
157+
llvm.func @beforeTeams()
158+
llvm.func @duringTeams()
159+
llvm.func @afterTeams()
160+
161+
// CHECK-LABEL: @omp_teams_num_teams_upper
162+
// CHECK-SAME: (i32 [[NUM_TEAMS_UPPER:.+]])
163+
llvm.func @omp_teams_num_teams_upper(%numTeamsUpper: i32) {
164+
// CHECK-NEXT: call void @beforeTeams()
165+
llvm.call @beforeTeams() : () -> ()
166+
// CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
167+
// CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_UPPER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
168+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
169+
omp.teams num_teams(to %numTeamsUpper : i32) {
170+
llvm.call @duringTeams() : () -> ()
171+
omp.terminator
172+
}
173+
// CHECK: call void @afterTeams
174+
llvm.call @afterTeams() : () -> ()
175+
// CHECK: ret void
176+
llvm.return
177+
}
178+
179+
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
180+
// CHECK: call void @duringTeams()
181+
// CHECK: ret void
182+
183+
// -----
184+
185+
llvm.func @beforeTeams()
186+
llvm.func @duringTeams()
187+
llvm.func @afterTeams()
188+
189+
// CHECK-LABEL: @omp_teams_num_teams_lower_and_upper
190+
// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]])
191+
llvm.func @omp_teams_num_teams_lower_and_upper(%numTeamsLower: i32, %numTeamsUpper: i32) {
192+
// CHECK-NEXT: call void @beforeTeams()
193+
llvm.call @beforeTeams() : () -> ()
194+
// CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
195+
// CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 0)
196+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
197+
omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) {
198+
llvm.call @duringTeams() : () -> ()
199+
omp.terminator
200+
}
201+
// CHECK: call void @afterTeams
202+
llvm.call @afterTeams() : () -> ()
203+
// CHECK: ret void
204+
llvm.return
205+
}
206+
207+
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
208+
// CHECK: call void @duringTeams()
209+
// CHECK: ret void
210+
211+
// -----
212+
213+
llvm.func @beforeTeams()
214+
llvm.func @duringTeams()
215+
llvm.func @afterTeams()
216+
217+
// CHECK-LABEL: @omp_teams_num_teams_and_thread_limit
218+
// CHECK-SAME: (i32 [[NUM_TEAMS_LOWER:.+]], i32 [[NUM_TEAMS_UPPER:.+]], i32 [[THREAD_LIMIT:.+]])
219+
llvm.func @omp_teams_num_teams_and_thread_limit(%numTeamsLower: i32, %numTeamsUpper: i32, %threadLimit: i32) {
220+
// CHECK-NEXT: call void @beforeTeams()
221+
llvm.call @beforeTeams() : () -> ()
222+
// CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
223+
// CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 [[NUM_TEAMS_LOWER]], i32 [[NUM_TEAMS_UPPER]], i32 [[THREAD_LIMIT]])
224+
// CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @1, i32 0, ptr [[OUTLINED_FN:.+]])
225+
omp.teams num_teams(%numTeamsLower : i32 to %numTeamsUpper: i32) thread_limit(%threadLimit: i32) {
226+
llvm.call @duringTeams() : () -> ()
227+
omp.terminator
228+
}
229+
// CHECK: call void @afterTeams
230+
llvm.call @afterTeams() : () -> ()
231+
// CHECK: ret void
232+
llvm.return
233+
}
234+
235+
// CHECK: define internal void [[OUTLINED_FN]](ptr {{.+}}, ptr {{.+}})
236+
// CHECK: call void @duringTeams()
237+
// CHECK: ret void

0 commit comments

Comments
 (0)