Skip to content

[RISCV] Allow folding vmerge into masked ops when mask is the same #97989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3523,24 +3523,26 @@ bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) {
return false;
}

static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// After ISel, a vector pseudo's mask will be copied to V0 via a CopyToReg
// that's glued to the pseudo. This tries to look up the value that was copied
// to V0.
static SDValue getMaskSetter(SDValue MaskOp, SDValue GlueOp) {
// Check that we're using V0 as a mask register.
if (!isa<RegisterSDNode>(MaskOp) ||
cast<RegisterSDNode>(MaskOp)->getReg() != RISCV::V0)
return false;
return SDValue();

// The glued user defines V0.
const auto *Glued = GlueOp.getNode();

if (!Glued || Glued->getOpcode() != ISD::CopyToReg)
return false;
return SDValue();

// Check that we're defining V0 as a mask register.
if (!isa<RegisterSDNode>(Glued->getOperand(1)) ||
cast<RegisterSDNode>(Glued->getOperand(1))->getReg() != RISCV::V0)
return false;
return SDValue();

// Check the instruction defining V0; it needs to be a VMSET pseudo.
SDValue MaskSetter = Glued->getOperand(2);

// Sometimes the VMSET is wrapped in a COPY_TO_REGCLASS, e.g. if the mask came
Expand All @@ -3549,6 +3551,15 @@ static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
MaskSetter->getMachineOpcode() == RISCV::COPY_TO_REGCLASS)
MaskSetter = MaskSetter->getOperand(0);

return MaskSetter;
}

static bool usesAllOnesMask(SDValue MaskOp, SDValue GlueOp) {
// Check the instruction defining V0; it needs to be a VMSET pseudo.
SDValue MaskSetter = getMaskSetter(MaskOp, GlueOp);
if (!MaskSetter)
return false;

const auto IsVMSet = [](unsigned Opc) {
return Opc == RISCV::PseudoVMSET_M_B1 || Opc == RISCV::PseudoVMSET_M_B16 ||
Opc == RISCV::PseudoVMSET_M_B2 || Opc == RISCV::PseudoVMSET_M_B32 ||
Expand Down Expand Up @@ -3755,12 +3766,16 @@ bool RISCVDAGToDAGISel::performCombineVMergeAndVOps(SDNode *N) {
return false;
}

// If True is masked then the vmerge must have an all 1s mask, since we're
// going to keep the mask from True.
// If True is masked then the vmerge must have either the same mask or an all
// 1s mask, since we're going to keep the mask from True.
if (IsMasked && Mask) {
// FIXME: Support mask agnostic True instruction which would have an
// undef merge operand.
if (!usesAllOnesMask(Mask, Glue))
SDValue TrueMask =
getMaskSetter(True->getOperand(Info->MaskOpIdx),
True->getOperand(True->getNumOperands() - 1));
Comment on lines +3774 to +3776
Copy link
Contributor Author

@lukel97 lukel97 Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit gnarly but I still hope to eventually move all this into RISCVFoldMasks.cpp #71764. Sorry for dropping the ball on that

assert(TrueMask);
if (!usesAllOnesMask(Mask, Glue) && getMaskSetter(Mask, Glue) != TrueMask)
return false;
}

Expand Down
11 changes: 11 additions & 0 deletions llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-masked-vops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,14 @@ define <vscale x 2 x i32> @vpmerge_viota(<vscale x 2 x i32> %passthru, <vscale x
%b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> splat (i1 -1), i64 %1)
ret <vscale x 2 x i32> %b
}

define <vscale x 2 x i32> @vpmerge_vadd_same_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl) {
; CHECK-LABEL: vpmerge_vadd_same_mask:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, mu
; CHECK-NEXT: vadd.vv v8, v9, v10, v0.t
; CHECK-NEXT: ret
%a = call <vscale x 2 x i32> @llvm.riscv.vadd.mask.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, <vscale x 2 x i1> %m, i64 %vl, i64 1)
%b = call <vscale x 2 x i32> @llvm.riscv.vmerge.nxv2i32.nxv2i32(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %passthru, <vscale x 2 x i32> %a, <vscale x 2 x i1> %m, i64 %vl)
ret <vscale x 2 x i32> %b
}
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/RISCV/rvv/vfwmacc-vp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,8 @@ define <vscale x 1 x float> @vfmacc_vv_nxv1f32_tu(<vscale x 1 x half> %a, <vscal
define <vscale x 1 x float> @vfmacc_vv_nxv1f32_masked__tu(<vscale x 1 x half> %a, <vscale x 1 x half> %b, <vscale x 1 x float> %c, <vscale x 1 x i1> %m, i32 zeroext %evl) {
; ZVFH-LABEL: vfmacc_vv_nxv1f32_masked__tu:
; ZVFH: # %bb.0:
; ZVFH-NEXT: vmv1r.v v11, v10
; ZVFH-NEXT: vsetvli zero, a0, e16, mf4, ta, ma
; ZVFH-NEXT: vfwmacc.vv v11, v8, v9, v0.t
; ZVFH-NEXT: vsetvli zero, zero, e32, mf2, tu, ma
; ZVFH-NEXT: vmerge.vvm v10, v10, v11, v0
; ZVFH-NEXT: vsetvli zero, a0, e16, mf4, tu, mu
; ZVFH-NEXT: vfwmacc.vv v10, v8, v9, v0.t
; ZVFH-NEXT: vmv1r.v v8, v10
; ZVFH-NEXT: ret
;
Expand Down
27 changes: 19 additions & 8 deletions llvm/test/CodeGen/RISCV/rvv/vpload.ll
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ define <vscale x 1 x i8> @vpload_nxv1i8_allones_mask(ptr %ptr, i32 zeroext %evl)
ret <vscale x 1 x i8> %load
}

define <vscale x 1 x i8> @vpload_nxv1i8_passthru(ptr %ptr, <vscale x 1 x i1> %m, <vscale x 1 x i8> %passthru, i32 zeroext %evl) {
; CHECK-LABEL: vpload_nxv1i8_passthru:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a1, e8, mf8, tu, mu
; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%load = call <vscale x 1 x i8> @llvm.vp.load.nxv1i8.p0(ptr %ptr, <vscale x 1 x i1> %m, i32 %evl)
%merge = call <vscale x 1 x i8> @llvm.vp.merge.nxv1i8(<vscale x 1 x i1> %m, <vscale x 1 x i8> %load, <vscale x 1 x i8> %passthru, i32 %evl)
ret <vscale x 1 x i8> %merge
}

declare <vscale x 2 x i8> @llvm.vp.load.nxv2i8.p0(ptr, <vscale x 2 x i1>, i32)

define <vscale x 2 x i8> @vpload_nxv2i8(ptr %ptr, <vscale x 2 x i1> %m, i32 zeroext %evl) {
Expand Down Expand Up @@ -450,10 +461,10 @@ define <vscale x 16 x double> @vpload_nxv16f64(ptr %ptr, <vscale x 16 x i1> %m,
; CHECK-NEXT: add a4, a0, a4
; CHECK-NEXT: vsetvli zero, a3, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v16, (a4), v0.t
; CHECK-NEXT: bltu a1, a2, .LBB37_2
; CHECK-NEXT: bltu a1, a2, .LBB38_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a1, a2
; CHECK-NEXT: .LBB37_2:
; CHECK-NEXT: .LBB38_2:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: vsetvli zero, a1, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a0), v0.t
Expand All @@ -480,10 +491,10 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
; CHECK-NEXT: slli a5, a3, 1
; CHECK-NEXT: vmv1r.v v8, v0
; CHECK-NEXT: mv a4, a2
; CHECK-NEXT: bltu a2, a5, .LBB38_2
; CHECK-NEXT: bltu a2, a5, .LBB39_2
; CHECK-NEXT: # %bb.1:
; CHECK-NEXT: mv a4, a5
; CHECK-NEXT: .LBB38_2:
; CHECK-NEXT: .LBB39_2:
; CHECK-NEXT: sub a6, a4, a3
; CHECK-NEXT: sltu a7, a4, a6
; CHECK-NEXT: addi a7, a7, -1
Expand All @@ -499,21 +510,21 @@ define <vscale x 16 x double> @vpload_nxv17f64(ptr %ptr, ptr %out, <vscale x 17
; CHECK-NEXT: sltu a2, a2, a5
; CHECK-NEXT: addi a2, a2, -1
; CHECK-NEXT: and a2, a2, a5
; CHECK-NEXT: bltu a2, a3, .LBB38_4
; CHECK-NEXT: bltu a2, a3, .LBB39_4
; CHECK-NEXT: # %bb.3:
; CHECK-NEXT: mv a2, a3
; CHECK-NEXT: .LBB38_4:
; CHECK-NEXT: .LBB39_4:
; CHECK-NEXT: slli a5, a3, 4
; CHECK-NEXT: srli a6, a3, 2
; CHECK-NEXT: vsetvli a7, zero, e8, mf2, ta, ma
; CHECK-NEXT: vslidedown.vx v0, v8, a6
; CHECK-NEXT: add a5, a0, a5
; CHECK-NEXT: vsetvli zero, a2, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v24, (a5), v0.t
; CHECK-NEXT: bltu a4, a3, .LBB38_6
; CHECK-NEXT: bltu a4, a3, .LBB39_6
; CHECK-NEXT: # %bb.5:
; CHECK-NEXT: mv a4, a3
; CHECK-NEXT: .LBB38_6:
; CHECK-NEXT: .LBB39_6:
; CHECK-NEXT: vmv1r.v v0, v8
; CHECK-NEXT: vsetvli zero, a4, e64, m8, ta, ma
; CHECK-NEXT: vle64.v v8, (a0), v0.t
Expand Down
Loading