Skip to content

Re-apply (#117867): [flang][OpenMP] Implicitly map allocatable record fields #120374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Dec 18, 2024

This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding dyn_cast for the result of getOperation(). Instead we can assign the result to mlir::ModuleOp directly since the type of the operation is known statically (OpT in OperationPass).

…cord fields"

This re-applies llvm#117867 with a small fix that hopefully prevents build
bot failures. The fix is avoiding `dyn_cast` for the result of `getOperation()`.
Instead we can assign the result to `mlir::ModuleOp` directly since the
type of the operation is known statically (`OpT` in `OperationPass`).
@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding dyn_cast for the result of getOperation(). Instead we can assign the result to mlir::ModuleOp directly since the type of the operation is known statically (OpT in OperationPass).


Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff

12 Files Affected:

  • (renamed) flang/include/flang/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/Bridge.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+8-15)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+1-1)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+163-6)
  • (added) flang/test/Transforms/omp-map-info-finalization-implicit-field.fir (+63)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+7)
  • (added) offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 (+83)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+52)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
           argIface
               ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
               : 0;
-      addOperands(
-          mapMutableOpRange,
-          llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
-          blockArgInsertIndex);
+      addOperands(mapMutableOpRange,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+                      argIface.getOperation()),
+                  blockArgInsertIndex);
     }
 
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
   // operation (usually function) containing the MapInfoOp because this pass
   // will mutate siblings of MapInfoOp.
   void runOnOperation() override {
-    mlir::ModuleOp module =
-        mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+    mlir::ModuleOp module = getOperation();
     if (!module)
       module = getOperation()->getParentOfType<mlir::ModuleOp>();
     fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-offload

Author: Kareem Ergawy (ergawy)

Changes

This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding dyn_cast for the result of getOperation(). Instead we can assign the result to mlir::ModuleOp directly since the type of the operation is known statically (OpT in OperationPass).


Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff

12 Files Affected:

  • (renamed) flang/include/flang/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/Bridge.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+8-15)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+1-1)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+163-6)
  • (added) flang/test/Transforms/omp-map-info-finalization-implicit-field.fir (+63)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+7)
  • (added) offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 (+83)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+52)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
           argIface
               ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
               : 0;
-      addOperands(
-          mapMutableOpRange,
-          llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
-          blockArgInsertIndex);
+      addOperands(mapMutableOpRange,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+                      argIface.getOperation()),
+                  blockArgInsertIndex);
     }
 
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
   // operation (usually function) containing the MapInfoOp because this pass
   // will mutate siblings of MapInfoOp.
   void runOnOperation() override {
-    mlir::ModuleOp module =
-        mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+    mlir::ModuleOp module = getOperation();
     if (!module)
       module = getOperation()->getParentOfType<mlir::ModuleOp>();
     fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-flang-openmp

Author: Kareem Ergawy (ergawy)

Changes

This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding dyn_cast for the result of getOperation(). Instead we can assign the result to mlir::ModuleOp directly since the type of the operation is known statically (OpT in OperationPass).


Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff

12 Files Affected:

  • (renamed) flang/include/flang/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/Bridge.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+8-15)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+1-1)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+163-6)
  • (added) flang/test/Transforms/omp-map-info-finalization-implicit-field.fir (+63)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+7)
  • (added) offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 (+83)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+52)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
           argIface
               ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
               : 0;
-      addOperands(
-          mapMutableOpRange,
-          llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
-          blockArgInsertIndex);
+      addOperands(mapMutableOpRange,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+                      argIface.getOperation()),
+                  blockArgInsertIndex);
     }
 
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
   // operation (usually function) containing the MapInfoOp because this pass
   // will mutate siblings of MapInfoOp.
   void runOnOperation() override {
-    mlir::ModuleOp module =
-        mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+    mlir::ModuleOp module = getOperation();
     if (!module)
       module = getOperation()->getParentOfType<mlir::ModuleOp>();
     fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 18, 2024

@llvm/pr-subscribers-mlir-openmp

Author: Kareem Ergawy (ergawy)

Changes

This re-applies #117867 with a small fix that hopefully prevents build bot failures. The fix is avoiding dyn_cast for the result of getOperation(). Instead we can assign the result to mlir::ModuleOp directly since the type of the operation is known statically (OpT in OperationPass).


Patch is 26.65 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/120374.diff

12 Files Affected:

  • (renamed) flang/include/flang/Lower/DirectivesCommon.h (+33-17)
  • (modified) flang/lib/Lower/Bridge.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenACC.cpp (+2-1)
  • (modified) flang/lib/Lower/OpenMP/ClauseProcessor.h (+1-1)
  • (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+8-15)
  • (modified) flang/lib/Lower/OpenMP/Utils.cpp (+1-1)
  • (modified) flang/lib/Optimizer/OpenMP/CMakeLists.txt (+2)
  • (modified) flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp (+163-6)
  • (added) flang/test/Transforms/omp-map-info-finalization-implicit-field.fir (+63)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+7)
  • (added) offload/test/offloading/fortran/explicit-and-implicit-record-field-mapping.f90 (+83)
  • (added) offload/test/offloading/fortran/implicit-record-field-mapping.f90 (+52)
diff --git a/flang/lib/Lower/DirectivesCommon.h b/flang/include/flang/Lower/DirectivesCommon.h
similarity index 97%
rename from flang/lib/Lower/DirectivesCommon.h
rename to flang/include/flang/Lower/DirectivesCommon.h
index 88514b16743278..6e2c6ee4b1bcdb 100644
--- a/flang/lib/Lower/DirectivesCommon.h
+++ b/flang/include/flang/Lower/DirectivesCommon.h
@@ -609,11 +609,10 @@ void createEmptyRegionBlocks(
   }
 }
 
-inline AddrAndBoundsInfo
-getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
-                       fir::FirOpBuilder &builder,
-                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
-  mlir::Value symAddr = converter.getSymbolAddress(sym);
+inline AddrAndBoundsInfo getDataOperandBaseAddr(fir::FirOpBuilder &builder,
+                                                mlir::Value symAddr,
+                                                bool isOptional,
+                                                mlir::Location loc) {
   mlir::Value rawInput = symAddr;
   if (auto declareOp =
           mlir::dyn_cast_or_null<hlfir::DeclareOp>(symAddr.getDefiningOp())) {
@@ -621,20 +620,11 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     rawInput = declareOp.getResults()[1];
   }
 
-  // TODO: Might need revisiting to handle for non-shared clauses
-  if (!symAddr) {
-    if (const auto *details =
-            sym->detailsIf<Fortran::semantics::HostAssocDetails>()) {
-      symAddr = converter.getSymbolAddress(details->symbol());
-      rawInput = symAddr;
-    }
-  }
-
   if (!symAddr)
     llvm::report_fatal_error("could not retrieve symbol address");
 
   mlir::Value isPresent;
-  if (Fortran::semantics::IsOptional(sym))
+  if (isOptional)
     isPresent =
         builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), rawInput);
 
@@ -648,8 +638,7 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
     // all address/dimension retrievals. For Fortran optional though, leave
     // the load generation for later so it can be done in the appropriate
     // if branches.
-    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) &&
-        !Fortran::semantics::IsOptional(sym)) {
+    if (mlir::isa<fir::ReferenceType>(symAddr.getType()) && !isOptional) {
       mlir::Value addr = builder.create<fir::LoadOp>(loc, symAddr);
       return AddrAndBoundsInfo(addr, rawInput, isPresent, boxTy);
     }
@@ -659,6 +648,14 @@ getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
   return AddrAndBoundsInfo(symAddr, rawInput, isPresent);
 }
 
+inline AddrAndBoundsInfo
+getDataOperandBaseAddr(Fortran::lower::AbstractConverter &converter,
+                       fir::FirOpBuilder &builder,
+                       Fortran::lower::SymbolRef sym, mlir::Location loc) {
+  return getDataOperandBaseAddr(builder, converter.getSymbolAddress(sym),
+                                Fortran::semantics::IsOptional(sym), loc);
+}
+
 template <typename BoundsOp, typename BoundsType>
 llvm::SmallVector<mlir::Value>
 gatherBoundsOrBoundValues(fir::FirOpBuilder &builder, mlir::Location loc,
@@ -1224,6 +1221,25 @@ AddrAndBoundsInfo gatherDataOperandAddrAndBounds(
 
   return info;
 }
+
+template <typename BoundsOp, typename BoundsType>
+llvm::SmallVector<mlir::Value>
+genImplicitBoundsOps(fir::FirOpBuilder &builder, lower::AddrAndBoundsInfo &info,
+                     fir::ExtendedValue dataExv, bool dataExvIsAssumedSize,
+                     mlir::Location loc) {
+  llvm::SmallVector<mlir::Value> bounds;
+
+  mlir::Value baseOp = info.rawInput;
+  if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
+    bounds = lower::genBoundsOpsFromBox<BoundsOp, BoundsType>(builder, loc,
+                                                              dataExv, info);
+  if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
+    bounds = lower::genBaseBoundsOps<BoundsOp, BoundsType>(
+        builder, loc, dataExv, dataExvIsAssumedSize);
+  }
+
+  return bounds;
+}
 } // namespace lower
 } // namespace Fortran
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0650433dbaf394..f5883dcedb2b67 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -11,7 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/Bridge.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/Version.h"
 #include "flang/Lower/Allocatable.h"
 #include "flang/Lower/CallInterface.h"
@@ -22,6 +22,7 @@
 #include "flang/Lower/ConvertType.h"
 #include "flang/Lower/ConvertVariable.h"
 #include "flang/Lower/Cuda.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/HostAssociations.h"
 #include "flang/Lower/IO.h"
 #include "flang/Lower/IterationSpace.h"
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 75dcf6ec3e1107..ed18ad89c16ef5 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -11,10 +11,11 @@
 //===----------------------------------------------------------------------===//
 
 #include "flang/Lower/OpenACC.h"
-#include "DirectivesCommon.h"
+
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertType.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/Mangler.h"
 #include "flang/Lower/PFTBuilder.h"
 #include "flang/Lower/StatementContext.h"
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h
index 3942c54e6e935d..7b047d4a7567ad 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.h
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h
@@ -13,11 +13,11 @@
 #define FORTRAN_LOWER_CLAUSEPROCESSOR_H
 
 #include "Clauses.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Lower/AbstractConverter.h"
 #include "flang/Lower/Bridge.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/Todo.h"
 #include "flang/Parser/dump-parse-tree.h"
 #include "flang/Parser/parse-tree.h"
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c61ab67d95a957..b07e89d201d198 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -16,7 +16,6 @@
 #include "Clauses.h"
 #include "DataSharingProcessor.h"
 #include "Decomposer.h"
-#include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
 #include "flang/Common/OpenMP-utils.h"
@@ -24,6 +23,7 @@
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
 #include "flang/Lower/ConvertVariable.h"
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Lower/StatementContext.h"
 #include "flang/Lower/SymbolMap.h"
 #include "flang/Optimizer/Builder/BoxValue.h"
@@ -1735,32 +1735,25 @@ genTargetOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
       if (const auto *details =
               sym.template detailsIf<semantics::HostAssocDetails>())
         converter.copySymbolBinding(details->symbol(), sym);
-      llvm::SmallVector<mlir::Value> bounds;
       std::stringstream name;
       fir::ExtendedValue dataExv = converter.getSymbolExtendedValue(sym);
       name << sym.name().ToString();
 
       lower::AddrAndBoundsInfo info = getDataOperandBaseAddr(
           converter, firOpBuilder, sym, converter.getCurrentLocation());
-      mlir::Value baseOp = info.rawInput;
-      if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(baseOp.getType())))
-        bounds = lower::genBoundsOpsFromBox<mlir::omp::MapBoundsOp,
-                                            mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv, info);
-      if (mlir::isa<fir::SequenceType>(fir::unwrapRefType(baseOp.getType()))) {
-        bool dataExvIsAssumedSize =
-            semantics::IsAssumedSizeArray(sym.GetUltimate());
-        bounds = lower::genBaseBoundsOps<mlir::omp::MapBoundsOp,
-                                         mlir::omp::MapBoundsType>(
-            firOpBuilder, converter.getCurrentLocation(), dataExv,
-            dataExvIsAssumedSize);
-      }
+      llvm::SmallVector<mlir::Value> bounds =
+          lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                      mlir::omp::MapBoundsType>(
+              firOpBuilder, info, dataExv,
+              semantics::IsAssumedSizeArray(sym.GetUltimate()),
+              converter.getCurrentLocation());
 
       llvm::omp::OpenMPOffloadMappingFlags mapFlag =
           llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_IMPLICIT;
       mlir::omp::VariableCaptureKind captureKind =
           mlir::omp::VariableCaptureKind::ByRef;
 
+      mlir::Value baseOp = info.rawInput;
       mlir::Type eleType = baseOp.getType();
       if (auto refType = mlir::dyn_cast<fir::ReferenceType>(baseOp.getType()))
         eleType = refType.getElementType();
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 5340dd8c5fb9a2..9971dc8e0b0014 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -13,10 +13,10 @@
 #include "Utils.h"
 
 #include "Clauses.h"
-#include <DirectivesCommon.h>
 
 #include <flang/Lower/AbstractConverter.h>
 #include <flang/Lower/ConvertType.h>
+#include <flang/Lower/DirectivesCommon.h>
 #include <flang/Lower/PFTBuilder.h>
 #include <flang/Optimizer/Builder/FIRBuilder.h>
 #include <flang/Optimizer/Builder/Todo.h>
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 51ecbe1a664f92..4f23b2b970fa44 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -12,6 +12,7 @@ add_flang_library(FlangOpenMPTransforms
   FIRDialect
   HLFIROpsIncGen
   FlangOpenMPPassesIncGen
+  ${dialect_libs}
 
   LINK_LIBS
   FIRAnalysis
@@ -27,4 +28,5 @@ add_flang_library(FlangOpenMPTransforms
   MLIRIR
   MLIRPass
   MLIRTransformUtils
+  ${dialect_libs}
 )
diff --git a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
index 4575c90e34acdd..ad7b806ae262ae 100644
--- a/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapInfoFinalization.cpp
@@ -24,10 +24,14 @@
 /// indirectly via a parent object.
 //===----------------------------------------------------------------------===//
 
+#include "flang/Lower/DirectivesCommon.h"
 #include "flang/Optimizer/Builder/FIRBuilder.h"
+#include "flang/Optimizer/Builder/HLFIRTools.h"
 #include "flang/Optimizer/Dialect/FIRType.h"
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
+#include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/BuiltinDialect.h"
@@ -411,10 +415,10 @@ class MapInfoFinalizationPass
           argIface
               ? argIface.getMapBlockArgsStart() + argIface.numMapBlockArgs()
               : 0;
-      addOperands(
-          mapMutableOpRange,
-          llvm::dyn_cast_or_null<mlir::omp::TargetOp>(argIface.getOperation()),
-          blockArgInsertIndex);
+      addOperands(mapMutableOpRange,
+                  llvm::dyn_cast_if_present<mlir::omp::TargetOp>(
+                      argIface.getOperation()),
+                  blockArgInsertIndex);
     }
 
     if (auto targetDataOp = llvm::dyn_cast<mlir::omp::TargetDataOp>(target)) {
@@ -466,8 +470,7 @@ class MapInfoFinalizationPass
   // operation (usually function) containing the MapInfoOp because this pass
   // will mutate siblings of MapInfoOp.
   void runOnOperation() override {
-    mlir::ModuleOp module =
-        mlir::dyn_cast_or_null<mlir::ModuleOp>(getOperation());
+    mlir::ModuleOp module = getOperation();
     if (!module)
       module = getOperation()->getParentOfType<mlir::ModuleOp>();
     fir::KindMapping kindMap = fir::getKindMapping(module);
@@ -486,6 +489,160 @@ class MapInfoFinalizationPass
       // iterations from previous function scopes.
       localBoxAllocas.clear();
 
+      // First, walk `omp.map.info` ops to see if any record members should be
+      // implicitly mapped.
+      func->walk([&](mlir::omp::MapInfoOp op) {
+        mlir::Type underlyingType =
+            fir::unwrapRefType(op.getVarPtr().getType());
+
+        // TODO Test with and support more complicated cases; like arrays for
+        // records, for example.
+        if (!fir::isRecordWithAllocatableMember(underlyingType))
+          return mlir::WalkResult::advance();
+
+        // TODO For now, only consider `omp.target` ops. Other ops that support
+        // `map` clauses will follow later.
+        mlir::omp::TargetOp target =
+            mlir::dyn_cast_if_present<mlir::omp::TargetOp>(
+                getFirstTargetUser(op));
+
+        if (!target)
+          return mlir::WalkResult::advance();
+
+        auto mapClauseOwner =
+            llvm::dyn_cast<mlir::omp::MapClauseOwningOpInterface>(*target);
+
+        int64_t mapVarIdx = mapClauseOwner.getOperandIndexForMap(op);
+        assert(mapVarIdx >= 0 &&
+               mapVarIdx <
+                   static_cast<int64_t>(mapClauseOwner.getMapVars().size()));
+
+        auto argIface =
+            llvm::dyn_cast<mlir::omp::BlockArgOpenMPOpInterface>(*target);
+        // TODO How should `map` block argument that correspond to: `private`,
+        // `use_device_addr`, `use_device_ptr`, be handled?
+        mlir::BlockArgument opBlockArg = argIface.getMapBlockArgs()[mapVarIdx];
+        llvm::SetVector<mlir::Operation *> mapVarForwardSlice;
+        mlir::getForwardSlice(opBlockArg, &mapVarForwardSlice);
+
+        mapVarForwardSlice.remove_if([&](mlir::Operation *sliceOp) {
+          // TODO Support coordinate_of ops.
+          //
+          // TODO Support call ops by recursively examining the forward slice of
+          // the corresponding parameter to the field in the called function.
+          return !mlir::isa<hlfir::DesignateOp>(sliceOp);
+        });
+
+        auto recordType = mlir::cast<fir::RecordType>(underlyingType);
+        llvm::SmallVector<mlir::Value> newMapOpsForFields;
+        llvm::SmallVector<int64_t> fieldIndicies;
+
+        for (auto fieldMemTyPair : recordType.getTypeList()) {
+          auto &field = fieldMemTyPair.first;
+          auto memTy = fieldMemTyPair.second;
+
+          bool shouldMapField =
+              llvm::find_if(mapVarForwardSlice, [&](mlir::Operation *sliceOp) {
+                if (!fir::isAllocatableType(memTy))
+                  return false;
+
+                auto designateOp = mlir::dyn_cast<hlfir::DesignateOp>(sliceOp);
+                if (!designateOp)
+                  return false;
+
+                return designateOp.getComponent() &&
+                       designateOp.getComponent()->strref() == field;
+              }) != mapVarForwardSlice.end();
+
+          // TODO Handle recursive record types. Adapting
+          // `createParentSymAndGenIntermediateMaps` to work direclty on MLIR
+          // entities might be helpful here.
+
+          if (!shouldMapField)
+            continue;
+
+          int64_t fieldIdx = recordType.getFieldIndex(field);
+          bool alreadyMapped = [&]() {
+            if (op.getMembersIndexAttr())
+              for (auto indexList : op.getMembersIndexAttr()) {
+                auto indexListAttr = mlir::cast<mlir::ArrayAttr>(indexList);
+                if (indexListAttr.size() == 1 &&
+                    mlir::cast<mlir::IntegerAttr>(indexListAttr[0]).getInt() ==
+                        fieldIdx)
+                  return true;
+              }
+
+            return false;
+          }();
+
+          if (alreadyMapped)
+            continue;
+
+          builder.setInsertionPoint(op);
+          mlir::Value fieldIdxVal = builder.createIntegerConstant(
+              op.getLoc(), mlir::IndexType::get(builder.getContext()),
+              fieldIdx);
+          auto fieldCoord = builder.create<fir::CoordinateOp>(
+              op.getLoc(), builder.getRefType(memTy), op.getVarPtr(),
+              fieldIdxVal);
+          Fortran::lower::AddrAndBoundsInfo info =
+              Fortran::lower::getDataOperandBaseAddr(
+                  builder, fieldCoord, /*isOptional=*/false, op.getLoc());
+          llvm::SmallVector<mlir::Value> bounds =
+              Fortran::lower::genImplicitBoundsOps<mlir::omp::MapBoundsOp,
+                                                   mlir::omp::MapBoundsType>(
+                  builder, info,
+                  hlfir::translateToExtendedValue(op.getLoc(), builder,
+                                                  hlfir::Entity{fieldCoord})
+                      .first,
+                  /*dataExvIsAssumedSize=*/false, op.getLoc());
+
+          mlir::omp::MapInfoOp fieldMapOp =
+              builder.create<mlir::omp::MapInfoOp>(
+                  op.getLoc(), fieldCoord.getResult().getType(),
+                  fieldCoord.getResult(),
+                  mlir::TypeAttr::get(
+                      fir::unwrapRefType(fieldCoord.getResult().getType())),
+                  /*varPtrPtr=*/mlir::Value{},
+                  /*members=*/mlir::ValueRange{},
+                  /*members_index=*/mlir::ArrayAttr{},
+                  /*bounds=*/bounds, op.getMapTypeAttr(),
+                  builder.getAttr<mlir::omp::VariableCaptureKindAttr>(
+                      mlir::omp::VariableCaptureKind::ByRef),
+                  builder.getStringAttr(op.getNameAttr().strref() + "." +
+                                        field + ".implicit_map"),
+                  /*partial_map=*/builder.getBoolAttr(false));
+          newMapOpsForFields.emplace_back(fieldMapOp);
+          fieldIndicies.emplace_back(fieldIdx);
+        }
+
+        if (newMapOpsForFields.empty())
+          return mlir::WalkResult::advance();
+
+        op.getMembersMutable().append(newMapOpsForFields);
+        llvm::SmallVector<llvm::SmallVector<int64_t>> newMemberIndices;
+        mlir::ArrayAttr oldMembersIdxAttr = op.getMembersIndexAttr();
+
+        if (oldMembersIdxAttr)
+          for (mlir::Attribute indexList : oldMembersIdxAttr) {
+            llvm::SmallVector<int64_t> listVec;
+
+            for (mlir::Attribute index : mlir::cast<mlir::ArrayAttr>(indexList))
+              listVec.push_back(mlir::cast<mlir::IntegerAttr>(index).getInt());
+
+            newMemberIndices.emplace_back(std::move(listVec));
+          }
+
+        for (int64_t newFieldIdx : fieldIndicies)
+          newMemberIndices.emplace_back(
+              llvm::SmallVector<int64_t>(1, newFieldIdx));
+
+        op.setMembersIndexAttr(builder.create2DI64ArrayAttr(newMemberIndices));
+        op.setPartialMap(true);
+
+        return mlir::WalkResult::advance();
+      });
+
       func->walk([&](mlir::omp::MapInfoOp op) {
         // TODO: Currently only supports a single user for the MapInfoOp. This
         // is fine for the moment, as the Fortran frontend will generate a
diff --git a/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
new file mode 100644
index 00000000000000..bcf8b63075dbf8
--- /dev/null
+++ b/flang/test/Transforms/omp-map-info-finalization-implicit-field.fir
@@ -0,0 +1,63 @@
+// Tests that we implicitly map alloctable fields of a record when referenced in
+// a target region.
+
+// RUN: fir-opt --split-input-file --omp-map-info-finalization %s | FileCheck %s
+
+!record_t = !fir.type<_QFTrecord_t{
+  not_to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>,
+  to_implicitly_map:
+    !fir.box<!fir.heap<!fir.array<?xf32>>>
+}>
+
+fir.global internal @_QFEdst_record : !record_t {
+  %0 = fir.undefined !record_t
+  fir.has_value %0 : !record_t
+}
+
+func.func @_QQmain() {
+  %6 = fir.address_of(@_QFEdst_record) : !fir.ref<!record_t>
+  %7:2 = hlfir.declare %6 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (!fir.ref<!record_t>, !fir.ref<!record_t>)
+  %16 = omp.map.info var_ptr(%7#1 : !fir.ref<!record_t>, !record_t) map_clauses(implicit, tofrom) capture(ByRef) -> !fir.ref<!record_t> {name = "dst_record"}
+  omp.target map_entries(%16 -> %arg0 : !fir.ref<!record_t>) {
+    %20:2 = hlfir.declare %arg0 {uniq_name = "_QFEdst_record"} : (!fir.ref<!record_t>) -> (...
[truncated]

@ergawy
Copy link
Member Author

ergawy commented Dec 18, 2024

@kazutakahirata it would be great if you can run locally with the patch to verify it compiles 🙏.

@kazutakahirata
Copy link
Contributor

@kazutakahirata it would be great if you can run locally with the patch to verify it compiles 🙏.

Thanks! I'm on it. I'll let you know soon.

@kazutakahirata
Copy link
Contributor

@kazutakahirata it would be great if you can run locally with the patch to verify it compiles 🙏.

Thanks! I'm on it. I'll let you know soon.

With your new patch, my local tree builds fine (x86_64-linux + clang 16 + -DLLVM_ENABLE_WERROR=On). Both check-flang and check-mlir seem to be happy. Thanks!

Copy link
Contributor

@kazutakahirata kazutakahirata left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

@ergawy ergawy merged commit e532241 into llvm:main Dec 18, 2024
16 checks passed
vzakhari added a commit to vzakhari/llvm-project that referenced this pull request Dec 31, 2024
My local build with the shared libraries is broken.
I suppose this was introduced by llvm#120374.

`flang/include/flang/Evaluate/constant.h` ends up being included
by `MapInfoFinalization.cpp` via `flang/Lower/DirectivesCommon.h`.
The undefined references are related to `ConstantBase` classes.
vzakhari added a commit that referenced this pull request Jan 4, 2025
My local build with the shared libraries is broken.
I suppose this was introduced by #120374.

`flang/include/flang/Evaluate/constant.h` ends up being included
by `MapInfoFinalization.cpp` via `flang/Lower/DirectivesCommon.h`.
The undefined references are related to `ConstantBase` classes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants