-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Re-apply (#117867): [flang][OpenMP] Implicitly map allocatable record fields #120374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…cord fields" This re-applies llvm#117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding `dyn_cast` for the result of `getOperation()`. Instead we can assign the result to `mlir::ModuleOp` directly since the type of the operation is known statically (`OpT` in `OperationPass`).
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-flang-fir-hlfir Author: Kareem Ergawy (ergawy) ChangesThis re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff 12 Files Affected:
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
}
}
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
- fir::FirOpBuilder &builder,
- Fortran::lower::SymbolRef sym, mlir::Location loc) {
- mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+ mlir::Value symAddr,
+ bool isOptional,
+ mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
rawInput = declareOp.getResults()[1];
}
- // TODO: Might need revisiting to handle for non-shared clauses
- if (!symAddr) {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- symAddr = converter.getSymbolAddress(details->symbol());
- rawInput = symAddr;
- }
- }
-
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
mlir::Value isPresent;
- if (Fortran::semantics::IsOptional(sym))
+ if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
- !Fortran::semantics::IsOptional(sym)) {
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ Fortran::lower::SymbolRef sym, mlir::Location loc) {
+ return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+ Fortran::semantics::IsOptional(sym), loc);
+}
+
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
return info;
}
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+ fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+ mlir::Location loc) {
+ llvm::SmallVector<mlir::Value> bounds;
+
+ mlir::Value baseOp = info.rawInput;
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+ bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, info);
+ if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+ bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, loc, dataExv, dataExvIsAssumedSize);
+ }
+
+ return bounds;
+}
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
- llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
- mlir::Value baseOp = info.rawInput;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
- bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv, info);
- if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
- bool dataExvIsAssumedSize =
- semantics::IsAssumedSizeArray(sym.GetUltimate());
- bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv,
- dataExvIsAssumedSize);
- }
+ llvm::SmallVector<mlir::Value> bounds =
+ lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, dataExv,
+ semantics::IsAssumedSizeArray(sym.GetUltimate()),
+ converter.getCurrentLocation());
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
+ mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
#include "Utils.h"
#include "Clauses.h"
-#include <DirectivesCommon.h>
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
+ ${dialect_libs}
LINK_LIBS
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
+ ${dialect_libs}
)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
argIface
? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
: 0;
- addOperands(
- mapMutableOpRange,
- llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
- blockArgInsertIndex);
+ addOperands(mapMutableOpRange,
+ llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+ argIface.getOperation()),
+ blockArgInsertIndex);
}
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
// operation (usually function) containing the MapInfoOp because this pass
// will mutate siblings of MapInfoOp.
void runOnOperation() override {
- mlir::ModuleOp module =
- mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+ mlir::ModuleOp module = getOperation();
if (!module)
module = getOperation()->getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
+ // First, walk `omp.map.info` ops to see if any record members should be
+ // implicitly mapped.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ // TODO Test with and support more complicated cases; like arrays for
+ // records, for example.
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ // TODO For now, only consider `omp.target` ops. Other ops that support
+ // `map` clauses will follow later.
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+ assert(mapVarIdx >= 0 &&
+ mapVarIdx <
+ static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ // TODO How should `map` block argument that correspond to: `private`,
+ // `use_device_addr`, `use_device_ptr`, be handled?
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding parameter to the field in the called function.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIndicies;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types. Adapting
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+ // entities might be helpful here.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = [&]() {
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ return true;
+ }
+
+ return false;
+ }();
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ field + ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ fieldIndicies.emplace_back(fieldIdx);
+ }
+
+ if (newMapOpsForFields.empty())
+ return mlir::WalkResult::advance();
+
+ op.getMembersMutable().append(newMapOpsForFields);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+ if (oldMembersIdxAttr)
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ llvm::SmallVector<int64_t> listVec;
+
+ for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+ listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+ newMemberIndices.emplace_back(std::move(listVec));
+ }
+
+ for (int64_t newFieldIdx : fieldIndicies)
+ newMemberIndices.emplace_back(
+ llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+ op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+ op.setPartialMap(true);
+
+ return mlir::WalkResult::advance();
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+ not_to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>,
+ to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+ %0 = fir.undefined !record_t
+ fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+ %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+ %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+ omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+ %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]
|
@llvm/pr-subscribers-offload Author: Kareem Ergawy (ergawy) ChangesThis re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff 12 Files Affected:
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
}
}
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
- fir::FirOpBuilder &builder,
- Fortran::lower::SymbolRef sym, mlir::Location loc) {
- mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+ mlir::Value symAddr,
+ bool isOptional,
+ mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
rawInput = declareOp.getResults()[1];
}
- // TODO: Might need revisiting to handle for non-shared clauses
- if (!symAddr) {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- symAddr = converter.getSymbolAddress(details->symbol());
- rawInput = symAddr;
- }
- }
-
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
mlir::Value isPresent;
- if (Fortran::semantics::IsOptional(sym))
+ if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
- !Fortran::semantics::IsOptional(sym)) {
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ Fortran::lower::SymbolRef sym, mlir::Location loc) {
+ return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+ Fortran::semantics::IsOptional(sym), loc);
+}
+
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
return info;
}
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+ fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+ mlir::Location loc) {
+ llvm::SmallVector<mlir::Value> bounds;
+
+ mlir::Value baseOp = info.rawInput;
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+ bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, info);
+ if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+ bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, loc, dataExv, dataExvIsAssumedSize);
+ }
+
+ return bounds;
+}
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
- llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
- mlir::Value baseOp = info.rawInput;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
- bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv, info);
- if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
- bool dataExvIsAssumedSize =
- semantics::IsAssumedSizeArray(sym.GetUltimate());
- bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv,
- dataExvIsAssumedSize);
- }
+ llvm::SmallVector<mlir::Value> bounds =
+ lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, dataExv,
+ semantics::IsAssumedSizeArray(sym.GetUltimate()),
+ converter.getCurrentLocation());
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
+ mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
#include "Utils.h"
#include "Clauses.h"
-#include <DirectivesCommon.h>
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
+ ${dialect_libs}
LINK_LIBS
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
+ ${dialect_libs}
)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
argIface
? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
: 0;
- addOperands(
- mapMutableOpRange,
- llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
- blockArgInsertIndex);
+ addOperands(mapMutableOpRange,
+ llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+ argIface.getOperation()),
+ blockArgInsertIndex);
}
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
// operation (usually function) containing the MapInfoOp because this pass
// will mutate siblings of MapInfoOp.
void runOnOperation() override {
- mlir::ModuleOp module =
- mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+ mlir::ModuleOp module = getOperation();
if (!module)
module = getOperation()->getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
+ // First, walk `omp.map.info` ops to see if any record members should be
+ // implicitly mapped.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ // TODO Test with and support more complicated cases; like arrays for
+ // records, for example.
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ // TODO For now, only consider `omp.target` ops. Other ops that support
+ // `map` clauses will follow later.
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+ assert(mapVarIdx >= 0 &&
+ mapVarIdx <
+ static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ // TODO How should `map` block argument that correspond to: `private`,
+ // `use_device_addr`, `use_device_ptr`, be handled?
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding parameter to the field in the called function.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIndicies;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types. Adapting
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+ // entities might be helpful here.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = [&]() {
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ return true;
+ }
+
+ return false;
+ }();
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ field + ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ fieldIndicies.emplace_back(fieldIdx);
+ }
+
+ if (newMapOpsForFields.empty())
+ return mlir::WalkResult::advance();
+
+ op.getMembersMutable().append(newMapOpsForFields);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+ if (oldMembersIdxAttr)
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ llvm::SmallVector<int64_t> listVec;
+
+ for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+ listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+ newMemberIndices.emplace_back(std::move(listVec));
+ }
+
+ for (int64_t newFieldIdx : fieldIndicies)
+ newMemberIndices.emplace_back(
+ llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+ op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+ op.setPartialMap(true);
+
+ return mlir::WalkResult::advance();
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+ not_to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>,
+ to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+ %0 = fir.undefined !record_t
+ fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+ %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+ %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+ omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+ %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]
|
@llvm/pr-subscribers-flang-openmp Author: Kareem Ergawy (ergawy) ChangesThis re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff 12 Files Affected:
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
}
}
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
- fir::FirOpBuilder &builder,
- Fortran::lower::SymbolRef sym, mlir::Location loc) {
- mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+ mlir::Value symAddr,
+ bool isOptional,
+ mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
rawInput = declareOp.getResults()[1];
}
- // TODO: Might need revisiting to handle for non-shared clauses
- if (!symAddr) {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- symAddr = converter.getSymbolAddress(details->symbol());
- rawInput = symAddr;
- }
- }
-
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
mlir::Value isPresent;
- if (Fortran::semantics::IsOptional(sym))
+ if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
- !Fortran::semantics::IsOptional(sym)) {
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ Fortran::lower::SymbolRef sym, mlir::Location loc) {
+ return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+ Fortran::semantics::IsOptional(sym), loc);
+}
+
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
return info;
}
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+ fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+ mlir::Location loc) {
+ llvm::SmallVector<mlir::Value> bounds;
+
+ mlir::Value baseOp = info.rawInput;
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+ bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, info);
+ if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+ bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, loc, dataExv, dataExvIsAssumedSize);
+ }
+
+ return bounds;
+}
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
- llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
- mlir::Value baseOp = info.rawInput;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
- bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv, info);
- if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
- bool dataExvIsAssumedSize =
- semantics::IsAssumedSizeArray(sym.GetUltimate());
- bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv,
- dataExvIsAssumedSize);
- }
+ llvm::SmallVector<mlir::Value> bounds =
+ lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, dataExv,
+ semantics::IsAssumedSizeArray(sym.GetUltimate()),
+ converter.getCurrentLocation());
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
+ mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
#include "Utils.h"
#include "Clauses.h"
-#include <DirectivesCommon.h>
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
+ ${dialect_libs}
LINK_LIBS
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
+ ${dialect_libs}
)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
argIface
? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
: 0;
- addOperands(
- mapMutableOpRange,
- llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
- blockArgInsertIndex);
+ addOperands(mapMutableOpRange,
+ llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+ argIface.getOperation()),
+ blockArgInsertIndex);
}
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
// operation (usually function) containing the MapInfoOp because this pass
// will mutate siblings of MapInfoOp.
void runOnOperation() override {
- mlir::ModuleOp module =
- mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+ mlir::ModuleOp module = getOperation();
if (!module)
module = getOperation()->getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
+ // First, walk `omp.map.info` ops to see if any record members should be
+ // implicitly mapped.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ // TODO Test with and support more complicated cases; like arrays for
+ // records, for example.
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ // TODO For now, only consider `omp.target` ops. Other ops that support
+ // `map` clauses will follow later.
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+ assert(mapVarIdx >= 0 &&
+ mapVarIdx <
+ static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ // TODO How should `map` block argument that correspond to: `private`,
+ // `use_device_addr`, `use_device_ptr`, be handled?
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding parameter to the field in the called function.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIndicies;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types. Adapting
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+ // entities might be helpful here.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = [&]() {
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ return true;
+ }
+
+ return false;
+ }();
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ field + ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ fieldIndicies.emplace_back(fieldIdx);
+ }
+
+ if (newMapOpsForFields.empty())
+ return mlir::WalkResult::advance();
+
+ op.getMembersMutable().append(newMapOpsForFields);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+ if (oldMembersIdxAttr)
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ llvm::SmallVector<int64_t> listVec;
+
+ for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+ listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+ newMemberIndices.emplace_back(std::move(listVec));
+ }
+
+ for (int64_t newFieldIdx : fieldIndicies)
+ newMemberIndices.emplace_back(
+ llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+ op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+ op.setPartialMap(true);
+
+ return mlir::WalkResult::advance();
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+ not_to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>,
+ to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+ %0 = fir.undefined !record_t
+ fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+ %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+ %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+ omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+ %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]
|
@llvm/pr-subscribers-mlir-openmp Author: Kareem Ergawy (ergawy) ChangesThis re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff 12 Files Affected:
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
}
}
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
- fir::FirOpBuilder &builder,
- Fortran::lower::SymbolRef sym, mlir::Location loc) {
- mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+ mlir::Value symAddr,
+ bool isOptional,
+ mlir::Location loc) {
mlir::Value rawInput = symAddr;
if (auto declareOp =
mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
rawInput = declareOp.getResults()[1];
}
- // TODO: Might need revisiting to handle for non-shared clauses
- if (!symAddr) {
- if (const auto *details =
- sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
- symAddr = converter.getSymbolAddress(details->symbol());
- rawInput = symAddr;
- }
- }
-
if (!symAddr)
llvm::report_fatal_error("could not retrieve symbol address");
mlir::Value isPresent;
- if (Fortran::semantics::IsOptional(sym))
+ if (isOptional)
isPresent =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
// all address/dimension retrievals. For Fortran optional though, leave
// the load generation for later so it can be done in the appropriate
// if branches.
- if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
- !Fortran::semantics::IsOptional(sym)) {
+ if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
}
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
}
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+ fir::FirOpBuilder &builder,
+ Fortran::lower::SymbolRef sym, mlir::Location loc) {
+ return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+ Fortran::semantics::IsOptional(sym), loc);
+}
+
template <typename BoundsOp, typename BoundsType>
llvm::SmallVector<mlir::Value>
gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
return info;
}
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+ fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+ mlir::Location loc) {
+ llvm::SmallVector<mlir::Value> bounds;
+
+ mlir::Value baseOp = info.rawInput;
+ if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+ bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+ dataExv, info);
+ if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+ bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+ builder, loc, dataExv, dataExvIsAssumedSize);
+ }
+
+ return bounds;
+}
} // namespace lower
} // namespace Fortran
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/Version.h"
#include "flang/Lower/Allocatable.h"
#include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
#include "flang/Lower/ConvertType.h"
#include "flang/Lower/ConvertVariable.h"
#include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/HostAssociations.h"
#include "flang/Lower/IO.h"
#include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
//===----------------------------------------------------------------------===//
#include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/Mangler.h"
#include "flang/Lower/PFTBuilder.h"
#include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
#define FORTRAN_LOWER_CLAUSEPROCESSOR_H
#include "Clauses.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Lower/AbstractConverter.h"
#include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/Todo.h"
#include "flang/Parser/dump-parse-tree.h"
#include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
#include "Clauses.h"
#include "DataSharingProcessor.h"
#include "Decomposer.h"
-#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
#include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
#include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Lower/StatementContext.h"
#include "flang/Lower/SymbolMap.h"
#include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
if (const auto *details =
sym.template detailsIf<semantics::HostAssocDetails>())
converter.copySymbolBinding(details->symbol(), sym);
- llvm::SmallVector<mlir::Value> bounds;
std::stringstream name;
fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
name << sym.name().ToString();
lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
converter, firOpBuilder, sym, converter.getCurrentLocation());
- mlir::Value baseOp = info.rawInput;
- if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
- bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv, info);
- if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
- bool dataExvIsAssumedSize =
- semantics::IsAssumedSizeArray(sym.GetUltimate());
- bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
- mlir::omp::MapBoundsType>(
- firOpBuilder, converter.getCurrentLocation(), dataExv,
- dataExvIsAssumedSize);
- }
+ llvm::SmallVector<mlir::Value> bounds =
+ lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ firOpBuilder, info, dataExv,
+ semantics::IsAssumedSizeArray(sym.GetUltimate()),
+ converter.getCurrentLocation());
llvm::omp::OpenMPOffloadMappingFlags mapFlag =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
mlir::omp::VariableCaptureKind captureKind =
mlir::omp::VariableCaptureKind::ByRef;
+ mlir::Value baseOp = info.rawInput;
mlir::Type eleType = baseOp.getType();
if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
#include "Utils.h"
#include "Clauses.h"
-#include <DirectivesCommon.h>
#include <flang/Lower/AbstractConverter.h>
#include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
#include <flang/Lower/PFTBuilder.h>
#include <flang/Optimizer/Builder/FIRBuilder.h>
#include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
FIRDialect
HLFIROpsIncGen
FlangOpenMPPassesIncGen
+ ${dialect_libs}
LINK_LIBS
FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
MLIRIR
MLIRPass
MLIRTransformUtils
+ ${dialect_libs}
)
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
/// indirectly via a parent object.
//===----------------------------------------------------------------------===//
+#include "flang/Lower/DirectivesCommon.h"
#include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
#include "flang/Optimizer/Dialect/FIRType.h"
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
argIface
? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
: 0;
- addOperands(
- mapMutableOpRange,
- llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
- blockArgInsertIndex);
+ addOperands(mapMutableOpRange,
+ llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+ argIface.getOperation()),
+ blockArgInsertIndex);
}
if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
// operation (usually function) containing the MapInfoOp because this pass
// will mutate siblings of MapInfoOp.
void runOnOperation() override {
- mlir::ModuleOp module =
- mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+ mlir::ModuleOp module = getOperation();
if (!module)
module = getOperation()->getParentOfType<mlir::ModuleOp>();
fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
// iterations from previous function scopes.
localBoxAllocas.clear();
+ // First, walk `omp.map.info` ops to see if any record members should be
+ // implicitly mapped.
+ func->walk([&](mlir::omp::MapInfoOp op) {
+ mlir::Type underlyingType =
+ fir::unwrapRefType(op.getVarPtr().getType());
+
+ // TODO Test with and support more complicated cases; like arrays for
+ // records, for example.
+ if (!fir::isRecordWithAllocatableMember(underlyingType))
+ return mlir::WalkResult::advance();
+
+ // TODO For now, only consider `omp.target` ops. Other ops that support
+ // `map` clauses will follow later.
+ mlir::omp::TargetOp target =
+ mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+ getFirstTargetUser(op));
+
+ if (!target)
+ return mlir::WalkResult::advance();
+
+ auto mapClauseOwner =
+ llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+ int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+ assert(mapVarIdx >= 0 &&
+ mapVarIdx <
+ static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+ auto argIface =
+ llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+ // TODO How should `map` block argument that correspond to: `private`,
+ // `use_device_addr`, `use_device_ptr`, be handled?
+ mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+ llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+ mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+ mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+ // TODO Support coordinate_of ops.
+ //
+ // TODO Support call ops by recursively examining the forward slice of
+ // the corresponding parameter to the field in the called function.
+ return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+ });
+
+ auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+ llvm::SmallVector<mlir::Value> newMapOpsForFields;
+ llvm::SmallVector<int64_t> fieldIndicies;
+
+ for (auto fieldMemTyPair : recordType.getTypeList()) {
+ auto &field = fieldMemTyPair.first;
+ auto memTy = fieldMemTyPair.second;
+
+ bool shouldMapField =
+ llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+ if (!fir::isAllocatableType(memTy))
+ return false;
+
+ auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+ if (!designateOp)
+ return false;
+
+ return designateOp.getComponent() &&
+ designateOp.getComponent()->strref() == field;
+ }) != mapVarForwardSlice.end();
+
+ // TODO Handle recursive record types. Adapting
+ // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+ // entities might be helpful here.
+
+ if (!shouldMapField)
+ continue;
+
+ int64_t fieldIdx = recordType.getFieldIndex(field);
+ bool alreadyMapped = [&]() {
+ if (op.getMembersIndexAttr())
+ for (auto indexList : op.getMembersIndexAttr()) {
+ auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+ if (indexListAttr.size() == 1 &&
+ mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+ fieldIdx)
+ return true;
+ }
+
+ return false;
+ }();
+
+ if (alreadyMapped)
+ continue;
+
+ builder.setInsertionPoint(op);
+ mlir::Value fieldIdxVal = builder.createIntegerConstant(
+ op.getLoc(), mlir::IndexType::get(builder.getContext()),
+ fieldIdx);
+ auto fieldCoord = builder.create<fir::CoordinateOp>(
+ op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+ fieldIdxVal);
+ Fortran::lower::AddrAndBoundsInfo info =
+ Fortran::lower::getDataOperandBaseAddr(
+ builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+ llvm::SmallVector<mlir::Value> bounds =
+ Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+ mlir::omp::MapBoundsType>(
+ builder, info,
+ hlfir::translateToExtendedValue(op.getLoc(), builder,
+ hlfir::Entity{fieldCoord})
+ .first,
+ /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+ mlir::omp::MapInfoOp fieldMapOp =
+ builder.create<mlir::omp::MapInfoOp>(
+ op.getLoc(), fieldCoord.getResult().getType(),
+ fieldCoord.getResult(),
+ mlir::TypeAttr::get(
+ fir::unwrapRefType(fieldCoord.getResult().getType())),
+ /*varPtrPtr=*/mlir::Value{},
+ /*members=*/mlir::ValueRange{},
+ /*members_index=*/mlir::ArrayAttr{},
+ /*bounds=*/bounds, op.getMapTypeAttr(),
+ builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+ mlir::omp::VariableCaptureKind::ByRef),
+ builder.getStringAttr(op.getNameAttr().strref() + "." +
+ field + ".implicit_map"),
+ /*partial_map=*/builder.getBoolAttr(false));
+ newMapOpsForFields.emplace_back(fieldMapOp);
+ fieldIndicies.emplace_back(fieldIdx);
+ }
+
+ if (newMapOpsForFields.empty())
+ return mlir::WalkResult::advance();
+
+ op.getMembersMutable().append(newMapOpsForFields);
+ llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+ mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+ if (oldMembersIdxAttr)
+ for (mlir::Attribute indexList : oldMembersIdxAttr) {
+ llvm::SmallVector<int64_t> listVec;
+
+ for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+ listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+ newMemberIndices.emplace_back(std::move(listVec));
+ }
+
+ for (int64_t newFieldIdx : fieldIndicies)
+ newMemberIndices.emplace_back(
+ llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+ op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+ op.setPartialMap(true);
+
+ return mlir::WalkResult::advance();
+ });
+
func->walk([&](mlir::omp::MapInfoOp op) {
// TODO: Currently only supports a single user for the MapInfoOp. This
// is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+ not_to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>,
+ to_implicitly_map:
+ !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+ %0 = fir.undefined !record_t
+ fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+ %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+ %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+ %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+ omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+ %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]
|
@kazutakahirata it would be great if you can run locally with the patch to verify it compiles 🙏. |
Thanks! I'm on it. I'll let you know soon. |
With your new patch, my local tree builds fine ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
My local build with the shared libraries is broken. I suppose this was introduced by llvm#120374. `flang/include/flang/Evaluate/constant.h` ends up being included by `MapInfoFinalization.cpp` via `flang/Lower/DirectivesCommon.h`. The undefined references are related to `ConstantBase` classes.
My local build with the shared libraries is broken. I suppose this was introduced by #120374. `flang/include/flang/Evaluate/constant.h` ends up being included by `MapInfoFinalization.cpp` via `flang/Lower/DirectivesCommon.h`. The undefined references are related to `ConstantBase` classes.
This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding
dyn_cast
for the result ofgetOperation()
. Instead we can assign the result tomlir::ModuleOp
directly since the type of the operation is known statically (OpT
inOperationPass
).