Skip to content

[SDAG] Support expanding FSINCOS to vector library calls #114039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,9 @@ class SelectionDAG {
SDValue getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
SDValue Op2);

/// Expand the specified \c ISD::FSINCOS node as the Legalize pass would.
bool expandFSINCOS(SDNode *Node, SmallVectorImpl<SDValue> &Results);

/// Expand the specified \c ISD::VAARG node as the Legalize pass would.
SDValue expandVAArg(SDNode *Node);

Expand Down
71 changes: 1 addition & 70 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2348,75 +2348,6 @@ static bool useSinCos(SDNode *Node) {
return false;
}

/// Issue libcalls to sincos to compute sin / cos pairs.
void SelectionDAGLegalize::ExpandSinCosLibCall(
SDNode *Node, SmallVectorImpl<SDValue> &Results) {
EVT VT = Node->getValueType(0);
Type *Ty = VT.getTypeForEVT(*DAG.getContext());
RTLIB::Libcall LC = RTLIB::getFSINCOS(VT);

// Find users of the node that store the results (and share input chains). The
// destination pointers can be used instead of creating stack allocations.
SDValue StoresInChain{};
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
for (SDNode *User : Node->uses()) {
if (!ISD::isNormalStore(User))
continue;
auto *ST = cast<StoreSDNode>(User);
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
ST->getAlign() < DAG.getDataLayout().getABITypeAlign(Ty) ||
(StoresInChain && ST->getChain() != StoresInChain) ||
Node->isPredecessorOf(ST->getChain().getNode()))
continue;
ResultStores[ST->getValue().getResNo()] = ST;
StoresInChain = ST->getChain();
}

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry{};

// Pass the argument.
Entry.Node = Node->getOperand(0);
Entry.Ty = Ty;
Args.push_back(Entry);

// Pass the output pointers for sin and cos.
SmallVector<SDValue, 2> ResultPtrs{};
for (StoreSDNode *ST : ResultStores) {
SDValue ResultPtr = ST ? ST->getBasePtr() : DAG.CreateStackTemporary(VT);
Entry.Node = ResultPtr;
Entry.Ty = PointerType::getUnqual(Ty->getContext());
Args.push_back(Entry);
ResultPtrs.push_back(ResultPtr);
}

SDLoc DL(Node);
SDValue InChain = StoresInChain ? StoresInChain : DAG.getEntryNode();
SDValue Callee = DAG.getExternalSymbol(TLI.getLibcallName(LC),
TLI.getPointerTy(DAG.getDataLayout()));
TargetLowering::CallLoweringInfo CLI(DAG);
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
TLI.getLibcallCallingConv(LC), Type::getVoidTy(*DAG.getContext()), Callee,
std::move(Args));

auto [Call, OutChain] = TLI.LowerCallTo(CLI);

for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
MachinePointerInfo PtrInfo;
if (StoreSDNode *ST = ResultStores[ResNo]) {
// Replace store with the library call.
DAG.ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
PtrInfo = ST->getPointerInfo();
} else {
PtrInfo = MachinePointerInfo::getFixedStack(
DAG.getMachineFunction(),
cast<FrameIndexSDNode>(ResultPtr)->getIndex());
}
SDValue LoadResult = DAG.getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
Results.push_back(LoadResult);
}
}

SDValue SelectionDAGLegalize::expandLdexp(SDNode *Node) const {
SDLoc dl(Node);
EVT VT = Node->getValueType(0);
Expand Down Expand Up @@ -4633,7 +4564,7 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
break;
case ISD::FSINCOS:
// Expand into sincos libcall.
ExpandSinCosLibCall(Node, Results);
(void)DAG.expandFSINCOS(Node, Results);
break;
case ISD::FLOG:
case ISD::STRICT_FLOG:
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,11 @@ void VectorLegalizer::Expand(SDNode *Node, SmallVectorImpl<SDValue> &Results) {
RTLIB::REM_PPCF128, Results))
return;

break;
case ISD::FSINCOS:
if (DAG.expandFSINCOS(Node, Results))
return;

break;
case ISD::VECTOR_COMPRESS:
Results.push_back(TLI.expandVECTOR_COMPRESS(Node, DAG));
Expand Down
98 changes: 98 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/MemoryLocation.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/BinaryFormat/Dwarf.h"
Expand Down Expand Up @@ -2483,6 +2484,103 @@ SDValue SelectionDAG::getPartialReduceAdd(SDLoc DL, EVT ReducedTy, SDValue Op1,
return Subvectors[0];
}

bool SelectionDAG::expandFSINCOS(SDNode *Node,
SmallVectorImpl<SDValue> &Results) {
EVT VT = Node->getValueType(0);
LLVMContext *Ctx = getContext();
Type *Ty = VT.getTypeForEVT(*Ctx);
RTLIB::Libcall LC =
RTLIB::getFSINCOS(VT.isVector() ? VT.getVectorElementType() : VT);

const char *LCName = TLI->getLibcallName(LC);
if (!LC || !LCName)
return false;

auto getVecDesc = [&]() -> VecDesc const * {
for (bool Masked : {false, true}) {
if (VecDesc const *VD = getLibInfo().getVectorMappingInfo(
LCName, VT.getVectorElementCount(), Masked)) {
return VD;
}
}
return nullptr;
};

VecDesc const *VD = nullptr;
if (VT.isVector() && !(VD = getVecDesc()))
return false;

// Find users of the node that store the results (and share input chains). The
// destination pointers can be used instead of creating stack allocations.
SDValue StoresInChain{};
std::array<StoreSDNode *, 2> ResultStores = {nullptr};
for (SDNode *User : Node->uses()) {
if (!ISD::isNormalStore(User))
continue;
auto *ST = cast<StoreSDNode>(User);
if (!ST->isSimple() || ST->getAddressSpace() != 0 ||
ST->getAlign() < getDataLayout().getABITypeAlign(Ty->getScalarType()) ||
(StoresInChain && ST->getChain() != StoresInChain) ||
Node->isPredecessorOf(ST->getChain().getNode()))
continue;
ResultStores[ST->getValue().getResNo()] = ST;
StoresInChain = ST->getChain();
}

TargetLowering::ArgListTy Args;
TargetLowering::ArgListEntry Entry{};

// Pass the argument.
Entry.Node = Node->getOperand(0);
Entry.Ty = Ty;
Args.push_back(Entry);

// Pass the output pointers for sin and cos.
SmallVector<SDValue, 2> ResultPtrs{};
for (StoreSDNode *ST : ResultStores) {
SDValue ResultPtr = ST ? ST->getBasePtr() : CreateStackTemporary(VT);
Entry.Node = ResultPtr;
Entry.Ty = PointerType::getUnqual(Ty->getContext());
Args.push_back(Entry);
ResultPtrs.push_back(ResultPtr);
}

SDLoc DL(Node);

if (VD && VD->isMasked()) {
EVT MaskVT = TLI->getSetCCResultType(getDataLayout(), *Ctx, VT);
Entry.Node = getBoolConstant(true, DL, MaskVT, VT);
Entry.Ty = MaskVT.getTypeForEVT(*Ctx);
Args.push_back(Entry);
}

SDValue InChain = StoresInChain ? StoresInChain : getEntryNode();
SDValue Callee = getExternalSymbol(VD ? VD->getVectorFnName().data() : LCName,
TLI->getPointerTy(getDataLayout()));
TargetLowering::CallLoweringInfo CLI(*this);
CLI.setDebugLoc(DL).setChain(InChain).setLibCallee(
TLI->getLibcallCallingConv(LC), Type::getVoidTy(*Ctx), Callee,
std::move(Args));

auto [Call, OutChain] = TLI->LowerCallTo(CLI);

for (auto [ResNo, ResultPtr] : llvm::enumerate(ResultPtrs)) {
MachinePointerInfo PtrInfo;
if (StoreSDNode *ST = ResultStores[ResNo]) {
// Replace store with the library call.
ReplaceAllUsesOfValueWith(SDValue(ST, 0), OutChain);
PtrInfo = ST->getPointerInfo();
} else {
PtrInfo = MachinePointerInfo::getFixedStack(
getMachineFunction(), cast<FrameIndexSDNode>(ResultPtr)->getIndex());
}
SDValue LoadResult = getLoad(VT, DL, OutChain, ResultPtr, PtrInfo);
Results.push_back(LoadResult);
}

return true;
}

SDValue SelectionDAG::expandVAArg(SDNode *Node) {
SDLoc dl(Node);
const TargetLowering &TLI = getTargetLoweringInfo();
Expand Down
61 changes: 61 additions & 0 deletions llvm/test/CodeGen/AArch64/veclib-llvm.sincos.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --filter "(bl|ptrue)" --version 5
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=sleefgnuabi < %s | FileCheck %s -check-prefix=SLEEF
; RUN: llc -mtriple=aarch64-gnu-linux -mattr=+neon,+sve -vector-library=ArmPL < %s | FileCheck %s -check-prefix=ARMPL

define void @test_sincos_v4f32(<4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
; SLEEF-LABEL: test_sincos_v4f32:
; SLEEF: bl _ZGVnN4vl4l4_sincosf
;
; ARMPL-LABEL: test_sincos_v4f32:
; ARMPL: bl armpl_vsincosq_f32
%result = call { <4 x float>, <4 x float> } @llvm.sincos.v4f32(<4 x float> %x)
%result.0 = extractvalue { <4 x float>, <4 x float> } %result, 0
%result.1 = extractvalue { <4 x float>, <4 x float> } %result, 1
store <4 x float> %result.0, ptr %out_sin, align 4
store <4 x float> %result.1, ptr %out_cos, align 4
ret void
}

define void @test_sincos_v2f64(<2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
; SLEEF-LABEL: test_sincos_v2f64:
; SLEEF: bl _ZGVnN2vl8l8_sincos
;
; ARMPL-LABEL: test_sincos_v2f64:
; ARMPL: bl armpl_vsincosq_f64
%result = call { <2 x double>, <2 x double> } @llvm.sincos.v2f64(<2 x double> %x)
%result.0 = extractvalue { <2 x double>, <2 x double> } %result, 0
%result.1 = extractvalue { <2 x double>, <2 x double> } %result, 1
store <2 x double> %result.0, ptr %out_sin, align 8
store <2 x double> %result.1, ptr %out_cos, align 8
ret void
}

define void @test_sincos_nxv4f32(<vscale x 4 x float> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
; SLEEF-LABEL: test_sincos_nxv4f32:
; SLEEF: bl _ZGVsNxvl4l4_sincosf
;
; ARMPL-LABEL: test_sincos_nxv4f32:
; ARMPL: ptrue p0.s
; ARMPL: bl armpl_svsincos_f32_x
%result = call { <vscale x 4 x float>, <vscale x 4 x float> } @llvm.sincos.nxv4f32(<vscale x 4 x float> %x)
%result.0 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 0
%result.1 = extractvalue { <vscale x 4 x float>, <vscale x 4 x float> } %result, 1
store <vscale x 4 x float> %result.0, ptr %out_sin, align 4
store <vscale x 4 x float> %result.1, ptr %out_cos, align 4
ret void
}

define void @test_sincos_nxv2f64(<vscale x 2 x double> %x, ptr noalias %out_sin, ptr noalias %out_cos) {
; SLEEF-LABEL: test_sincos_nxv2f64:
; SLEEF: bl _ZGVsNxvl8l8_sincos
;
; ARMPL-LABEL: test_sincos_nxv2f64:
; ARMPL: ptrue p0.d
; ARMPL: bl armpl_svsincos_f64_x
%result = call { <vscale x 2 x double>, <vscale x 2 x double> } @llvm.sincos.nxv2f64(<vscale x 2 x double> %x)
%result.0 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 0
%result.1 = extractvalue { <vscale x 2 x double>, <vscale x 2 x double> } %result, 1
store <vscale x 2 x double> %result.0, ptr %out_sin, align 8
store <vscale x 2 x double> %result.1, ptr %out_cos, align 8
ret void
}
Loading