Skip to content

[Flang][OpenMP] Remove use of non reference values from MapInfoOp #72444

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 89 additions & 87 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "flang/Semantics/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Support/CommandLine.h"

Expand Down Expand Up @@ -1709,26 +1710,22 @@ static mlir::omp::MapInfoOp
createMapInfoOp(fir::FirOpBuilder &builder, mlir::Location loc,
mlir::Value baseAddr, std::stringstream &name,
mlir::SmallVector<mlir::Value> bounds, uint64_t mapType,
mlir::omp::VariableCaptureKind mapCaptureType, mlir::Type retTy,
bool isVal = false) {
mlir::Value val, varPtr, varPtrPtr;
mlir::omp::VariableCaptureKind mapCaptureType,
mlir::Type retTy) {
mlir::Value varPtr, varPtrPtr;
mlir::TypeAttr varType;

if (auto boxTy = baseAddr.getType().dyn_cast<fir::BaseBoxType>()) {
baseAddr = builder.create<fir::BoxAddrOp>(loc, baseAddr);
retTy = baseAddr.getType();
}

if (isVal)
val = baseAddr;
else
varPtr = baseAddr;

if (auto ptrType = llvm::dyn_cast<mlir::omp::PointerLikeType>(retTy))
varType = mlir::TypeAttr::get(ptrType.getElementType());
varPtr = baseAddr;
varType = mlir::TypeAttr::get(
llvm::cast<mlir::omp::PointerLikeType>(retTy).getElementType());

mlir::omp::MapInfoOp op = builder.create<mlir::omp::MapInfoOp>(
loc, retTy, val, varPtr, varType, varPtrPtr, bounds,
loc, retTy, varPtr, varType, varPtrPtr, bounds,
builder.getIntegerAttr(builder.getIntegerType(64, false), mapType),
builder.getAttr<mlir::omp::VariableCaptureKindAttr>(mapCaptureType),
builder.getStringAttr(name.str()));
Expand Down Expand Up @@ -2489,21 +2486,27 @@ static void genBodyOfTargetOp(
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::Region &region = targetOp.getRegion();

firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);
auto *regionBlock =
firOpBuilder.createBlock(&region, {}, mapSymTypes, mapSymLocs);

unsigned argIndex = 0;
unsigned blockArgsIndex = mapSymbols.size();

// The block arguments contain the map_operands followed by the bounds in
// order. This returns a vector containing the next 'n' block arguments for
// the bounds.
auto extractBoundArgs = [&](auto n) {
llvm::SmallVector<mlir::Value> argExtents;
while (n--) {
argExtents.push_back(fir::getBase(region.getArgument(blockArgsIndex)));
blockArgsIndex++;

// Clones the `bounds` placing them inside the target region and returns them.
auto cloneBound = [&](mlir::Value bound) {
if (mlir::isMemoryEffectFree(bound.getDefiningOp())) {
mlir::Operation *clonedOp = bound.getDefiningOp()->clone();
regionBlock->push_back(clonedOp);
return clonedOp->getResult(0);
}
return argExtents;
TODO(converter.getCurrentLocation(),
"target map clause operand unsupported bound type");
};

auto cloneBounds = [cloneBound](llvm::ArrayRef<mlir::Value> bounds) {
llvm::SmallVector<mlir::Value> clonedBounds;
for (mlir::Value bound : bounds)
clonedBounds.emplace_back(cloneBound(bound));
return clonedBounds;
};

// Bind the symbols to their corresponding block arguments.
Expand All @@ -2512,34 +2515,31 @@ static void genBodyOfTargetOp(
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
extVal.match(
[&](const fir::BoxValue &v) {
converter.bindSymbol(
*sym,
fir::BoxValue(arg, extractBoundArgs(v.getLBounds().size()),
v.getExplicitParameters(), v.getExplicitExtents()));
converter.bindSymbol(*sym,
fir::BoxValue(arg, cloneBounds(v.getLBounds()),
v.getExplicitParameters(),
v.getExplicitExtents()));
},
[&](const fir::MutableBoxValue &v) {
converter.bindSymbol(
*sym,
fir::MutableBoxValue(arg, extractBoundArgs(v.getLBounds().size()),
v.getMutableProperties()));
*sym, fir::MutableBoxValue(arg, cloneBounds(v.getLBounds()),
v.getMutableProperties()));
},
[&](const fir::ArrayBoxValue &v) {
converter.bindSymbol(
*sym,
fir::ArrayBoxValue(arg, extractBoundArgs(v.getExtents().size()),
extractBoundArgs(v.getLBounds().size()),
v.getSourceBox()));
*sym, fir::ArrayBoxValue(arg, cloneBounds(v.getExtents()),
cloneBounds(v.getLBounds()),
v.getSourceBox()));
},
[&](const fir::CharArrayBoxValue &v) {
converter.bindSymbol(
*sym,
fir::CharArrayBoxValue(arg, extractBoundArgs(1).front(),
extractBoundArgs(v.getExtents().size()),
extractBoundArgs(v.getLBounds().size())));
*sym, fir::CharArrayBoxValue(arg, cloneBound(v.getLen()),
cloneBounds(v.getExtents()),
cloneBounds(v.getLBounds())));
},
[&](const fir::CharBoxValue &v) {
converter.bindSymbol(
*sym, fir::CharBoxValue(arg, extractBoundArgs(1).front()));
converter.bindSymbol(*sym,
fir::CharBoxValue(arg, cloneBound(v.getLen())));
},
[&](const fir::UnboxedValue &v) { converter.bindSymbol(*sym, arg); },
[&](const auto &) {
Expand All @@ -2549,6 +2549,55 @@ static void genBodyOfTargetOp(
argIndex++;
}

// Check if cloning the bounds introduced any dependency on the outer region.
// If so, then either clone them as well if they are MemoryEffectFree, or else
// copy them to a new temporary and add them to the map and block_argument
// lists and replace their uses with the new temporary.
llvm::SetVector<mlir::Value> valuesDefinedAbove;
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
while (!valuesDefinedAbove.empty()) {
for (mlir::Value val : valuesDefinedAbove) {
mlir::Operation *valOp = val.getDefiningOp();
if (mlir::isMemoryEffectFree(valOp)) {
mlir::Operation *clonedOp = valOp->clone();
regionBlock->push_front(clonedOp);
val.replaceUsesWithIf(
clonedOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
return use.getOwner()->getBlock() == regionBlock;
});
} else {
auto savedIP = firOpBuilder.getInsertionPoint();
firOpBuilder.setInsertionPointAfter(valOp);
auto copyVal =
firOpBuilder.createTemporary(val.getLoc(), val.getType());
firOpBuilder.createStoreWithConvert(copyVal.getLoc(), val, copyVal);

llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
firOpBuilder.setInsertionPoint(targetOp);
mlir::Value mapOp = createMapInfoOp(
firOpBuilder, copyVal.getLoc(), copyVal, name, bounds,
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
mlir::omp::VariableCaptureKind::ByCopy, copyVal.getType());
targetOp.getMapOperandsMutable().append(mapOp);
mlir::Value clonedValArg =
region.addArgument(copyVal.getType(), copyVal.getLoc());
firOpBuilder.setInsertionPointToStart(regionBlock);
auto loadOp = firOpBuilder.create<fir::LoadOp>(clonedValArg.getLoc(),
clonedValArg);
val.replaceUsesWithIf(
loadOp->getResult(0), [regionBlock](mlir::OpOperand &use) {
return use.getOwner()->getBlock() == regionBlock;
});
firOpBuilder.setInsertionPoint(regionBlock, savedIP);
}
}
valuesDefinedAbove.clear();
mlir::getUsedValuesDefinedAbove(region, valuesDefinedAbove);
}

// Insert dummy instruction to remember the insertion position. The
// marker will be deleted since there are not uses.
// In the HLFIR flow there are hlfir.declares inserted above while
Expand Down Expand Up @@ -2671,53 +2720,6 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
};
Fortran::lower::pft::visitAllSymbols(eval, captureImplicitMap);

// Add the bounds and extents for box values to mapOperands
auto addMapInfoForBounds = [&](const auto &bounds) {
for (auto &val : bounds) {
mapSymLocs.push_back(val.getLoc());
mapSymTypes.push_back(val.getType());

llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;

mlir::Value mapOp = createMapInfoOp(
converter.getFirOpBuilder(), val.getLoc(), val, name, bounds,
static_cast<
std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT),
mlir::omp::VariableCaptureKind::ByCopy, val.getType(), true);
mapOperands.push_back(mapOp);
}
};

for (const Fortran::semantics::Symbol *sym : mapSymbols) {
fir::ExtendedValue extVal = converter.getSymbolExtendedValue(*sym);
extVal.match(
[&](const fir::BoxValue &v) { addMapInfoForBounds(v.getLBounds()); },
[&](const fir::MutableBoxValue &v) {
addMapInfoForBounds(v.getLBounds());
},
[&](const fir::ArrayBoxValue &v) {
addMapInfoForBounds(v.getExtents());
addMapInfoForBounds(v.getLBounds());
},
[&](const fir::CharArrayBoxValue &v) {
llvm::SmallVector<mlir::Value> len;
len.push_back(v.getLen());
addMapInfoForBounds(len);
addMapInfoForBounds(v.getExtents());
addMapInfoForBounds(v.getLBounds());
},
[&](const fir::CharBoxValue &v) {
llvm::SmallVector<mlir::Value> len;
len.push_back(v.getLen());
addMapInfoForBounds(len);
},
[&](const auto &) {
// Nothing to do for non-box values.
});
}

auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
nowaitAttr, mapOperands);
Expand Down
4 changes: 2 additions & 2 deletions flang/test/Lower/OpenMP/FIR/array-bounds.f90
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
!ALL: %[[BOUNDS1:.*]] = omp.bounds lower_bound(%[[C5]] : index) upper_bound(%[[C6]] : index) stride(%[[C4]] : index) start_idx(%[[C4]] : index)
!ALL: %[[MAP1:.*]] = omp.map_info var_ptr(%[[WRITE]] : !fir.ref<!fir.array<10xi32>>, !fir.array<10xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS1]]) -> !fir.ref<!fir.array<10xi32>> {name = "sp_write(2:5)"}
!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ITER]] : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>, index, index) {
!ALL: omp.target map_entries(%[[MAP0]] -> %{{.*}}, %[[MAP1]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<10xi32>>, !fir.ref<!fir.array<10xi32>>, !fir.ref<i32>) {

subroutine read_write_section()
integer :: sp_read(10) = (/1,2,3,4,5,6,7,8,9,10/)
Expand Down Expand Up @@ -64,7 +64,7 @@ end subroutine assumed_shape_array
!ALL: %[[BOUNDS:.*]] = omp.bounds lower_bound(%[[C1]] : index) upper_bound(%[[C2]] : index) stride(%[[C0]] : index) start_idx(%[[C0]] : index)
!ALL: %[[MAP:.*]] = omp.map_info var_ptr(%[[ARG0]] : !fir.ref<!fir.array<?xi32>>, !fir.array<?xi32>) map_clauses(tofrom) capture(ByRef) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xi32>> {name = "arr_read_write(2:5)"}
!ALL: %[[MAP2:.*]] = omp.map_info var_ptr(%[[ALLOCA]] : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}}, %{{.*}} -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>, index) {
!ALL: omp.target map_entries(%[[MAP]] -> %{{.*}}, %[[MAP2]] -> %{{.*}} : !fir.ref<!fir.array<?xi32>>, !fir.ref<i32>) {

subroutine assumed_size_array(arr_read_write)
integer, intent(inout) :: arr_read_write(*)
Expand Down
Loading