Skip to content

Commit 25d3695

Browse files
skatrakjsjodin
andcommitted
Fix remaining issues after merge.
- Cherry-pick some unit test fixes from PR #44. - Cherry-pick reduction processing fix from PR llvm#85807. - Update argument lists of modified functions. - Rewrite `ReductionProcessor::addReductionSym` based on new clause structures. - Remove leftover references to `RIManager`. - Op renames to match new approach. Co-authored-by: Jan Leyonberg <[email protected]>
1 parent 21a0a58 commit 25d3695

9 files changed

+62
-61
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

+18-8
Original file line numberDiff line numberDiff line change
@@ -907,27 +907,37 @@ bool ClauseProcessor::processMap(
907907
bool ClauseProcessor::processTargetReduction(
908908
llvm::SmallVector<const Fortran::semantics::Symbol *> &reductionSymbols)
909909
const {
910-
return findRepeatableClause<ClauseTy::Reduction>(
911-
[&](const ClauseTy::Reduction *reductionClause,
910+
return findRepeatableClause<omp::clause::Reduction>(
911+
[&](const omp::clause::Reduction &clause,
912912
const Fortran::parser::CharBlock &) {
913913
ReductionProcessor rp;
914-
rp.addReductionSym(reductionClause->v, reductionSymbols);
914+
rp.addReductionSym(clause, reductionSymbols);
915915
});
916916
}
917917

918918
bool ClauseProcessor::processReduction(
919919
mlir::Location currentLocation,
920-
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
921-
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
922-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
923-
const {
920+
llvm::SmallVectorImpl<mlir::Value> &outReductionVars,
921+
llvm::SmallVectorImpl<mlir::Attribute> &outReductionDeclSymbols,
922+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
923+
*outReductionSymbols) const {
924924
return findRepeatableClause<omp::clause::Reduction>(
925925
[&](const omp::clause::Reduction &clause,
926926
const Fortran::parser::CharBlock &) {
927+
llvm::SmallVector<mlir::Value> reductionVars;
928+
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
929+
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
927930
ReductionProcessor rp;
928931
rp.addDeclareReduction(currentLocation, converter, clause,
929932
reductionVars, reductionDeclSymbols,
930-
reductionSymbols);
933+
outReductionSymbols ? &reductionSymbols
934+
: nullptr);
935+
llvm::copy(reductionVars, std::back_inserter(outReductionVars));
936+
llvm::copy(reductionDeclSymbols,
937+
std::back_inserter(outReductionDeclSymbols));
938+
if (outReductionSymbols)
939+
llvm::copy(reductionSymbols,
940+
std::back_inserter(*outReductionSymbols));
931941
});
932942
}
933943

flang/lib/Lower/OpenMP/OpenMP.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -1339,7 +1339,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
13391339
Fortran::parser::OmpClause::Defaultmap>(
13401340
currentLocation, llvm::omp::Directive::OMPD_target);
13411341

1342-
DataSharingProcessor localDSP(converter, clauseList, eval);
1342+
DataSharingProcessor localDSP(converter, semaCtx, clauseList, eval);
13431343
DataSharingProcessor &actualDSP = dsp ? *dsp : localDSP;
13441344
actualDSP.processStep1();
13451345

@@ -1458,8 +1458,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
14581458
/*teams_thread_limit=*/nullptr, /*num_threads=*/nullptr);
14591459

14601460
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
1461-
mapSymLocs, mapSymbols, currentLocation, clauseList,
1462-
actualDSP);
1461+
mapSymLocs, mapSymbols, currentLocation, actualDSP);
14631462

14641463
return targetOp;
14651464
}
@@ -2074,7 +2073,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
20742073
}();
20752074

20762075
bool validDirective = false;
2077-
DataSharingProcessor dsp(converter, loopOpClauseList, eval);
2076+
DataSharingProcessor dsp(converter, semaCtx, loopOpClauseList, eval);
20782077

20792078
if (llvm::omp::topTaskloopSet.test(ompDirective)) {
20802079
validDirective = true;

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

+4-10
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,11 @@ ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
6767
}
6868

6969
void ReductionProcessor::addReductionSym(
70-
const Fortran::parser::OmpReductionClause &reduction,
70+
const omp::clause::Reduction &reduction,
7171
llvm::SmallVector<const Fortran::semantics::Symbol *> &symbols) {
72-
const auto &objectList{std::get<Fortran::parser::OmpObjectList>(reduction.t)};
73-
74-
for (const Fortran::parser::OmpObject &ompObject : objectList.v) {
75-
if (const auto *name{
76-
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
77-
if (const Fortran::semantics::Symbol * symbol{name->symbol})
78-
symbols.push_back(symbol);
79-
}
80-
}
72+
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
73+
llvm::transform(objectList, std::back_inserter(symbols),
74+
[](const Object &object) { return object.id(); });
8175
}
8276

8377
bool ReductionProcessor::supportedIntrinsicProcReduction(

flang/lib/Lower/OpenMP/ReductionProcessor.h

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#include "Clauses.h"
1717
#include "flang/Optimizer/Builder/FIRBuilder.h"
1818
#include "flang/Optimizer/Dialect/FIRType.h"
19-
#include "flang/Parser/parse-tree.h"
2019
#include "flang/Semantics/symbol.h"
2120
#include "flang/Semantics/type.h"
2221
#include "mlir/IR/Location.h"
@@ -106,7 +105,7 @@ class ReductionProcessor {
106105
mlir::Value op2);
107106

108107
static void addReductionSym(
109-
const Fortran::parser::OmpReductionClause &reduction,
108+
const omp::clause::Reduction &reduction,
110109
llvm::SmallVector<const Fortran::semantics::Symbol *> &symbols);
111110

112111
/// Creates an OpenMP reduction declaration and inserts it into the provided

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+22-32
Original file line numberDiff line numberDiff line change
@@ -1125,22 +1125,19 @@ convertOmpWsloop(
11251125
tempTerminator->eraseFromParent();
11261126
builder.restoreIP(nextInsertionPoint);
11271127

1128-
if (!ompBuilder->Config.isGPU())
1129-
ompBuilder->RIManager.clear();
1130-
11311128
return success();
11321129
}
11331130

11341131
static LogicalResult
1135-
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
1132+
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
11361133
LLVM::ModuleTranslation &moduleTranslation) {
11371134
llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP =
11381135
findAllocaInsertPoint(builder, moduleTranslation);
11391136
SmallVector<OwningReductionGen> owningReductionGens;
11401137
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
11411138
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
11421139

1143-
return convertOmpWsLoop(opInst, builder, moduleTranslation, redAllocaIP,
1140+
return convertOmpWsloop(opInst, builder, moduleTranslation, redAllocaIP,
11441141
owningReductionGens, owningAtomicReductionGens,
11451142
reductionInfos);
11461143
}
@@ -1414,9 +1411,6 @@ convertOmpParallel(Operation &opInst1, llvm::IRBuilderBase &builder,
14141411
ompBuilder->createParallel(ompLoc, allocaIP, bodyGenCB, privCB, finiCB,
14151412
ifCond, numThreads, pbKind, isCancellable));
14161413

1417-
if (!ompBuilder->Config.isGPU())
1418-
ompBuilder->RIManager.clear();
1419-
14201414
return bodyGenStatus;
14211415
}
14221416

@@ -3330,10 +3324,9 @@ class ConversionDispatchList {
33303324
}
33313325
};
33323326

3333-
static LogicalResult convertOmpDistributeParallelWsLoop(
3334-
Operation *op,
3335-
omp::DistributeOp distribute, omp::ParallelOp parallel,
3336-
omp::WsLoopOp wsloop, llvm::IRBuilderBase &builder,
3327+
static LogicalResult convertOmpDistributeParallelWsloop(
3328+
Operation *op, omp::DistributeOp distribute, omp::ParallelOp parallel,
3329+
omp::WsloopOp wsloop, llvm::IRBuilderBase &builder,
33373330
LLVM::ModuleTranslation &moduleTranslation,
33383331
ConversionDispatchList &dispatchList);
33393332

@@ -3534,10 +3527,10 @@ convertInternalTargetOp(Operation *op, llvm::IRBuilderBase &builder,
35343527

35353528
omp::DistributeOp distribute;
35363529
omp::ParallelOp parallel;
3537-
omp::WsLoopOp wsloop;
3530+
omp::WsloopOp wsloop;
35383531
// Match composite constructs
35393532
if (matchOpNest(op, distribute, parallel, wsloop)) {
3540-
return convertOmpDistributeParallelWsLoop(op, distribute, parallel, wsloop,
3533+
return convertOmpDistributeParallelWsloop(op, distribute, parallel, wsloop,
35413534
builder, moduleTranslation,
35423535
dispatchList);
35433536
}
@@ -3586,9 +3579,9 @@ class OpenMPDialectLLVMIRTranslationInterface
35863579
// Implementation converting a nest of operations in a single function. This
35873580
// just overrides the parallel and wsloop dispatches but does the normal
35883581
// lowering for now.
3589-
static LogicalResult convertOmpDistributeParallelWsLoop(
3582+
static LogicalResult convertOmpDistributeParallelWsloop(
35903583
Operation *op, omp::DistributeOp distribute, omp::ParallelOp parallel,
3591-
omp::WsLoopOp wsloop, llvm::IRBuilderBase &builder,
3584+
omp::WsloopOp wsloop, llvm::IRBuilderBase &builder,
35923585
LLVM::ModuleTranslation &moduleTranslation,
35933586
ConversionDispatchList &dispatchList) {
35943587

@@ -3599,25 +3592,22 @@ static LogicalResult convertOmpDistributeParallelWsLoop(
35993592
llvm::OpenMPIRBuilder::InsertPointTy redAllocaIP;
36003593

36013594
// Convert wsloop alternative implementation
3602-
ConvertFunctionTy convertWsLoop = [&redAllocaIP, &owningReductionGens,
3603-
&owningAtomicReductionGens,
3604-
&reductionInfos](
3605-
Operation *op,
3606-
llvm::IRBuilderBase &builder,
3607-
LLVM::ModuleTranslation
3608-
&moduleTranslation) {
3609-
if (!isa<omp::WsLoopOp>(op)) {
3610-
return std::make_pair(false, failure());
3611-
}
3595+
ConvertFunctionTy convertWsloop =
3596+
[&redAllocaIP, &owningReductionGens, &owningAtomicReductionGens,
3597+
&reductionInfos](Operation *op, llvm::IRBuilderBase &builder,
3598+
LLVM::ModuleTranslation &moduleTranslation) {
3599+
if (!isa<omp::WsloopOp>(op)) {
3600+
return std::make_pair(false, failure());
3601+
}
36123602

3613-
LogicalResult result = convertOmpWsLoop(
3614-
*op, builder, moduleTranslation, redAllocaIP, owningReductionGens,
3615-
owningAtomicReductionGens, reductionInfos);
3616-
return std::make_pair(true, result);
3617-
};
3603+
LogicalResult result = convertOmpWsloop(
3604+
*op, builder, moduleTranslation, redAllocaIP, owningReductionGens,
3605+
owningAtomicReductionGens, reductionInfos);
3606+
return std::make_pair(true, result);
3607+
};
36183608

36193609
// Push the new alternative functions
3620-
dispatchList.pushConversionFunction(convertWsLoop);
3610+
dispatchList.pushConversionFunction(convertWsloop);
36213611

36223612
// Lower the current distribute operation
36233613
LogicalResult result = convertOmpDistribute(*op, builder, moduleTranslation,

mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
77
llvm.func @target_parallel_wsloop(%arg0: !llvm.ptr) attributes {
88
target_cpu = "gfx90a",
9-
target_features = #llvm.target_features<["+gfx9-insts", "+wavefrontsize64"]>
9+
target_features = #llvm.target_features<["+gfx9-insts", "+wavefrontsize64"]>,
10+
omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>
1011
} {
1112
omp.parallel {
1213
%loop_ub = llvm.mlir.constant(9 : i32) : i32

mlir/test/Target/LLVMIR/omptarget-teams-llvm.mlir

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
module attributes {omp.is_target_device = true} {
77
llvm.func @foo(i32)
8-
llvm.func @omp_target_teams_shared_simple(%arg0 : i32) {
8+
llvm.func @omp_target_teams_shared_simple(%arg0 : i32) attributes {
9+
omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>
10+
} {
911
omp.teams {
1012
llvm.call @foo(%arg0) : (i32) -> ()
1113
omp.terminator

mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
// for nested omp do loop with collapse clause inside omp target region
55

66
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
7-
llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) {
7+
llvm.func @target_collapsed_wsloop(%arg0: !llvm.ptr) attributes {
8+
omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>
9+
} {
810
%loop_ub = llvm.mlir.constant(99 : i32) : i32
911
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1012
%loop_step = llvm.mlir.constant(1 : index) : i32

mlir/test/Target/LLVMIR/omptarget-wsloop.mlir

+6-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
// for nested omp do loop inside omp target region
55

66
module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.data_layout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8", llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true } {
7-
llvm.func @target_wsloop(%arg0: !llvm.ptr ){
7+
llvm.func @target_wsloop(%arg0: !llvm.ptr) attributes {
8+
omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>
9+
} {
810
%loop_ub = llvm.mlir.constant(9 : i32) : i32
911
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1012
%loop_step = llvm.mlir.constant(1 : i32) : i32
@@ -16,7 +18,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
1618
llvm.return
1719
}
1820

19-
llvm.func @target_empty_wsloop(){
21+
llvm.func @target_empty_wsloop() attributes {
22+
omp.declare_target = #omp.declaretarget<device_type = (any), capture_clause = (to)>
23+
} {
2024
%loop_ub = llvm.mlir.constant(9 : i32) : i32
2125
%loop_lb = llvm.mlir.constant(0 : i32) : i32
2226
%loop_step = llvm.mlir.constant(1 : i32) : i32

0 commit comments

Comments
 (0)