Skip to content

Commit ecb1fba

Browse files
authored
[flang][openacc] Generate data bounds for array addressing. (#71254)
In cases like `copy(array(N))` it is still useful to represent the data operand uniformly with `copy(array(N:N))`. This change generates data bounds even if it is not an array section with the triplets. The lower and the upper bounds are the same and the extent is one in this case.
1 parent 4f31d32 commit ecb1fba

File tree

4 files changed

+95
-50
lines changed

4 files changed

+95
-50
lines changed

flang/lib/Lower/DirectivesCommon.h

+69-44
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
660660
Fortran::lower::StatementContext &stmtCtx,
661661
const std::list<Fortran::parser::SectionSubscript> &subscripts,
662662
std::stringstream &asFortran, fir::ExtendedValue &dataExv,
663-
mlir::Value baseAddr) {
663+
mlir::Value baseAddr, bool treatIndexAsSection = false) {
664664
int dimension = 0;
665665
mlir::Type idxTy = builder.getIndexType();
666666
mlir::Type boundTy = builder.getType<BoundsType>();
@@ -669,8 +669,9 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
669669
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
670670
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
671671
for (const auto &subscript : subscripts) {
672-
if (const auto *triplet{
673-
std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)}) {
672+
const auto *triplet{
673+
std::get_if<Fortran::parser::SubscriptTriplet>(&subscript.u)};
674+
if (triplet || treatIndexAsSection) {
674675
if (dimension != 0)
675676
asFortran << ',';
676677
mlir::Value lbound, ubound, extent;
@@ -689,9 +690,21 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
689690
strideInBytes = true;
690691
}
691692

692-
const auto &lower{std::get<0>(triplet->t)};
693+
const Fortran::lower::SomeExpr *lower{nullptr};
694+
if (triplet) {
695+
if (const auto &tripletLb{std::get<0>(triplet->t)})
696+
lower = Fortran::semantics::GetExpr(*tripletLb);
697+
} else {
698+
const auto &index{std::get<Fortran::parser::IntExpr>(subscript.u)};
699+
lower = Fortran::semantics::GetExpr(index);
700+
if (lower->Rank() > 0) {
701+
mlir::emitError(
702+
loc, "vector subscript cannot be used for an array section");
703+
break;
704+
}
705+
}
693706
if (lower) {
694-
lval = Fortran::semantics::GetIntValue(lower);
707+
lval = Fortran::evaluate::ToInt64(*lower);
695708
if (lval) {
696709
if (defaultLb) {
697710
lbound = builder.createIntegerConstant(loc, idxTy, *lval - 1);
@@ -701,59 +714,66 @@ genBoundsOps(fir::FirOpBuilder &builder, mlir::Location loc,
701714
}
702715
asFortran << *lval;
703716
} else {
704-
const Fortran::lower::SomeExpr *lexpr =
705-
Fortran::semantics::GetExpr(*lower);
706717
mlir::Value lb =
707-
fir::getBase(converter.genExprValue(loc, *lexpr, stmtCtx));
718+
fir::getBase(converter.genExprValue(loc, *lower, stmtCtx));
708719
lb = builder.createConvert(loc, baseLb.getType(), lb);
709720
lbound = builder.create<mlir::arith::SubIOp>(loc, lb, baseLb);
710-
asFortran << lexpr->AsFortran();
721+
asFortran << lower->AsFortran();
711722
}
712723
} else {
713724
// If the lower bound is not specified, then the section
714725
// starts from offset 0 of the dimension.
715726
// Note that the lowerbound in the BoundsOp is always 0-based.
716727
lbound = zero;
717728
}
718-
asFortran << ':';
719-
const auto &upper{std::get<1>(triplet->t)};
720-
if (upper) {
721-
uval = Fortran::semantics::GetIntValue(upper);
722-
if (uval) {
723-
if (defaultLb) {
724-
ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1);
729+
730+
if (!triplet) {
731+
// If it is a scalar subscript, then the upper bound
732+
// is equal to the lower bound, and the extent is one.
733+
ubound = lbound;
734+
extent = one;
735+
} else {
736+
asFortran << ':';
737+
const auto &upper{std::get<1>(triplet->t)};
738+
739+
if (upper) {
740+
uval = Fortran::semantics::GetIntValue(upper);
741+
if (uval) {
742+
if (defaultLb) {
743+
ubound = builder.createIntegerConstant(loc, idxTy, *uval - 1);
744+
} else {
745+
mlir::Value ub = builder.createIntegerConstant(loc, idxTy, *uval);
746+
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
747+
}
748+
asFortran << *uval;
725749
} else {
726-
mlir::Value ub = builder.createIntegerConstant(loc, idxTy, *uval);
750+
const Fortran::lower::SomeExpr *uexpr =
751+
Fortran::semantics::GetExpr(*upper);
752+
mlir::Value ub =
753+
fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
754+
ub = builder.createConvert(loc, baseLb.getType(), ub);
727755
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
756+
asFortran << uexpr->AsFortran();
728757
}
729-
asFortran << *uval;
730-
} else {
731-
const Fortran::lower::SomeExpr *uexpr =
732-
Fortran::semantics::GetExpr(*upper);
733-
mlir::Value ub =
734-
fir::getBase(converter.genExprValue(loc, *uexpr, stmtCtx));
735-
ub = builder.createConvert(loc, baseLb.getType(), ub);
736-
ubound = builder.create<mlir::arith::SubIOp>(loc, ub, baseLb);
737-
asFortran << uexpr->AsFortran();
738758
}
739-
}
740-
if (lower && upper) {
741-
if (lval && uval && *uval < *lval) {
742-
mlir::emitError(loc, "zero sized array section");
743-
break;
744-
} else if (std::get<2>(triplet->t)) {
745-
const auto &strideExpr{std::get<2>(triplet->t)};
746-
if (strideExpr) {
747-
mlir::emitError(loc, "stride cannot be specified on "
748-
"an array section");
759+
if (lower && upper) {
760+
if (lval && uval && *uval < *lval) {
761+
mlir::emitError(loc, "zero sized array section");
749762
break;
763+
} else if (std::get<2>(triplet->t)) {
764+
const auto &strideExpr{std::get<2>(triplet->t)};
765+
if (strideExpr) {
766+
mlir::emitError(loc, "stride cannot be specified on "
767+
"an array section");
768+
break;
769+
}
750770
}
751771
}
752-
}
753-
if (!ubound) {
754-
// ub = extent - 1
755-
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
756-
ubound = builder.create<mlir::arith::SubIOp>(loc, extent, one);
772+
if (!ubound) {
773+
// ub = extent - 1
774+
extent = fir::factory::readExtent(builder, loc, dataExv, dimension);
775+
ubound = builder.create<mlir::arith::SubIOp>(loc, extent, one);
776+
}
757777
}
758778
mlir::Value bound = builder.create<BoundsOp>(
759779
loc, boundTy, lbound, ubound, extent, stride, strideInBytes, baseLb);
@@ -770,15 +790,15 @@ mlir::Value gatherDataOperandAddrAndBounds(
770790
Fortran::semantics::SemanticsContext &semanticsContext,
771791
Fortran::lower::StatementContext &stmtCtx, const ObjectType &object,
772792
mlir::Location operandLocation, std::stringstream &asFortran,
773-
llvm::SmallVector<mlir::Value> &bounds) {
793+
llvm::SmallVector<mlir::Value> &bounds, bool treatIndexAsSection = false) {
774794
mlir::Value baseAddr;
775795

776796
std::visit(
777797
Fortran::common::visitors{
778798
[&](const Fortran::parser::Designator &designator) {
779799
if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext,
780800
designator)}) {
781-
if ((*expr).Rank() > 0 &&
801+
if (((*expr).Rank() > 0 || treatIndexAsSection) &&
782802
Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
783803
designator)) {
784804
const auto *arrayElement =
@@ -809,7 +829,8 @@ mlir::Value gatherDataOperandAddrAndBounds(
809829
asFortran << '(';
810830
bounds = genBoundsOps<BoundsType, BoundsOp>(
811831
builder, operandLocation, converter, stmtCtx,
812-
arrayElement->subscripts, asFortran, dataExv, baseAddr);
832+
arrayElement->subscripts, asFortran, dataExv, baseAddr,
833+
treatIndexAsSection);
813834
}
814835
asFortran << ')';
815836
} else if (Fortran::parser::Unwrap<
@@ -845,6 +866,10 @@ mlir::Value gatherDataOperandAddrAndBounds(
845866
if (Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
846867
designator)) {
847868
// Single array element.
869+
const auto *arrayElement =
870+
Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
871+
designator);
872+
(void)arrayElement;
848873
fir::ExtendedValue compExv =
849874
converter.genExprAddr(operandLocation, *expr, stmtCtx);
850875
baseAddr = fir::getBase(compExv);

flang/lib/Lower/OpenACC.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,8 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
265265
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
266266
Fortran::parser::AccObject, mlir::acc::DataBoundsType,
267267
mlir::acc::DataBoundsOp>(converter, builder, semanticsContext, stmtCtx,
268-
accObject, operandLocation, asFortran, bounds);
268+
accObject, operandLocation, asFortran, bounds,
269+
/*treatIndexAsSection=*/true);
269270
Op op = createDataEntryOp<Op>(builder, operandLocation, baseAddr, asFortran,
270271
bounds, structured, implicit, dataClause,
271272
baseAddr.getType());

flang/lib/Lower/OpenMP.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@
2929
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3030
#include "mlir/Dialect/SCF/IR/SCF.h"
3131
#include "llvm/Frontend/OpenMP/OMPConstants.h"
32+
#include "llvm/Support/CommandLine.h"
33+
34+
static llvm::cl::opt<bool> treatIndexAsSection(
35+
"openmp-treat-index-as-section",
36+
llvm::cl::desc("In the OpenMP data clauses treat `a(N)` as `a(N:N)`."),
37+
llvm::cl::init(true));
3238

3339
using DeclareTargetCapturePair =
3440
std::pair<mlir::omp::DeclareTargetCaptureClause,
@@ -1788,9 +1794,9 @@ bool ClauseProcessor::processMap(
17881794
std::stringstream asFortran;
17891795
mlir::Value baseAddr = Fortran::lower::gatherDataOperandAddrAndBounds<
17901796
Fortran::parser::OmpObject, mlir::omp::DataBoundsType,
1791-
mlir::omp::DataBoundsOp>(converter, firOpBuilder,
1792-
semanticsContext, stmtCtx, ompObject,
1793-
clauseLocation, asFortran, bounds);
1797+
mlir::omp::DataBoundsOp>(
1798+
converter, firOpBuilder, semanticsContext, stmtCtx, ompObject,
1799+
clauseLocation, asFortran, bounds, treatIndexAsSection);
17941800

17951801
// Explicit map captures are captured ByRef by default,
17961802
// optimisation passes may alter this to ByCopy or other capture

flang/test/Lower/OpenACC/acc-enter-data.f90

+15-2
Original file line numberDiff line numberDiff line change
@@ -813,7 +813,20 @@ subroutine acc_enter_data_single_array_element()
813813

814814
!$acc enter data create(e(2)%a(1,2))
815815

816-
!CHECK: %[[CREATE:.*]] = acc.create varPtr(%{{.*}} : !fir.ref<f32>) -> !fir.ref<f32> {name = "e(2_8)%a(1_8,2_8)", structured = false}
817-
!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.ref<f32>)
816+
!CHECK-LABEL: func.func @_QPacc_enter_data_single_array_element() {
817+
!CHECK-DAG: %[[VAL_38:.*]]:3 = fir.box_dims %[[BOX:.*]], %[[VAL_37:.*]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>, index) -> (index, index, index)
818+
!CHECK-DAG: %[[VAL_37]] = arith.constant 0 : index
819+
!CHECK-DAG: %[[VAL_40:.*]]:3 = fir.box_dims %[[BOX]], %[[VAL_39:.*]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>, index) -> (index, index, index)
820+
!CHECK-DAG: %[[VAL_39]] = arith.constant 1 : index
821+
!CHECK-DAG: %[[VAL_41:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.heap<!fir.array<?x?xf32>>>) -> !fir.heap<!fir.array<?x?xf32>>
822+
!CHECK: %[[VAL_42:.*]] = arith.constant 1 : index
823+
!CHECK: %[[VAL_43:.*]] = arith.constant 1 : index
824+
!CHECK: %[[VAL_44:.*]] = arith.subi %[[VAL_43]], %[[VAL_38]]#0 : index
825+
!CHECK: %[[VAL_45:.*]] = acc.bounds lowerbound(%[[VAL_44]] : index) upperbound(%[[VAL_44]] : index) extent(%[[VAL_42]] : index) stride(%[[VAL_42]] : index) startIdx(%[[VAL_38]]#0 : index)
826+
!CHECK: %[[VAL_46:.*]] = arith.constant 2 : index
827+
!CHECK: %[[VAL_47:.*]] = arith.subi %[[VAL_46]], %[[VAL_40]]#0 : index
828+
!CHECK: %[[VAL_48:.*]] = acc.bounds lowerbound(%[[VAL_47]] : index) upperbound(%[[VAL_47]] : index) extent(%[[VAL_42]] : index) stride(%[[VAL_42]] : index) startIdx(%[[VAL_40]]#0 : index)
829+
!CHECK: %[[CREATE:.*]] = acc.create varPtr(%[[VAL_41]] : !fir.heap<!fir.array<?x?xf32>>) bounds(%[[VAL_45]], %[[VAL_48]]) -> !fir.heap<!fir.array<?x?xf32>> {name = "e(2_8)%a(1,2)", structured = false}
830+
!CHECK: acc.enter_data dataOperands(%[[CREATE]] : !fir.heap<!fir.array<?x?xf32>>)
818831

819832
end subroutine

0 commit comments

Comments
 (0)