Skip to content

Commit 62423e1

Browse files
committed
[flang][OpenMP] Extend do concurrent mapping to device.
For simple loops, we can now choose to map `do concurrent` to either the host (i.e. `omp parallel do`) or the device (i.e. `omp target teams distribute parallel do`). In order to use this from `flang-new`, you can pass: `-fopenmp -fdo-concurrent-parallel=[host|device|none]`; where `none` will disable the `do concurrent` mapping altogether.
1 parent e967097 commit 62423e1

File tree

18 files changed

+639
-263
lines changed

18 files changed

+639
-263
lines changed

flang/lib/Lower/OpenMP/Utils.h renamed to flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ using DeclareTargetCapturePair =
5050
const Fortran::semantics::Symbol &>;
5151

5252
mlir::omp::MapInfoOp
53-
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
53+
createMapInfoOp(mlir::OpBuilder &builder, mlir::Location loc,
5454
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
5555
mlir::ArrayRef<mlir::Value> bounds,
5656
mlir::ArrayRef<mlir::Value> members, uint64_t mapType,
@@ -73,6 +73,8 @@ void genObjectList(const ObjectList &objects,
7373
Fortran::lower::AbstractConverter &converter,
7474
llvm::SmallVectorImpl<mlir::Value> &operands);
7575

76+
mlir::Value calculateTripCount(fir::FirOpBuilder &builder, mlir::Location loc,
77+
const mlir::omp::CollapseClauseOps &ops);
7678
} // namespace omp
7779
} // namespace lower
7880
} // namespace Fortran

flang/include/flang/Optimizer/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ namespace fir {
3838
#define GEN_PASS_DECL_ARRAYVALUECOPY
3939
#define GEN_PASS_DECL_CHARACTERCONVERSION
4040
#define GEN_PASS_DECL_CFGCONVERSION
41+
#define GEN_PASS_DECL_DOCONCURRENTCONVERSIONPASS
4142
#define GEN_PASS_DECL_EXTERNALNAMECONVERSION
4243
#define GEN_PASS_DECL_MEMREFDATAFLOWOPT
4344
#define GEN_PASS_DECL_SIMPLIFYINTRINSICS
@@ -96,6 +97,7 @@ createFunctionAttrPass(FunctionAttrTypes &functionAttr, bool noInfsFPMath,
9697
bool noSignedZerosFPMath, bool unsafeFPMath);
9798

9899
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass();
100+
std::unique_ptr<mlir::Pass> createDoConcurrentConversionPass(bool mapToDevice);
99101

100102
void populateCfgConversionRewrites(mlir::RewritePatternSet &patterns,
101103
bool forceLoopToExecuteOnce = false);

flang/include/flang/Optimizer/Transforms/Passes.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,12 @@ def DoConcurrentConversionPass : Pass<"fopenmp-do-concurrent-conversion", "mlir:
423423
target.
424424
}];
425425

426-
let constructor = "::fir::createDoConcurrentConversionPass()";
427426
let dependentDialects = ["mlir::omp::OpenMPDialect"];
427+
428+
let options = [
429+
Option<"mapTo", "map-to", "std::string", "",
430+
"Try to map `do concurrent` loops to OpenMP (on host or device)">,
431+
];
428432
}
429433

430434
#endif // FLANG_OPTIMIZER_TRANSFORMS_PASSES

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -320,22 +320,14 @@ bool CodeGenAction::beginSourceFileAction() {
320320
// Add OpenMP-related passes
321321
// WARNING: These passes must be run immediately after the lowering to ensure
322322
// that the FIR is correct with respect to OpenMP operations/attributes.
323-
bool isOpenMPEnabled = ci.getInvocation().getFrontendOpts().features.IsEnabled(
323+
bool isOpenMPEnabled =
324+
ci.getInvocation().getFrontendOpts().features.IsEnabled(
324325
Fortran::common::LanguageFeature::OpenMP);
325-
if (isOpenMPEnabled) {
326-
bool isDevice = false;
327-
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
328-
mlirModule->getOperation()))
329-
isDevice = offloadMod.getIsTargetDevice();
330-
// WARNING: This pipeline must be run immediately after the lowering to
331-
// ensure that the FIR is correct with respect to OpenMP operations/
332-
// attributes.
333-
fir::createOpenMPFIRPassPipeline(pm, isDevice);
334-
}
335326

336327
using DoConcurrentMappingKind =
337328
Fortran::frontend::CodeGenOptions::DoConcurrentMappingKind;
338-
DoConcurrentMappingKind selectedKind = ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
329+
DoConcurrentMappingKind selectedKind =
330+
ci.getInvocation().getCodeGenOpts().getDoConcurrentMapping();
339331
if (selectedKind != DoConcurrentMappingKind::DCMK_None) {
340332
if (!isOpenMPEnabled) {
341333
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
@@ -345,18 +337,21 @@ bool CodeGenAction::beginSourceFileAction() {
345337
ci.getDiagnostics().Report(diagID);
346338
} else {
347339
bool mapToDevice = selectedKind == DoConcurrentMappingKind::DCMK_Device;
348-
349-
if (mapToDevice) {
350-
unsigned diagID = ci.getDiagnostics().getCustomDiagID(
351-
clang::DiagnosticsEngine::Warning,
352-
"TODO: lowering `do concurrent` loops to OpenMP device is not "
353-
"supported yet");
354-
ci.getDiagnostics().Report(diagID);
355-
} else
356-
pm.addPass(fir::createDoConcurrentConversionPass());
340+
pm.addPass(fir::createDoConcurrentConversionPass(mapToDevice));
357341
}
358342
}
359343

344+
if (isOpenMPEnabled) {
345+
bool isDevice = false;
346+
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
347+
mlirModule->getOperation()))
348+
isDevice = offloadMod.getIsTargetDevice();
349+
// WARNING: This pipeline must be run immediately after the lowering to
350+
// ensure that the FIR is correct with respect to OpenMP operations/
351+
// attributes.
352+
fir::createOpenMPFIRPassPipeline(pm, isDevice);
353+
}
354+
360355
pm.enableVerifier(/*verifyPasses=*/true);
361356
pm.addPass(std::make_unique<Fortran::lower::VerifierPass>());
362357

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "ClauseProcessor.h"
14-
#include "Clauses.h"
1514

15+
#include "flang/Lower/OpenMP/Clauses.h"
1616
#include "flang/Lower/PFTBuilder.h"
1717
#include "flang/Parser/tools.h"
1818
#include "flang/Semantics/tools.h"
@@ -807,30 +807,6 @@ bool ClauseProcessor::processLink(
807807
});
808808
}
809809

810-
mlir::omp::MapInfoOp
811-
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
812-
mlir::Value baseAddr, mlir::Value varPtrPtr, std::string name,
813-
llvm::ArrayRef<mlir::Value> bounds,
814-
llvm::ArrayRef<mlir::Value> members, uint64_t mapType,
815-
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
816-
bool isVal) {
817-
if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
818-
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
819-
retTy = baseAddr.getType();
820-
}
821-
822-
mlir::TypeAttr varType = mlir::TypeAttr::get(
823-
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());
824-
825-
mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
826-
loc, retTy, baseAddr, varType, varPtrPtr, members, bounds,
827-
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
828-
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
829-
builder.getStringAttr(name));
830-
831-
return op;
832-
}
833-
834810
bool ClauseProcessor::processMap(
835811
mlir::Location currentLocation, Fortran::lower::StatementContext &stmtCtx,
836812
mlir::omp::MapClauseOps &result,

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
#ifndef FORTRAN_LOWER_CLAUASEPROCESSOR_H
1313
#define FORTRAN_LOWER_CLAUASEPROCESSOR_H
1414

15-
#include "Clauses.h"
1615
#include "DirectivesCommon.h"
1716
#include "ReductionProcessor.h"
18-
#include "Utils.h"
1917
#include "flang/Lower/AbstractConverter.h"
2018
#include "flang/Lower/Bridge.h"
19+
#include "flang/Lower/OpenMP/Clauses.h"
20+
#include "flang/Lower/OpenMP/Utils.h"
2121
#include "flang/Optimizer/Builder/Todo.h"
2222
#include "flang/Parser/dump-parse-tree.h"
2323
#include "flang/Parser/parse-tree.h"

flang/lib/Lower/OpenMP/Clauses.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include "Clauses.h"
9+
#include "flang/Lower/OpenMP/Clauses.h"
1010

1111
#include "flang/Common/idioms.h"
1212
#include "flang/Evaluate/expression.h"

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
#include "DataSharingProcessor.h"
1414

15-
#include "Utils.h"
15+
#include "flang/Lower/OpenMP/Utils.h"
1616
#include "flang/Lower/PFTBuilder.h"
1717
#include "flang/Lower/SymbolMap.h"
1818
#include "flang/Optimizer/Builder/Todo.h"

flang/lib/Lower/OpenMP/DataSharingProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
#ifndef FORTRAN_LOWER_DATASHARINGPROCESSOR_H
1313
#define FORTRAN_LOWER_DATASHARINGPROCESSOR_H
1414

15-
#include "Clauses.h"
1615
#include "flang/Lower/AbstractConverter.h"
1716
#include "flang/Lower/OpenMP.h"
17+
#include "flang/Lower/OpenMP/Clauses.h"
1818
#include "flang/Optimizer/Builder/FIRBuilder.h"
1919
#include "flang/Parser/parse-tree.h"
2020
#include "flang/Semantics/symbol.h"

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 4 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
#include "flang/Lower/OpenMP.h"
1414

1515
#include "ClauseProcessor.h"
16-
#include "Clauses.h"
1716
#include "DataSharingProcessor.h"
1817
#include "DirectivesCommon.h"
1918
#include "ReductionProcessor.h"
20-
#include "Utils.h"
2119
#include "flang/Common/idioms.h"
2220
#include "flang/Lower/Bridge.h"
2321
#include "flang/Lower/ConvertExpr.h"
2422
#include "flang/Lower/ConvertVariable.h"
23+
#include "flang/Lower/OpenMP/Clauses.h"
24+
#include "flang/Lower/OpenMP/Utils.h"
2525
#include "flang/Lower/StatementContext.h"
2626
#include "flang/Lower/SymbolMap.h"
2727
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -280,84 +280,6 @@ static void threadPrivatizeVars(Fortran::lower::AbstractConverter &converter,
280280
}
281281
}
282282

283-
static mlir::Value
284-
calculateTripCount(Fortran::lower::AbstractConverter &converter,
285-
mlir::Location loc,
286-
const mlir::omp::CollapseClauseOps &ops) {
287-
using namespace mlir::arith;
288-
assert(ops.loopLBVar.size() == ops.loopUBVar.size() &&
289-
ops.loopLBVar.size() == ops.loopStepVar.size() &&
290-
!ops.loopLBVar.empty() && "Invalid bounds or step");
291-
292-
fir::FirOpBuilder &b = converter.getFirOpBuilder();
293-
294-
// Get the bit width of an integer-like type.
295-
auto widthOf = [](mlir::Type ty) -> unsigned {
296-
if (mlir::isa<mlir::IndexType>(ty)) {
297-
return mlir::IndexType::kInternalStorageBitWidth;
298-
}
299-
if (auto tyInt = mlir::dyn_cast<mlir::IntegerType>(ty)) {
300-
return tyInt.getWidth();
301-
}
302-
llvm_unreachable("Unexpected type");
303-
};
304-
305-
// For a type that is either IntegerType or IndexType, return the
306-
// equivalent IntegerType. In the former case this is a no-op.
307-
auto asIntTy = [&](mlir::Type ty) -> mlir::IntegerType {
308-
if (ty.isIndex()) {
309-
return mlir::IntegerType::get(ty.getContext(), widthOf(ty));
310-
}
311-
assert(ty.isIntOrIndex() && "Unexpected type");
312-
return mlir::cast<mlir::IntegerType>(ty);
313-
};
314-
315-
// For two given values, establish a common signless IntegerType
316-
// that can represent any value of type of x and of type of y,
317-
// and return the pair of x, y converted to the new type.
318-
auto unifyToSignless =
319-
[&](fir::FirOpBuilder &b, mlir::Value x,
320-
mlir::Value y) -> std::pair<mlir::Value, mlir::Value> {
321-
auto tyX = asIntTy(x.getType()), tyY = asIntTy(y.getType());
322-
unsigned width = std::max(widthOf(tyX), widthOf(tyY));
323-
auto wideTy = mlir::IntegerType::get(b.getContext(), width,
324-
mlir::IntegerType::Signless);
325-
return std::make_pair(b.createConvert(loc, wideTy, x),
326-
b.createConvert(loc, wideTy, y));
327-
};
328-
329-
// Start with signless i32 by default.
330-
auto tripCount = b.createIntegerConstant(loc, b.getI32Type(), 1);
331-
332-
for (auto [origLb, origUb, origStep] :
333-
llvm::zip(ops.loopLBVar, ops.loopUBVar, ops.loopStepVar)) {
334-
auto tmpS0 = b.createIntegerConstant(loc, origStep.getType(), 0);
335-
auto [step, step0] = unifyToSignless(b, origStep, tmpS0);
336-
auto reverseCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, step, step0);
337-
auto negStep = b.create<SubIOp>(loc, step0, step);
338-
mlir::Value absStep = b.create<SelectOp>(loc, reverseCond, negStep, step);
339-
340-
auto [lb, ub] = unifyToSignless(b, origLb, origUb);
341-
auto start = b.create<SelectOp>(loc, reverseCond, ub, lb);
342-
auto end = b.create<SelectOp>(loc, reverseCond, lb, ub);
343-
344-
mlir::Value range = b.create<SubIOp>(loc, end, start);
345-
auto rangeCond = b.create<CmpIOp>(loc, CmpIPredicate::slt, end, start);
346-
std::tie(range, absStep) = unifyToSignless(b, range, absStep);
347-
// numSteps = (range /u absStep) + 1
348-
auto numSteps =
349-
b.create<AddIOp>(loc, b.create<DivUIOp>(loc, range, absStep),
350-
b.createIntegerConstant(loc, range.getType(), 1));
351-
352-
auto trip0 = b.createIntegerConstant(loc, numSteps.getType(), 0);
353-
auto loopTripCount = b.create<SelectOp>(loc, rangeCond, trip0, numSteps);
354-
auto [totalTC, thisTC] = unifyToSignless(b, tripCount, loopTripCount);
355-
tripCount = b.create<MulIOp>(loc, totalTC, thisTC);
356-
}
357-
358-
return tripCount;
359-
}
360-
361283
static mlir::Operation *
362284
createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
363285
mlir::Location loc, mlir::Value indexVal,
@@ -1572,8 +1494,8 @@ genLoopNestOp(Fortran::lower::AbstractConverter &converter,
15721494
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
15731495
ClauseProcessor cp(converter, semaCtx, clauses);
15741496
cp.processCollapse(loc, eval, collapseClauseOps, iv);
1575-
targetOp.getTripCountMutable().assign(
1576-
calculateTripCount(converter, loc, collapseClauseOps));
1497+
targetOp.getTripCountMutable().assign(calculateTripCount(
1498+
converter.getFirOpBuilder(), loc, collapseClauseOps));
15771499
}
15781500
return loopNestOp;
15791501
}

flang/lib/Lower/OpenMP/ReductionProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#ifndef FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1414
#define FORTRAN_LOWER_REDUCTIONPROCESSOR_H
1515

16-
#include "Clauses.h"
16+
#include "flang/Lower/OpenMP/Clauses.h"
1717
#include "flang/Optimizer/Builder/FIRBuilder.h"
1818
#include "flang/Optimizer/Dialect/FIRType.h"
1919
#include "flang/Semantics/symbol.h"

0 commit comments

Comments
 (0)