@@ -711,6 +711,13 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
711711 setOperationAction (ISD ::BITCAST , MVT ::f32 , Custom);
712712 }
713713
714+ // Expand FP16 <=> FP32 conversions to libcalls and handle FP16 loads and
715+ // stores in GPRs.
716+ setOperationAction (ISD ::FP16_TO_FP , MVT ::f32 , Expand);
717+ setOperationAction (ISD ::FP_TO_FP16 , MVT ::f32 , Expand);
718+ setLoadExtAction (ISD ::EXTLOAD , MVT ::f32 , MVT ::f16 , Expand);
719+ setTruncStoreAction (MVT ::f32 , MVT ::f16 , Expand);
720+
714721 // VASTART and VACOPY need to deal with the SystemZ-specific varargs
715722 // structure, but VAEND is a no-op.
716723 setOperationAction (ISD ::VASTART , MVT ::Other, Custom);
@@ -784,6 +791,20 @@ bool SystemZTargetLowering::useSoftFloat() const {
784791 return Subtarget.hasSoftFloat ();
785792}
786793
794+ MVT SystemZTargetLowering::getRegisterTypeForCallingConv (
795+ LLVMContext &Context, CallingConv::ID CC ,
796+ EVT VT ) const {
797+ // 128-bit single-element vector types are passed like other vectors,
798+ // not like their element type.
799+ if (VT .isVector () && VT .getSizeInBits () == 128 &&
800+ VT .getVectorNumElements () == 1 )
801+ return MVT ::v16i8;
802+ // Keep f16 so that they can be recognized and handled.
803+ if (VT == MVT ::f16 )
804+ return MVT ::f16 ;
805+ return TargetLowering::getRegisterTypeForCallingConv (Context, CC , VT );
806+ }
807+
787808EVT SystemZTargetLowering::getSetCCResultType (const DataLayout &DL ,
788809 LLVMContext &, EVT VT ) const {
789810 if (!VT .isVector ())
@@ -1597,6 +1618,15 @@ bool SystemZTargetLowering::splitValueIntoRegisterParts(
15971618 return true ;
15981619 }
15991620
1621+ // Convert f16 to f32 (Out-arg).
1622+ if (PartVT == MVT ::f16 ) {
1623+ assert (NumParts == 1 && " " );
1624+ SDValue I16Val = DAG .getBitcast (MVT ::i16 , Val);
1625+ SDValue I32Val = DAG .getAnyExtOrTrunc (I16Val, DL , MVT ::i32 );
1626+ Parts[0 ] = DAG .getBitcast (MVT ::f32 , I32Val);
1627+ return true ;
1628+ }
1629+
16001630 return false ;
16011631}
16021632
@@ -1612,6 +1642,18 @@ SDValue SystemZTargetLowering::joinRegisterPartsIntoValue(
16121642 return SDValue ();
16131643}
16141644
1645+ // F32Val holds a f16 value in f32, return it as an f16 (In-arg). The
1646+ // CopyFromReg was made into an f32 as required as FP32 registers are used
1647+ // for arguments, now convert it to f16.
1648+ static SDValue convertF32ToF16 (SDValue F32Val, SelectionDAG &DAG ,
1649+ const SDLoc &DL ) {
1650+ assert (F32Val->getOpcode () == ISD ::CopyFromReg &&
1651+ " Only expecting to handle f16 with CopyFromReg here." );
1652+ SDValue I32Val = DAG .getBitcast (MVT ::i32 , F32Val);
1653+ SDValue I16Val = DAG .getAnyExtOrTrunc (I32Val, DL , MVT ::i16 );
1654+ return DAG .getBitcast (MVT ::f16 , I16Val);
1655+ }
1656+
16151657SDValue SystemZTargetLowering::LowerFormalArguments (
16161658 SDValue Chain, CallingConv::ID CallConv, bool IsVarArg,
16171659 const SmallVectorImpl<ISD ::InputArg> &Ins, const SDLoc &DL ,
@@ -1651,6 +1693,7 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16511693 NumFixedGPRs += 1 ;
16521694 RC = &SystemZ::GR64BitRegClass;
16531695 break ;
1696+ case MVT ::f16 :
16541697 case MVT ::f32 :
16551698 NumFixedFPRs += 1 ;
16561699 RC = &SystemZ::FP32BitRegClass;
@@ -1675,7 +1718,11 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16751718
16761719 Register VReg = MRI .createVirtualRegister (RC );
16771720 MRI .addLiveIn (VA .getLocReg (), VReg);
1678- ArgValue = DAG .getCopyFromReg (Chain, DL , VReg, LocVT);
1721+ // Special handling is needed for f16.
1722+ MVT ArgVT = VA .getLocVT () == MVT ::f16 ? MVT ::f32 : VA .getLocVT ();
1723+ ArgValue = DAG .getCopyFromReg (Chain, DL , VReg, ArgVT);
1724+ if (VA .getLocVT () == MVT ::f16 )
1725+ ArgValue = convertF32ToF16 (ArgValue, DAG , DL );
16791726 } else {
16801727 assert (VA .isMemLoc () && " Argument not register or memory" );
16811728
@@ -1695,9 +1742,12 @@ SDValue SystemZTargetLowering::LowerFormalArguments(
16951742 // from this parameter. Unpromoted ints and floats are
16961743 // passed as right-justified 8-byte values.
16971744 SDValue FIN = DAG .getFrameIndex (FI , PtrVT);
1698- if (VA .getLocVT () == MVT ::i32 || VA .getLocVT () == MVT ::f32 )
1745+ if (VA .getLocVT () == MVT ::i32 || VA .getLocVT () == MVT ::f32 ||
1746+ VA .getLocVT () == MVT ::f16 ) {
1747+ unsigned SlotOffs = VA .getLocVT () == MVT ::f16 ? 6 : 4 ;
16991748 FIN = DAG .getNode (ISD ::ADD , DL , PtrVT, FIN ,
1700- DAG .getIntPtrConstant (4 , DL ));
1749+ DAG .getIntPtrConstant (SlotOffs, DL ));
1750+ }
17011751 ArgValue = DAG .getLoad (LocVT, DL , Chain, FIN ,
17021752 MachinePointerInfo::getFixedStack (MF , FI ));
17031753 }
@@ -2120,10 +2170,14 @@ SystemZTargetLowering::LowerCall(CallLoweringInfo &CLI,
21202170 // Copy all of the result registers out of their specified physreg.
21212171 for (CCValAssign &VA : RetLocs) {
21222172 // Copy the value out, gluing the copy to the end of the call sequence.
2173+ // Special handling is needed for f16.
2174+ MVT ArgVT = VA .getLocVT () == MVT ::f16 ? MVT ::f32 : VA .getLocVT ();
21232175 SDValue RetValue = DAG .getCopyFromReg (Chain, DL , VA .getLocReg (),
2124- VA . getLocVT () , Glue);
2176+ ArgVT , Glue);
21252177 Chain = RetValue.getValue (1 );
21262178 Glue = RetValue.getValue (2 );
2179+ if (VA .getLocVT () == MVT ::f16 )
2180+ RetValue = convertF32ToF16 (RetValue, DAG , DL );
21272181
21282182 // Convert the value of the return register into the value that's
21292183 // being returned.
0 commit comments