diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp index 424566462e8fe..cba1bfc74e922 100644 --- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp +++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp @@ -660,10 +660,7 @@ OrderedAssignmentRewriter::generateYieldedEntity( return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType); } - assert(region.hasOneBlock() && "region must contain one block"); auto oldYield = getYield(region); - mlir::Block::OpListType &ops = region.back().getOperations(); - // Inside Forall, scalars that do not depend on forall indices can be hoisted // here because their evaluation is required to only call pure procedures, and // if they depend on a variable previously assigned to in a forall assignment, @@ -674,24 +671,24 @@ OrderedAssignmentRewriter::generateYieldedEntity( bool hoistComputation = false; if (fir::isa_trivial(oldYield.getEntity().getType()) && !constructStack.empty()) { - hoistComputation = true; - for (mlir::Operation &op : ops) - if (llvm::any_of(op.getOperands(), [](mlir::Value value) { - return isForallIndex(value); - })) { - hoistComputation = false; - break; - } + mlir::WalkResult walkResult = + region.walk([&](mlir::Operation *op) -> mlir::WalkResult { + if (llvm::any_of(op->getOperands(), [](mlir::Value value) { + return isForallIndex(value); + })) + return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); + }); + hoistComputation = !walkResult.wasInterrupted(); } auto insertionPoint = builder.saveInsertionPoint(); if (hoistComputation) builder.setInsertionPoint(constructStack[0]); // Clone all operations except the final hlfir.yield. - assert(!ops.empty() && "yield block cannot be empty"); - auto end = ops.end(); - for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt) - (void)builder.clone(*opIt, mapper); + assert(region.hasOneBlock() && "region must contain one block"); + for (auto &op : region.back().without_terminator()) + (void)builder.clone(op, mapper); // Get the value for the yielded entity, it may be the result of an operation // that was cloned, or it may be the same as the previous value if the yield // operand was created before the ordered assignment tree. diff --git a/flang/test/HLFIR/order_assignments/forall-issue120190.fir b/flang/test/HLFIR/order_assignments/forall-issue120190.fir new file mode 100644 index 0000000000000..ca10bbfefad57 --- /dev/null +++ b/flang/test/HLFIR/order_assignments/forall-issue120190.fir @@ -0,0 +1,64 @@ +// Regression test for https://github.com/llvm/llvm-project/issues/120190 +// Verify that hlfir.forall lowering does not try hoisting mask evaluation +// that refer to the forall index inside nested regions only. +// RUN: fir-opt %s --lower-hlfir-ordered-assignments | FileCheck %s + +func.func @issue120190(%array: !fir.ref>, %cdt: i1) { + %cst = arith.constant 0.000000e+00 : f32 + %c1 = arith.constant 1 : i64 + %c50 = arith.constant 50 : i64 + %c100 = arith.constant 100 : i64 + hlfir.forall lb { + hlfir.yield %c1 : i64 + } ub { + hlfir.yield %c100 : i64 + } (%forall_index: i64) { + hlfir.forall_mask { + %mask = fir.if %cdt -> i1 { + // Reference to %forall_index is not directly in + // hlfir.forall_mask region, but is nested. + %res = arith.cmpi slt, %forall_index, %c50 : i64 + fir.result %res : i1 + } else { + %res = arith.cmpi sgt, %forall_index, %c50 : i64 + fir.result %res : i1 + } + hlfir.yield %mask : i1 + } do { + hlfir.region_assign { + hlfir.yield %cst : f32 + } to { + %6 = hlfir.designate %array (%forall_index) : (!fir.ref>, i64) -> !fir.ref + hlfir.yield %6 : !fir.ref + } + } + } + return +} + +// CHECK-LABEL: func.func @issue120190( +// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref>, +// CHECK-SAME: %[[VAL_1:.*]]: i1) { +// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64 +// CHECK: %[[VAL_4:.*]] = arith.constant 50 : i64 +// CHECK: %[[VAL_5:.*]] = arith.constant 100 : i64 +// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i64) -> index +// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_5]] : (i64) -> index +// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index +// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] { +// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (index) -> i64 +// CHECK: %[[VAL_11:.*]] = fir.if %[[VAL_1]] -> (i1) { +// CHECK: %[[VAL_12:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_4]] : i64 +// CHECK: fir.result %[[VAL_12]] : i1 +// CHECK: } else { +// CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_10]], %[[VAL_4]] : i64 +// CHECK: fir.result %[[VAL_13]] : i1 +// CHECK: } +// CHECK: fir.if %[[VAL_11]] { +// CHECK: %[[VAL_14:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_10]]) : (!fir.ref>, i64) -> !fir.ref +// CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_14]] : f32, !fir.ref +// CHECK: } +// CHECK: } +// CHECK: return +// CHECK: }