Skip to content

Commit a7bbcc4

Browse files
[RISCV][GISEL] Add support for lowerFormalArguments that contain scalable vector types (#70882)
Scalable vector types from LLVM IR can be lowered to scalable vector types in MIR according to the RISCVAssignFn.
1 parent 506a30d commit a7bbcc4

File tree

9 files changed

+984
-9
lines changed

9 files changed

+984
-9
lines changed

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
358358
if (PartLLT.isVector() == LLTy.isVector() &&
359359
PartLLT.getScalarSizeInBits() > LLTy.getScalarSizeInBits() &&
360360
(!PartLLT.isVector() ||
361-
PartLLT.getNumElements() == LLTy.getNumElements()) &&
361+
PartLLT.getElementCount() == LLTy.getElementCount()) &&
362362
OrigRegs.size() == 1 && Regs.size() == 1) {
363363
Register SrcReg = Regs[0];
364364

@@ -406,6 +406,7 @@ static void buildCopyFromRegs(MachineIRBuilder &B, ArrayRef<Register> OrigRegs,
406406
// If PartLLT is a mismatched vector in both number of elements and element
407407
// size, e.g. PartLLT == v2s64 and LLTy is v3s32, then first coerce it to
408408
// have the same elt type, i.e. v4s32.
409+
// TODO: Extend this coersion to element multiples other than just 2.
409410
if (PartLLT.getSizeInBits() > LLTy.getSizeInBits() &&
410411
PartLLT.getScalarSizeInBits() == LLTy.getScalarSizeInBits() * 2 &&
411412
Regs.size() == 1) {

llvm/lib/CodeGen/GlobalISel/MachineIRBuilder.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -1065,16 +1065,16 @@ void MachineIRBuilder::validateTruncExt(const LLT DstTy, const LLT SrcTy,
10651065
#ifndef NDEBUG
10661066
if (DstTy.isVector()) {
10671067
assert(SrcTy.isVector() && "mismatched cast between vector and non-vector");
1068-
assert(SrcTy.getNumElements() == DstTy.getNumElements() &&
1068+
assert(SrcTy.getElementCount() == DstTy.getElementCount() &&
10691069
"different number of elements in a trunc/ext");
10701070
} else
10711071
assert(DstTy.isScalar() && SrcTy.isScalar() && "invalid extend/trunc");
10721072

10731073
if (IsExtend)
1074-
assert(DstTy.getSizeInBits() > SrcTy.getSizeInBits() &&
1074+
assert(TypeSize::isKnownGT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
10751075
"invalid narrowing extend");
10761076
else
1077-
assert(DstTy.getSizeInBits() < SrcTy.getSizeInBits() &&
1077+
assert(TypeSize::isKnownLT(DstTy.getSizeInBits(), SrcTy.getSizeInBits()) &&
10781078
"invalid widening trunc");
10791079
#endif
10801080
}

llvm/lib/CodeGen/LowLevelType.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using namespace llvm;
1717

1818
LLT::LLT(MVT VT) {
1919
if (VT.isVector()) {
20-
bool asVector = VT.getVectorMinNumElements() > 1;
20+
bool asVector = VT.getVectorMinNumElements() > 1 || VT.isScalableVector();
2121
init(/*IsPointer=*/false, asVector, /*IsScalar=*/!asVector,
2222
VT.getVectorElementCount(), VT.getVectorElementType().getSizeInBits(),
2323
/*AddressSpace=*/0);

llvm/lib/CodeGen/MachineVerifier.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ bool MachineVerifier::verifyVectorElementMatch(LLT Ty0, LLT Ty1,
965965
return false;
966966
}
967967

968-
if (Ty0.isVector() && Ty0.getNumElements() != Ty1.getNumElements()) {
968+
if (Ty0.isVector() && Ty0.getElementCount() != Ty1.getElementCount()) {
969969
report("operand types must preserve number of vector elements", MI);
970970
return false;
971971
}

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

+35-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "RISCVCallLowering.h"
1616
#include "RISCVISelLowering.h"
17+
#include "RISCVMachineFunctionInfo.h"
1718
#include "RISCVSubtarget.h"
1819
#include "llvm/CodeGen/Analysis.h"
1920
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
@@ -185,6 +186,9 @@ struct RISCVIncomingValueAssigner : public CallLowering::IncomingValueAssigner {
185186
const DataLayout &DL = MF.getDataLayout();
186187
const RISCVSubtarget &Subtarget = MF.getSubtarget<RISCVSubtarget>();
187188

189+
if (LocVT.isScalableVector())
190+
MF.getInfo<RISCVMachineFunctionInfo>()->setIsVectorCall();
191+
188192
if (RISCVAssignFn(DL, Subtarget.getTargetABI(), ValNo, ValVT, LocVT,
189193
LocInfo, Flags, State, /*IsFixed=*/true, IsRet, Info.Ty,
190194
*Subtarget.getTargetLowering(),
@@ -301,8 +305,31 @@ struct RISCVCallReturnHandler : public RISCVIncomingValueHandler {
301305
RISCVCallLowering::RISCVCallLowering(const RISCVTargetLowering &TLI)
302306
: CallLowering(&TLI) {}
303307

308+
/// Return true if scalable vector with ScalarTy is legal for lowering.
309+
static bool isLegalElementTypeForRVV(Type *EltTy,
310+
const RISCVSubtarget &Subtarget) {
311+
if (EltTy->isPointerTy())
312+
return Subtarget.is64Bit() ? Subtarget.hasVInstructionsI64() : true;
313+
if (EltTy->isIntegerTy(1) || EltTy->isIntegerTy(8) ||
314+
EltTy->isIntegerTy(16) || EltTy->isIntegerTy(32))
315+
return true;
316+
if (EltTy->isIntegerTy(64))
317+
return Subtarget.hasVInstructionsI64();
318+
if (EltTy->isHalfTy())
319+
return Subtarget.hasVInstructionsF16();
320+
if (EltTy->isBFloatTy())
321+
return Subtarget.hasVInstructionsBF16();
322+
if (EltTy->isFloatTy())
323+
return Subtarget.hasVInstructionsF32();
324+
if (EltTy->isDoubleTy())
325+
return Subtarget.hasVInstructionsF64();
326+
return false;
327+
}
328+
304329
// TODO: Support all argument types.
305-
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
330+
// TODO: Remove IsLowerArgs argument by adding support for vectors in lowerCall.
331+
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget,
332+
bool IsLowerArgs = false) {
306333
// TODO: Integers larger than 2*XLen are passed indirectly which is not
307334
// supported yet.
308335
if (T->isIntegerTy())
@@ -311,6 +338,11 @@ static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget) {
311338
return true;
312339
if (T->isPointerTy())
313340
return true;
341+
// TODO: Support fixed vector types.
342+
if (IsLowerArgs && T->isVectorTy() && Subtarget.hasVInstructions() &&
343+
T->isScalableTy() &&
344+
isLegalElementTypeForRVV(T->getScalarType(), Subtarget))
345+
return true;
314346
return false;
315347
}
316348

@@ -398,7 +430,8 @@ bool RISCVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
398430
const RISCVSubtarget &Subtarget =
399431
MIRBuilder.getMF().getSubtarget<RISCVSubtarget>();
400432
for (auto &Arg : F.args()) {
401-
if (!isSupportedArgumentType(Arg.getType(), Subtarget))
433+
if (!isSupportedArgumentType(Arg.getType(), Subtarget,
434+
/*IsLowerArgs=*/true))
402435
return false;
403436
}
404437

llvm/test/CodeGen/RISCV/GlobalISel/irtranslator/fallback.ll

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ declare <vscale x 1 x i8> @llvm.riscv.vadd.nxv1i8.nxv1i8(
99
<vscale x 1 x i8>,
1010
i64)
1111

12-
; FALLBACK-WITH-REPORT-ERR: remark: <unknown>:0:0: unable to lower arguments{{.*}}scalable_arg
12+
; FALLBACK_WITH_REPORT_ERR: <unknown>:0:0: unable to translate instruction: call:
1313
; FALLBACK-WITH-REPORT-OUT-LABEL: scalable_arg
1414
define <vscale x 1 x i8> @scalable_arg(<vscale x 1 x i8> %0, <vscale x 1 x i8> %1, i64 %2) nounwind {
1515
entry:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
2+
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
3+
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
4+
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
5+
6+
; The purpose of this test is to show that the compiler throws an error when
7+
; there is no support for bf16 vectors. If the compiler did not throw an error,
8+
; then it will try to scalarize the argument to an s32, which may drop elements.
9+
define void @test_args_nxv1bf16(<vscale x 1 x bfloat> %a) {
10+
entry:
11+
ret void
12+
}
13+
14+
; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1bf16)
15+
16+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
; RUN: not --crash llc -mtriple=riscv32 -mattr=+v -global-isel -stop-after=irtranslator \
2+
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
3+
; RUN: not --crash llc -mtriple=riscv64 -mattr=+v -global-isel -stop-after=irtranslator \
4+
; RUN: -verify-machineinstrs < %s 2>&1 | FileCheck %s
5+
6+
; The purpose of this test is to show that the compiler throws an error when
7+
; there is no support for f16 vectors. If the compiler did not throw an error,
8+
; then it will try to scalarize the argument to an s32, which may drop elements.
9+
define void @test_args_nxv1f16(<vscale x 1 x half> %a) {
10+
entry:
11+
ret void
12+
}
13+
14+
; CHECK: LLVM ERROR: unable to lower arguments: ptr (in function: test_args_nxv1f16)
15+
16+

0 commit comments

Comments
 (0)