diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index b2124c6106198..3411163549de2 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -882,6 +882,8 @@ class TargetTransformInfo { /// should use coldcc calling convention. bool useColdCCForColdCall(Function &F) const; + bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const; + /// Estimate the overhead of scalarizing an instruction. Insert and Extract /// are set if the demanded result elements need to be inserted and/or /// extracted from vectors. @@ -1928,6 +1930,7 @@ class TargetTransformInfo::Concept { virtual bool shouldBuildLookupTablesForConstant(Constant *C) = 0; virtual bool shouldBuildRelLookupTables() = 0; virtual bool useColdCCForColdCall(Function &F) = 0; + virtual bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) = 0; virtual InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, @@ -2467,7 +2470,9 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept { bool useColdCCForColdCall(Function &F) override { return Impl.useColdCCForColdCall(F); } - + bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) override { + return Impl.isTargetIntrinsicTriviallyScalarizable(ID); + } InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index 90eef93a2a54d..2819af30cd170 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -373,6 +373,10 @@ class TargetTransformInfoImplBase { bool useColdCCForColdCall(Function &F) const { return false; } + bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const { + return false; + } + InstructionCost getScalarizationOverhead(VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h index caa3a57ebabc2..2f2a6a09ffc44 100644 --- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h +++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h @@ -789,6 +789,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase { return Cost; } + bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const { + return false; + } + /// Helper wrapper for the DemandedElts variant of getScalarizationOverhead. InstructionCost getScalarizationOverhead(VectorType *InTy, bool Insert, bool Extract, diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 2c26493bd3f1c..67b626f300a10 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -587,6 +587,11 @@ bool TargetTransformInfo::useColdCCForColdCall(Function &F) const { return TTIImpl->useColdCCForColdCall(F); } +bool TargetTransformInfo::isTargetIntrinsicTriviallyScalarizable( + Intrinsic::ID ID) const { + return TTIImpl->isTargetIntrinsicTriviallyScalarizable(ID); +} + InstructionCost TargetTransformInfo::getScalarizationOverhead( VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, TTI::TargetCostKind CostKind) const { diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index f7ae09957996b..a9c5d81391b8d 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -18,6 +18,7 @@ add_llvm_target(DirectXCodeGen DirectXRegisterInfo.cpp DirectXSubtarget.cpp DirectXTargetMachine.cpp + DirectXTargetTransformInfo.cpp DXContainerGlobals.cpp DXILFinalizeLinkage.cpp DXILIntrinsicExpansion.cpp diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp new file mode 100644 index 0000000000000..1a59f04b21404 --- /dev/null +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp @@ -0,0 +1,25 @@ +//===- DirectXTargetTransformInfo.cpp - DirectX TTI ---------------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +//===----------------------------------------------------------------------===// + +#include "DirectXTargetTransformInfo.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsDirectX.h" + +bool llvm::DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable( + Intrinsic::ID ID) const { + switch (ID) { + case Intrinsic::dx_frac: + case Intrinsic::dx_rsqrt: + return true; + default: + return false; + } +} diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h index ed98355fad002..48414549f8349 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h +++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.h @@ -34,6 +34,7 @@ class DirectXTTIImpl : public BasicTTIImplBase { : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)), TLI(ST->getTargetLowering()) {} unsigned getMinVectorRegisterBitWidth() const { return 32; } + bool isTargetIntrinsicTriviallyScalarizable(Intrinsic::ID ID) const; }; } // namespace llvm diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp index 01d24335df226..d464e49990b3b 100644 --- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp +++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" @@ -281,10 +282,11 @@ T getWithDefaultOverride(const cl::opt &ClOption, class ScalarizerVisitor : public InstVisitor { public: - ScalarizerVisitor(DominatorTree *DT, ScalarizerPassOptions Options) - : DT(DT), ScalarizeVariableInsertExtract(getWithDefaultOverride( - ClScalarizeVariableInsertExtract, - Options.ScalarizeVariableInsertExtract)), + ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI, + ScalarizerPassOptions Options) + : DT(DT), TTI(TTI), ScalarizeVariableInsertExtract(getWithDefaultOverride( + ClScalarizeVariableInsertExtract, + Options.ScalarizeVariableInsertExtract)), ScalarizeLoadStore(getWithDefaultOverride(ClScalarizeLoadStore, Options.ScalarizeLoadStore)), ScalarizeMinBits(getWithDefaultOverride(ClScalarizeMinBits, @@ -292,6 +294,8 @@ class ScalarizerVisitor : public InstVisitor { bool visit(Function &F); + bool isTriviallyScalarizable(Intrinsic::ID ID); + // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. bool visitInstruction(Instruction &I) { return false; } @@ -335,6 +339,7 @@ class ScalarizerVisitor : public InstVisitor { SmallVector PotentiallyDeadInstrs; DominatorTree *DT; + const TargetTransformInfo *TTI; const bool ScalarizeVariableInsertExtract; const bool ScalarizeLoadStore; @@ -358,6 +363,7 @@ ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options) void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired(); + AU.addRequired(); AU.addPreserved(); } @@ -445,7 +451,9 @@ bool ScalarizerLegacyPass::runOnFunction(Function &F) { return false; DominatorTree *DT = &getAnalysis().getDomTree(); - ScalarizerVisitor Impl(DT, Options); + const TargetTransformInfo *TTI = + &getAnalysis().getTTI(F); + ScalarizerVisitor Impl(DT, TTI, Options); return Impl.visit(F); } @@ -689,8 +697,11 @@ bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) { return true; } -static bool isTriviallyScalariable(Intrinsic::ID ID) { - return isTriviallyVectorizable(ID); +bool ScalarizerVisitor::isTriviallyScalarizable(Intrinsic::ID ID) { + if (isTriviallyVectorizable(ID)) + return true; + return Function::isTargetIntrinsic(ID) && + TTI->isTargetIntrinsicTriviallyScalarizable(ID); } /// If a call to a vector typed intrinsic function, split into a scalar call per @@ -705,7 +716,8 @@ bool ScalarizerVisitor::splitCall(CallInst &CI) { return false; Intrinsic::ID ID = F->getIntrinsicID(); - if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) + + if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID)) return false; // unsigned NumElems = VT->getNumElements(); @@ -1249,7 +1261,8 @@ bool ScalarizerVisitor::finish() { PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree *DT = &AM.getResult(F); - ScalarizerVisitor Impl(DT, Options); + const TargetTransformInfo *TTI = &AM.getResult(F); + ScalarizerVisitor Impl(DT, TTI, Options); bool Changed = Impl.visit(F); PreservedAnalyses PA; PA.preserve(); diff --git a/llvm/test/CodeGen/DirectX/frac.ll b/llvm/test/CodeGen/DirectX/frac.ll index ae86fe06654da..ef24527ce837b 100644 --- a/llvm/test/CodeGen/DirectX/frac.ll +++ b/llvm/test/CodeGen/DirectX/frac.ll @@ -1,31 +1,55 @@ -; RUN: opt -S -dxil-op-lower < %s | FileCheck %s +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; Make sure dxil operation function calls for frac are generated for float and half. -; CHECK:call float @dx.op.unary.f32(i32 22, float %{{.*}}) -; CHECK:call half @dx.op.unary.f16(i32 22, half %{{.*}}) -target datalayout = "e-m:e-p:32:32-i1:32-i8:8-i16:16-i32:32-i64:64-f16:16-f32:32-f64:64-n8:16:32:64" -target triple = "dxil-pc-shadermodel6.7-library" +define noundef half @frac_half(half noundef %a) { +; CHECK-LABEL: define noundef half @frac_half( +; CHECK-SAME: half noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[DX_FRAC1:%.*]] = call half @dx.op.unary.f16(i32 22, half [[A]]) +; CHECK-NEXT: ret half [[DX_FRAC1]] +; +entry: + %dx.frac = call half @llvm.dx.frac.f16(half %a) + ret half %dx.frac +} -; Function Attrs: noinline nounwind optnone define noundef float @frac_float(float noundef %a) #0 { +; CHECK-LABEL: define noundef float @frac_float( +; CHECK-SAME: float noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[DX_FRAC1:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A]]) +; CHECK-NEXT: ret float [[DX_FRAC1]] +; entry: - %a.addr = alloca float, align 4 - store float %a, ptr %a.addr, align 4 - %0 = load float, ptr %a.addr, align 4 - %dx.frac = call float @llvm.dx.frac.f32(float %0) + %dx.frac = call float @llvm.dx.frac.f32(float %a) ret float %dx.frac } -; Function Attrs: nocallback nofree nosync nounwind readnone speculatable willreturn -declare float @llvm.dx.frac.f32(float) #1 - -; Function Attrs: noinline nounwind optnone -define noundef half @frac_half(half noundef %a) #0 { +define noundef <4 x float> @frac_float4(<4 x float> noundef %a) #0 { +; CHECK-LABEL: define noundef <4 x float> @frac_float4( +; CHECK-SAME: <4 x float> noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[A_I0:%.*]] = extractelement <4 x float> [[A]], i64 0 +; CHECK-NEXT: [[DOTI04:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I0]]) +; CHECK-NEXT: [[A_I1:%.*]] = extractelement <4 x float> [[A]], i64 1 +; CHECK-NEXT: [[DOTI13:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I1]]) +; CHECK-NEXT: [[A_I2:%.*]] = extractelement <4 x float> [[A]], i64 2 +; CHECK-NEXT: [[DOTI22:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I2]]) +; CHECK-NEXT: [[A_I3:%.*]] = extractelement <4 x float> [[A]], i64 3 +; CHECK-NEXT: [[DOTI31:%.*]] = call float @dx.op.unary.f32(i32 22, float [[A_I3]]) +; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x float> poison, float [[DOTI04]], i64 0 +; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x float> [[DOTUPTO0]], float [[DOTI13]], i64 1 +; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x float> [[DOTUPTO1]], float [[DOTI22]], i64 2 +; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> [[DOTUPTO2]], float [[DOTI31]], i64 3 +; CHECK-NEXT: ret <4 x float> [[TMP0]] +; entry: - %a.addr = alloca half, align 2 - store half %a, ptr %a.addr, align 2 - %0 = load half, ptr %a.addr, align 2 - %dx.frac = call half @llvm.dx.frac.f16(half %0) - ret half %dx.frac + %2 = call <4 x float> @llvm.dx.frac.v4f32(<4 x float> %a) + ret <4 x float> %2 } + +declare half @llvm.dx.frac.f16(half) +declare float @llvm.dx.frac.f32(float) +declare <4 x float> @llvm.dx.frac.v4f32(<4 x float>) diff --git a/llvm/test/CodeGen/DirectX/llc-pipeline.ll b/llvm/test/CodeGen/DirectX/llc-pipeline.ll index 52bd891aee7d4..46326d6917587 100644 --- a/llvm/test/CodeGen/DirectX/llc-pipeline.ll +++ b/llvm/test/CodeGen/DirectX/llc-pipeline.ll @@ -5,6 +5,7 @@ ; CHECK-LABEL: Pass Arguments: ; CHECK-NEXT: Target Library Information +; CHECK-NEXT: Target Transform Information ; CHECK-NEXT: ModulePass Manager ; CHECK-NEXT: DXIL Intrinsic Expansion ; CHECK-NEXT: FunctionPass Manager diff --git a/llvm/test/CodeGen/DirectX/rsqrt.ll b/llvm/test/CodeGen/DirectX/rsqrt.ll index 054c84483ef82..26b22e19635af 100644 --- a/llvm/test/CodeGen/DirectX/rsqrt.ll +++ b/llvm/test/CodeGen/DirectX/rsqrt.ll @@ -1,28 +1,56 @@ -; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -scalarizer -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s ; Make sure dxil operation function calls for rsqrt are generated for float and half. ; CHECK-LABEL: rsqrt_float -; CHECK: call float @dx.op.unary.f32(i32 25, float %{{.*}}) define noundef float @rsqrt_float(float noundef %a) { +; CHECK-SAME: float noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[DX_RSQRT1:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A]]) +; CHECK-NEXT: ret float [[DX_RSQRT1]] +; entry: - %a.addr = alloca float, align 4 - store float %a, ptr %a.addr, align 4 - %0 = load float, ptr %a.addr, align 4 - %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %0) + %dx.rsqrt = call float @llvm.dx.rsqrt.f32(float %a) ret float %dx.rsqrt } ; CHECK-LABEL: rsqrt_half -; CHECK: call half @dx.op.unary.f16(i32 25, half %{{.*}}) define noundef half @rsqrt_half(half noundef %a) { +; CHECK-SAME: half noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[DX_RSQRT1:%.*]] = call half @dx.op.unary.f16(i32 25, half [[A]]) +; CHECK-NEXT: ret half [[DX_RSQRT1]] +; entry: - %a.addr = alloca half, align 2 - store half %a, ptr %a.addr, align 2 - %0 = load half, ptr %a.addr, align 2 - %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %0) + %dx.rsqrt = call half @llvm.dx.rsqrt.f16(half %a) ret half %dx.rsqrt } +define noundef <4 x float> @rsqrt_float4(<4 x float> noundef %a) #0 { +; CHECK-LABEL: define noundef <4 x float> @rsqrt_float4( +; CHECK-SAME: <4 x float> noundef [[A:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[A_I0:%.*]] = extractelement <4 x float> [[A]], i64 0 +; CHECK-NEXT: [[DOTI04:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I0]]) +; CHECK-NEXT: [[A_I1:%.*]] = extractelement <4 x float> [[A]], i64 1 +; CHECK-NEXT: [[DOTI13:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I1]]) +; CHECK-NEXT: [[A_I2:%.*]] = extractelement <4 x float> [[A]], i64 2 +; CHECK-NEXT: [[DOTI22:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I2]]) +; CHECK-NEXT: [[A_I3:%.*]] = extractelement <4 x float> [[A]], i64 3 +; CHECK-NEXT: [[DOTI31:%.*]] = call float @dx.op.unary.f32(i32 25, float [[A_I3]]) +; CHECK-NEXT: [[DOTUPTO0:%.*]] = insertelement <4 x float> poison, float [[DOTI04]], i64 0 +; CHECK-NEXT: [[DOTUPTO1:%.*]] = insertelement <4 x float> [[DOTUPTO0]], float [[DOTI13]], i64 1 +; CHECK-NEXT: [[DOTUPTO2:%.*]] = insertelement <4 x float> [[DOTUPTO1]], float [[DOTI22]], i64 2 +; CHECK-NEXT: [[TMP0:%.*]] = insertelement <4 x float> [[DOTUPTO2]], float [[DOTI31]], i64 3 +; CHECK-NEXT: ret <4 x float> [[TMP0]] +; +entry: + %2 = call <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float> %a) + ret <4 x float> %2 +} + + declare half @llvm.dx.rsqrt.f16(half) declare float @llvm.dx.rsqrt.f32(float) +declare <4 x float> @llvm.dx.rsqrt.v4f32(<4 x float>)