Skip to content

Commit 44aa476

Browse files
authored
[flang] AArch64 ABI for BIND(C) VALUE parameters (#118305)
This patch adds handling for derived type VALUE parameters in BIND(C) functions for AArch64.
1 parent 3666de9 commit 44aa476

File tree

2 files changed

+193
-26
lines changed

2 files changed

+193
-26
lines changed

flang/lib/Optimizer/CodeGen/Target.cpp

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,8 @@ struct TargetX86_64Win : public GenericTarget<TargetX86_64Win> {
788788
//===----------------------------------------------------------------------===//
789789

790790
namespace {
791+
// AArch64 procedure call standard:
792+
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
791793
struct TargetAArch64 : public GenericTarget<TargetAArch64> {
792794
using GenericTarget::GenericTarget;
793795

@@ -826,7 +828,7 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
826828
return marshal;
827829
}
828830

829-
// Flatten a RecordType::TypeList containing more record types or array types
831+
// Flatten a RecordType::TypeList containing more record types or array type
830832
static std::optional<std::vector<mlir::Type>>
831833
flattenTypeList(const RecordType::TypeList &types) {
832834
std::vector<mlir::Type> flatTypes;
@@ -870,52 +872,144 @@ struct TargetAArch64 : public GenericTarget<TargetAArch64> {
870872

871873
// Determine if the type is a Homogenous Floating-point Aggregate (HFA). An
872874
// HFA is a record type with up to 4 floating-point members of the same type.
873-
static bool isHFA(fir::RecordType ty) {
875+
static std::optional<int> usedRegsForHFA(fir::RecordType ty) {
874876
RecordType::TypeList types = ty.getTypeList();
875877
if (types.empty() || types.size() > 4)
876-
return false;
878+
return std::nullopt;
877879

878880
std::optional<std::vector<mlir::Type>> flatTypes = flattenTypeList(types);
879881
if (!flatTypes || flatTypes->size() > 4) {
880-
return false;
882+
return std::nullopt;
881883
}
882884

883885
if (!isa_real(flatTypes->front())) {
884-
return false;
886+
return std::nullopt;
885887
}
886888

887-
return llvm::all_equal(*flatTypes);
889+
return llvm::all_equal(*flatTypes) ? std::optional<int>{flatTypes->size()}
890+
: std::nullopt;
888891
}
889892

890-
// AArch64 procedure call ABI:
891-
// https://github.com/ARM-software/abi-aa/blob/main/aapcs64/aapcs64.rst#parameter-passing
892-
CodeGenSpecifics::Marshalling
893-
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
894-
CodeGenSpecifics::Marshalling marshal;
893+
struct NRegs {
894+
int n{0};
895+
bool isSimd{false};
896+
};
895897

896-
if (isHFA(ty)) {
897-
// Just return the existing record type
898-
marshal.emplace_back(ty, AT{});
899-
return marshal;
898+
NRegs usedRegsForRecordType(mlir::Location loc, fir::RecordType type) const {
899+
if (std::optional<int> size = usedRegsForHFA(type))
900+
return {*size, true};
901+
902+
auto [size, align] = fir::getTypeSizeAndAlignmentOrCrash(
903+
loc, type, getDataLayout(), kindMap);
904+
905+
if (size <= 16)
906+
return {static_cast<int>((size + 7) / 8), false};
907+
908+
// Pass on the stack, i.e. no registers used
909+
return {};
910+
}
911+
912+
NRegs usedRegsForType(mlir::Location loc, mlir::Type type) const {
913+
return llvm::TypeSwitch<mlir::Type, NRegs>(type)
914+
.Case<mlir::IntegerType>([&](auto intTy) {
915+
return intTy.getWidth() == 128 ? NRegs{2, false} : NRegs{1, false};
916+
})
917+
.Case<mlir::FloatType>([&](auto) { return NRegs{1, true}; })
918+
.Case<mlir::ComplexType>([&](auto) { return NRegs{2, true}; })
919+
.Case<fir::LogicalType>([&](auto) { return NRegs{1, false}; })
920+
.Case<fir::CharacterType>([&](auto) { return NRegs{1, false}; })
921+
.Case<fir::SequenceType>([&](auto ty) {
922+
assert(ty.getShape().size() == 1 &&
923+
"invalid array dimensions in BIND(C)");
924+
NRegs nregs = usedRegsForType(loc, ty.getEleTy());
925+
nregs.n *= ty.getShape()[0];
926+
return nregs;
927+
})
928+
.Case<fir::RecordType>(
929+
[&](auto ty) { return usedRegsForRecordType(loc, ty); })
930+
.Case<fir::VectorType>([&](auto) {
931+
TODO(loc, "passing vector argument to C by value is not supported");
932+
return NRegs{};
933+
});
934+
}
935+
936+
bool hasEnoughRegisters(mlir::Location loc, fir::RecordType type,
937+
const Marshalling &previousArguments) const {
938+
int availIntRegisters = 8;
939+
int availSIMDRegisters = 8;
940+
941+
// Check previous arguments to see how many registers are used already
942+
for (auto [type, attr] : previousArguments) {
943+
if (availIntRegisters <= 0 || availSIMDRegisters <= 0)
944+
break;
945+
946+
if (attr.isByVal())
947+
continue; // Previous argument passed on the stack
948+
949+
NRegs nregs = usedRegsForType(loc, type);
950+
if (nregs.isSimd)
951+
availSIMDRegisters -= nregs.n;
952+
else
953+
availIntRegisters -= nregs.n;
900954
}
901955

902-
auto [size, align] =
956+
NRegs nregs = usedRegsForRecordType(loc, type);
957+
958+
if (nregs.isSimd)
959+
return nregs.n <= availSIMDRegisters;
960+
961+
return nregs.n <= availIntRegisters;
962+
}
963+
964+
CodeGenSpecifics::Marshalling
965+
passOnTheStack(mlir::Location loc, mlir::Type ty, bool isResult) const {
966+
CodeGenSpecifics::Marshalling marshal;
967+
auto sizeAndAlign =
903968
fir::getTypeSizeAndAlignmentOrCrash(loc, ty, getDataLayout(), kindMap);
969+
// The stack is always 8 byte aligned
970+
unsigned short align =
971+
std::max(sizeAndAlign.second, static_cast<unsigned short>(8));
972+
marshal.emplace_back(fir::ReferenceType::get(ty),
973+
AT{align, /*byval=*/!isResult, /*sret=*/isResult});
974+
return marshal;
975+
}
904976

905-
// return in registers if size <= 16 bytes
906-
if (size <= 16) {
907-
std::size_t dwordSize = (size + 7) / 8;
908-
auto newTy = fir::SequenceType::get(
909-
dwordSize, mlir::IntegerType::get(ty.getContext(), 64));
910-
marshal.emplace_back(newTy, AT{});
911-
return marshal;
977+
CodeGenSpecifics::Marshalling
978+
structType(mlir::Location loc, fir::RecordType type, bool isResult) const {
979+
NRegs nregs = usedRegsForRecordType(loc, type);
980+
981+
// If the type needs no registers it must need to be passed on the stack
982+
if (nregs.n == 0)
983+
return passOnTheStack(loc, type, isResult);
984+
985+
CodeGenSpecifics::Marshalling marshal;
986+
987+
mlir::Type pcsType;
988+
if (nregs.isSimd) {
989+
pcsType = type;
990+
} else {
991+
pcsType = fir::SequenceType::get(
992+
nregs.n, mlir::IntegerType::get(type.getContext(), 64));
912993
}
913994

914-
unsigned short stackAlign = std::max<unsigned short>(align, 8u);
915-
marshal.emplace_back(fir::ReferenceType::get(ty),
916-
AT{stackAlign, false, true});
995+
marshal.emplace_back(pcsType, AT{});
917996
return marshal;
918997
}
998+
999+
CodeGenSpecifics::Marshalling
1000+
structArgumentType(mlir::Location loc, fir::RecordType ty,
1001+
const Marshalling &previousArguments) const override {
1002+
if (!hasEnoughRegisters(loc, ty, previousArguments)) {
1003+
return passOnTheStack(loc, ty, /*isResult=*/false);
1004+
}
1005+
1006+
return structType(loc, ty, /*isResult=*/false);
1007+
}
1008+
1009+
CodeGenSpecifics::Marshalling
1010+
structReturnType(mlir::Location loc, fir::RecordType ty) const override {
1011+
return structType(loc, ty, /*isResult=*/true);
1012+
}
9191013
};
9201014
} // namespace
9211015

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
// Test AArch64 ABI rewrite of struct passed by value (BIND(C), VALUE derived types).
2+
// RUN: fir-opt --target-rewrite="target=aarch64-unknown-linux-gnu" %s | FileCheck %s
3+
4+
// CHECK-LABEL: func.func private @small_i32(!fir.array<2xi64>)
5+
func.func private @small_i32(!fir.type<small_i32{i:i32,j:i32,k:i32}>)
6+
// CHECK-LABEL: func.func private @small_i64(!fir.array<2xi64>)
7+
func.func private @small_i64(!fir.type<small_i64{i:i64,j:i64}>)
8+
// CHECK-LABEL: func.func private @small_mixed(!fir.array<2xi64>)
9+
func.func private @small_mixed(!fir.type<small_mixed{i:i64,j:f32,k:i32}>)
10+
// CHECK-LABEL: func.func private @small_non_hfa(!fir.array<2xi64>)
11+
func.func private @small_non_hfa(!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>)
12+
13+
// CHECK-LABEL: func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
14+
func.func private @hfa_f16(!fir.type<hfa_f16{i:f16,j:f16}>)
15+
// CHECK-LABEL: func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
16+
func.func private @hfa_bf16(!fir.type<hfa_bf16{i:bf16,j:bf16,k:bf16,l:bf16}>)
17+
// CHECK-LABEL: func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
18+
func.func private @hfa_f32(!fir.type<hfa_f32{i:f32,j:f32}>)
19+
// CHECK-LABEL: func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
20+
func.func private @hfa_f64(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>)
21+
// CHECK-LABEL: func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
22+
func.func private @hfa_f128(!fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
23+
24+
// CHECK-LABEL: func.func private @multi_small_integer(!fir.array<2xi64>, !fir.array<2xi64>)
25+
func.func private @multi_small_integer(!fir.type<small_i32{i:i32,j:i32,k:i32}>, !fir.type<small_i64{i:i64,j:i64}>)
26+
// CHECK-LABEL: func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
27+
func.func private @multi_hfas(!fir.type<hfa_f16{i:f16,j:f16}>, !fir.type<hfa_f128{i:f128,j:f128,k:f128,l:f128}>)
28+
// CHECK-LABEL: func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>, !fir.array<2xi64>, !fir.type<hfa_f32{i:f32,j:f32}>, !fir.array<2xi64>)
29+
func.func private @multi_mixed(!fir.type<hfa_f64{i:f64,j:f64,k:f64}>,!fir.type<small_non_hfa{i:f64,j:f32,k:f16}>,!fir.type<hfa_f32{i:f32,j:f32}>,!fir.type<small_i64{i:i64,j:i64}>)
30+
31+
// CHECK-LABEL: func.func private @int_max(!fir.array<2xi64>,
32+
// CHECK-SAME: !fir.array<2xi64>,
33+
// CHECK-SAME: !fir.array<2xi64>,
34+
// CHECK-SAME: !fir.array<2xi64>)
35+
func.func private @int_max(!fir.type<int_max{i:i64,j:i64}>,
36+
!fir.type<int_max{i:i64,j:i64}>,
37+
!fir.type<int_max{i:i64,j:i64}>,
38+
!fir.type<int_max{i:i64,j:i64}>)
39+
// CHECK-LABEL: func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
40+
func.func private @hfa_max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>, !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
41+
// CHECK-LABEL: func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
42+
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
43+
// CHECK-SAME: !fir.array<2xi64>,
44+
// CHECK-SAME: !fir.array<2xi64>,
45+
// CHECK-SAME: !fir.array<2xi64>,
46+
// CHECK-SAME: !fir.array<2xi64>)
47+
func.func private @max(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
48+
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
49+
!fir.type<int_max{i:i64,j:i64}>,
50+
!fir.type<int_max{i:i64,j:i64}>,
51+
!fir.type<int_max{i:i64,j:i64}>,
52+
!fir.type<int_max{i:i64,j:i64}>)
53+
54+
55+
// CHECK-LABEL: func.func private @too_many_int(!fir.array<2xi64>,
56+
// CHECK-SAME: !fir.array<2xi64>,
57+
// CHECK-SAME: !fir.array<2xi64>,
58+
// CHECK-SAME: !fir.array<2xi64>,
59+
// CHECK-SAME: !fir.ref<!fir.type<int_max{i:i64,j:i64}>> {{{.*}}, llvm.byval = !fir.type<int_max{i:i64,j:i64}>})
60+
func.func private @too_many_int(!fir.type<int_max{i:i64,j:i64}>,
61+
!fir.type<int_max{i:i64,j:i64}>,
62+
!fir.type<int_max{i:i64,j:i64}>,
63+
!fir.type<int_max{i:i64,j:i64}>,
64+
!fir.type<int_max{i:i64,j:i64}>)
65+
// CHECK-LABEL: func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
66+
// CHECK-SAME: !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
67+
// CHECK-SAME: !fir.ref<!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>> {{{.*}}, llvm.byval = !fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>})
68+
func.func private @too_many_hfa(!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
69+
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>,
70+
!fir.type<hfa_max{i:f128,j:f128,k:f128,l:f128}>)
71+
72+
// CHECK-LABEL: func.func private @too_big(!fir.ref<!fir.type<too_big{i:!fir.array<5xi32>}>> {{{.*}}, llvm.byval = !fir.type<too_big{i:!fir.array<5xi32>}>})
73+
func.func private @too_big(!fir.type<too_big{i:!fir.array<5xi32>}>)

0 commit comments

Comments
 (0)