diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index e99c6208594e3..f4da46f82a810 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -16128,23 +16128,26 @@ static bool matchIndexAsWiderOp(EVT VT, SDValue Index, SDValue Mask, return true; } +// trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) +// This would be benefit for the cases where X and Y are both the same value +// type of low precision vectors. Since the truncate would be lowered into +// n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate +// restriction, such pattern would be expanded into a series of "vsetvli" +// and "vnsrl" instructions later to reach this point. static SDValue combineTruncOfSraSext(SDNode *N, SelectionDAG &DAG) { - // trunc (sra sext (X), zext (Y)) -> sra (X, smin (Y, scalarsize(Y) - 1)) - // This would be benefit for the cases where X and Y are both the same value - // type of low precision vectors. Since the truncate would be lowered into - // n-levels TRUNCATE_VECTOR_VL to satisfy RVV's SEW*2->SEW truncate - // restriction, such pattern would be expanded into a series of "vsetvli" - // and "vnsrl" instructions later to reach this point. - auto IsTruncNode = [](SDValue V) { - if (V.getOpcode() != RISCVISD::TRUNCATE_VECTOR_VL) - return false; - SDValue VL = V.getOperand(2); - auto *C = dyn_cast(VL); - // Assume all TRUNCATE_VECTOR_VL nodes use VLMAX for VMSET_VL operand - bool IsVLMAXForVMSET = (C && C->isAllOnes()) || - (isa(VL) && - cast(VL)->getReg() == RISCV::X0); - return V.getOperand(1).getOpcode() == RISCVISD::VMSET_VL && IsVLMAXForVMSET; + SDValue Mask = N->getOperand(1); + SDValue VL = N->getOperand(2); + + bool IsVLMAX = isAllOnesConstant(VL) || + (isa(VL) && + cast(VL)->getReg() == RISCV::X0); + if (!IsVLMAX || Mask.getOpcode() != RISCVISD::VMSET_VL || + Mask.getOperand(0) != VL) + return SDValue(); + + auto IsTruncNode = [&](SDValue V) { + return V.getOpcode() == RISCVISD::TRUNCATE_VECTOR_VL && + V.getOperand(1) == Mask && V.getOperand(2) == VL; }; SDValue Op = N->getOperand(0); diff --git a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll index 8dbb57fd15cf1..382c8297473b7 100644 --- a/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsra-sdnode.ll @@ -937,13 +937,17 @@ define @vsra_vi_mask_nxv8i32( %va, @vsra_vv_nxv1i8_sext_zext_mixed_trunc( %va, %vb, %m, i32 %evl) { +define @vsra_vv_nxv1i8_sext_zext_mixed_trunc( %va, %vb, %m, i32 zeroext %evl) { ; CHECK-LABEL: vsra_vv_nxv1i8_sext_zext_mixed_trunc: ; CHECK: # %bb.0: -; CHECK-NEXT: li a0, 7 -; CHECK-NEXT: vsetvli a1, zero, e8, mf8, ta, ma -; CHECK-NEXT: vmin.vx v9, v8, a0 -; CHECK-NEXT: vsra.vv v8, v8, v9 +; CHECK-NEXT: vsetvli a1, zero, e32, mf2, ta, ma +; CHECK-NEXT: vsext.vf4 v9, v8 +; CHECK-NEXT: vzext.vf4 v10, v8 +; CHECK-NEXT: vsra.vv v8, v9, v10 +; CHECK-NEXT: vsetvli zero, zero, e16, mf4, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0 +; CHECK-NEXT: vsetvli zero, a0, e8, mf8, ta, ma +; CHECK-NEXT: vnsrl.wi v8, v8, 0, v0.t ; CHECK-NEXT: ret %sexted_va = sext %va to %zexted_vb = zext %va to