Skip to content

Commit ec8e1c6

Browse files
authored
[RISCV] Custom promote f16/bf16 (s/u)int_to_fp. (#107026)
This avoids having isel patterns that emit two instrutions. It also allows us to remove sext.w and slli+srli pairs by using fcvt.s.w(u) on RV64.
1 parent 70f3511 commit ec8e1c6

8 files changed

+132
-183
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
460460
setOperationAction(ISD::FABS, MVT::bf16, Custom);
461461
setOperationAction(ISD::FNEG, MVT::bf16, Custom);
462462
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
463+
setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, XLenVT, Custom);
463464
}
464465

465466
if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
@@ -478,6 +479,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
478479
setOperationAction(ISD::FABS, MVT::f16, Custom);
479480
setOperationAction(ISD::FNEG, MVT::f16, Custom);
480481
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
482+
setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, XLenVT, Custom);
481483
}
482484

483485
setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal);
@@ -590,9 +592,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
590592
setOperationAction({ISD::FP_TO_UINT_SAT, ISD::FP_TO_SINT_SAT}, XLenVT,
591593
Custom);
592594

593-
setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
594-
ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
595-
XLenVT, Legal);
595+
setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT}, XLenVT,
596+
Legal);
597+
setOperationAction({ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, XLenVT,
598+
Custom);
596599

597600
setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
598601
setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
@@ -2953,6 +2956,33 @@ InstructionCost RISCVTargetLowering::getVSlideVICost(MVT VT) const {
29532956
return getLMULCost(VT);
29542957
}
29552958

2959+
static SDValue lowerINT_TO_FP(SDValue Op, SelectionDAG &DAG,
2960+
const RISCVSubtarget &Subtarget) {
2961+
// f16 conversions are promoted to f32 when Zfh/Zhinx are not supported.
2962+
// bf16 conversions are always promoted to f32.
2963+
if ((Op.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) ||
2964+
Op.getValueType() == MVT::bf16) {
2965+
bool IsStrict = Op->isStrictFPOpcode();
2966+
2967+
SDLoc DL(Op);
2968+
if (IsStrict) {
2969+
SDValue Val = DAG.getNode(Op.getOpcode(), DL, {MVT::f32, MVT::Other},
2970+
{Op.getOperand(0), Op.getOperand(1)});
2971+
return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
2972+
{Op.getValueType(), MVT::Other},
2973+
{Val.getValue(1), Val.getValue(0),
2974+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)});
2975+
}
2976+
return DAG.getNode(
2977+
ISD::FP_ROUND, DL, Op.getValueType(),
2978+
DAG.getNode(Op.getOpcode(), DL, MVT::f32, Op.getOperand(0)),
2979+
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
2980+
}
2981+
2982+
// Other operations are legal.
2983+
return Op;
2984+
}
2985+
29562986
static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
29572987
const RISCVSubtarget &Subtarget) {
29582988
// RISC-V FP-to-int conversions saturate to the destination register size, but
@@ -6631,13 +6661,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
66316661
// the source. We custom-lower any conversions that do two hops into
66326662
// sequences.
66336663
MVT VT = Op.getSimpleValueType();
6664+
bool IsStrict = Op->isStrictFPOpcode();
6665+
SDValue Src = Op.getOperand(0 + IsStrict);
6666+
MVT SrcVT = Src.getSimpleValueType();
6667+
if (SrcVT.isScalarInteger())
6668+
return lowerINT_TO_FP(Op, DAG, Subtarget);
66346669
if (!VT.isVector())
66356670
return Op;
66366671
SDLoc DL(Op);
6637-
bool IsStrict = Op->isStrictFPOpcode();
6638-
SDValue Src = Op.getOperand(0 + IsStrict);
66396672
MVT EltVT = VT.getVectorElementType();
6640-
MVT SrcVT = Src.getSimpleValueType();
66416673
MVT SrcEltVT = SrcVT.getVectorElementType();
66426674
unsigned EltSize = EltVT.getSizeInBits();
66436675
unsigned SrcEltSize = SrcEltVT.getSizeInBits();

llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,13 @@ let Predicates = [HasStdExtZfbfmin] in {
6363
// rounding mode has no effect for bf16->f32.
6464
def : Pat<(i32 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_W_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
6565
def : Pat<(i32 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_WU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
66-
67-
// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
68-
def : Pat<(bf16 (any_sint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_W $rs1, FRM_DYN), FRM_DYN)>;
69-
def : Pat<(bf16 (any_uint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_WU $rs1, FRM_DYN), FRM_DYN)>;
7066
}
7167

7268
let Predicates = [HasStdExtZfbfmin, IsRV64] in {
7369
// bf16->[u]int64. Round-to-zero must be used for the f32->int step, the
7470
// rounding mode has no effect for bf16->f32.
7571
def : Pat<(i64 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_L_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
7672
def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
77-
78-
// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
79-
def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
80-
def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
8173
}
8274

8375
let Predicates = [HasStdExtZfbfmin, HasStdExtD] in {

llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -604,38 +604,22 @@ let Predicates = [HasStdExtZfhmin, NoStdExtZfh] in {
604604
// half->[u]int. Round-to-zero must be used.
605605
def : Pat<(i32 (any_fp_to_sint (f16 FPR16:$rs1))), (FCVT_W_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
606606
def : Pat<(i32 (any_fp_to_uint (f16 FPR16:$rs1))), (FCVT_WU_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
607-
608-
// [u]int->half. Match GCC and default to using dynamic rounding mode.
609-
def : Pat<(f16 (any_sint_to_fp (i32 GPR:$rs1))), (FCVT_H_S (FCVT_S_W $rs1, FRM_DYN), FRM_DYN)>;
610-
def : Pat<(f16 (any_uint_to_fp (i32 GPR:$rs1))), (FCVT_H_S (FCVT_S_WU $rs1, FRM_DYN), FRM_DYN)>;
611607
} // Predicates = [HasStdExtZfhmin, NoStdExtZfh]
612608

613609
let Predicates = [HasStdExtZhinxmin, NoStdExtZhinx] in {
614610
// half->[u]int. Round-to-zero must be used.
615611
def : Pat<(i32 (any_fp_to_sint FPR16INX:$rs1)), (FCVT_W_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
616612
def : Pat<(i32 (any_fp_to_uint FPR16INX:$rs1)), (FCVT_WU_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
617-
618-
// [u]int->half. Match GCC and default to using dynamic rounding mode.
619-
def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_W_INX $rs1, FRM_DYN), FRM_DYN)>;
620-
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_WU_INX $rs1, FRM_DYN), FRM_DYN)>;
621613
} // Predicates = [HasStdExtZhinxmin, NoStdExtZhinx]
622614

623615
let Predicates = [HasStdExtZfhmin, NoStdExtZfh, IsRV64] in {
624616
// half->[u]int64. Round-to-zero must be used.
625617
def : Pat<(i64 (any_fp_to_sint (f16 FPR16:$rs1))), (FCVT_L_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
626618
def : Pat<(i64 (any_fp_to_uint (f16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
627-
628-
// [u]int->fp. Match GCC and default to using dynamic rounding mode.
629-
def : Pat<(f16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_H_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
630-
def : Pat<(f16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_H_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
631619
} // Predicates = [HasStdExtZfhmin, NoStdExtZfh, IsRV64]
632620

633621
let Predicates = [HasStdExtZhinxmin, NoStdExtZhinx, IsRV64] in {
634622
// half->[u]int64. Round-to-zero must be used.
635623
def : Pat<(i64 (any_fp_to_sint FPR16INX:$rs1)), (FCVT_L_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
636624
def : Pat<(i64 (any_fp_to_uint FPR16INX:$rs1)), (FCVT_LU_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
637-
638-
// [u]int->fp. Match GCC and default to using dynamic rounding mode.
639-
def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_L_INX $rs1, FRM_DYN), FRM_DYN)>;
640-
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_LU_INX $rs1, FRM_DYN), FRM_DYN)>;
641625
} // Predicates = [HasStdExtZhinxmin, NoStdExtZhinx, IsRV64]

llvm/test/CodeGen/RISCV/bfloat-convert.ll

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ define bfloat @fcvt_bf16_si(i16 %a) nounwind {
749749
; CHECK64ZFBFMIN: # %bb.0:
750750
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 48
751751
; CHECK64ZFBFMIN-NEXT: srai a0, a0, 48
752-
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
752+
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
753753
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
754754
; CHECK64ZFBFMIN-NEXT: ret
755755
;
@@ -795,7 +795,7 @@ define bfloat @fcvt_bf16_si_signext(i16 signext %a) nounwind {
795795
;
796796
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_si_signext:
797797
; CHECK64ZFBFMIN: # %bb.0:
798-
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
798+
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
799799
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
800800
; CHECK64ZFBFMIN-NEXT: ret
801801
;
@@ -845,7 +845,7 @@ define bfloat @fcvt_bf16_ui(i16 %a) nounwind {
845845
; CHECK64ZFBFMIN: # %bb.0:
846846
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 48
847847
; CHECK64ZFBFMIN-NEXT: srli a0, a0, 48
848-
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
848+
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
849849
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
850850
; CHECK64ZFBFMIN-NEXT: ret
851851
;
@@ -891,7 +891,7 @@ define bfloat @fcvt_bf16_ui_zeroext(i16 zeroext %a) nounwind {
891891
;
892892
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_ui_zeroext:
893893
; CHECK64ZFBFMIN: # %bb.0:
894-
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
894+
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
895895
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
896896
; CHECK64ZFBFMIN-NEXT: ret
897897
;
@@ -935,8 +935,7 @@ define bfloat @fcvt_bf16_w(i32 %a) nounwind {
935935
;
936936
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w:
937937
; CHECK64ZFBFMIN: # %bb.0:
938-
; CHECK64ZFBFMIN-NEXT: sext.w a0, a0
939-
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
938+
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
940939
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
941940
; CHECK64ZFBFMIN-NEXT: ret
942941
;
@@ -983,7 +982,7 @@ define bfloat @fcvt_bf16_w_load(ptr %p) nounwind {
983982
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w_load:
984983
; CHECK64ZFBFMIN: # %bb.0:
985984
; CHECK64ZFBFMIN-NEXT: lw a0, 0(a0)
986-
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
985+
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
987986
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
988987
; CHECK64ZFBFMIN-NEXT: ret
989988
;
@@ -1029,9 +1028,7 @@ define bfloat @fcvt_bf16_wu(i32 %a) nounwind {
10291028
;
10301029
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu:
10311030
; CHECK64ZFBFMIN: # %bb.0:
1032-
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 32
1033-
; CHECK64ZFBFMIN-NEXT: srli a0, a0, 32
1034-
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
1031+
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
10351032
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
10361033
; CHECK64ZFBFMIN-NEXT: ret
10371034
;
@@ -1078,7 +1075,7 @@ define bfloat @fcvt_bf16_wu_load(ptr %p) nounwind {
10781075
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu_load:
10791076
; CHECK64ZFBFMIN: # %bb.0:
10801077
; CHECK64ZFBFMIN-NEXT: lwu a0, 0(a0)
1081-
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
1078+
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
10821079
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
10831080
; CHECK64ZFBFMIN-NEXT: ret
10841081
;
@@ -1376,7 +1373,7 @@ define signext i32 @fcvt_bf16_w_demanded_bits(i32 signext %0, ptr %1) nounwind {
13761373
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w_demanded_bits:
13771374
; CHECK64ZFBFMIN: # %bb.0:
13781375
; CHECK64ZFBFMIN-NEXT: addiw a0, a0, 1
1379-
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
1376+
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
13801377
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa5, fa5
13811378
; CHECK64ZFBFMIN-NEXT: fsh fa5, 0(a1)
13821379
; CHECK64ZFBFMIN-NEXT: ret
@@ -1436,9 +1433,7 @@ define signext i32 @fcvt_bf16_wu_demanded_bits(i32 signext %0, ptr %1) nounwind
14361433
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu_demanded_bits:
14371434
; CHECK64ZFBFMIN: # %bb.0:
14381435
; CHECK64ZFBFMIN-NEXT: addiw a0, a0, 1
1439-
; CHECK64ZFBFMIN-NEXT: slli a2, a0, 32
1440-
; CHECK64ZFBFMIN-NEXT: srli a2, a2, 32
1441-
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a2
1436+
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
14421437
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa5, fa5
14431438
; CHECK64ZFBFMIN-NEXT: fsh fa5, 0(a1)
14441439
; CHECK64ZFBFMIN-NEXT: ret

0 commit comments

Comments
 (0)