diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp index 795ddf47c40da..7057cc1fd3024 100644 --- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp @@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics Type *deduceFunParamElementType(Function *F, unsigned OpIdx); Type *deduceFunParamElementType(Function *F, unsigned OpIdx, std::unordered_set &FVisited); + void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy, + CallInst *AssignCI); public: static char ID; @@ -475,10 +477,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper( if (DemangledName.length() > 0) DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName); auto AsArgIt = ResTypeByArg.find(DemangledName); - if (AsArgIt != ResTypeByArg.end()) { + if (AsArgIt != ResTypeByArg.end()) Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second), Visited, UnknownElemTypeI8); - } + else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF)) + Ty = KnownRetTy; } } @@ -808,6 +811,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I, CallInst *PtrCastI = B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); I->setOperand(OpIt.second, PtrCastI); + buildAssignPtr(B, KnownElemTy, PtrCastI); } } } @@ -1706,6 +1710,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) { return true; } +void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, + Type *KnownElemTy, + CallInst *AssignCI) { + updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy)); + IRBuilder<> B(CI->getContext()); + B.SetInsertPoint(*CI->getInsertionPointAfterDef()); + B.SetCurrentDebugLocation(CI->getDebugLoc()); + Type *OpTy = CI->getType(); + SmallVector Types = {OpTy, OpTy}; + SmallVector Args = {CI, buildMD(PoisonValue::get(KnownElemTy)), + B.getInt32(getPointerAddressSpace(OpTy))}; + CallInst *PtrCasted = + B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args); + SmallVector Users(CI->users()); + for (auto *U : Users) + if (U != AssignCI && U != PtrCasted) + U->replaceUsesOfWith(CI, PtrCasted); + buildAssignPtr(B, KnownElemTy, PtrCasted); +} + // Try to deduce a better type for pointers to untyped ptr. bool SPIRVEmitIntrinsics::postprocessTypes() { bool Changed = false; @@ -1717,6 +1741,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() { Type *KnownTy = GR->findDeducedElementType(*IB); if (!KnownTy || !AssignCI || !isa(AssignCI->getArgOperand(0))) continue; + // Try to improve the type deduced after all Functions are processed. + if (auto *CI = dyn_cast(*IB)) { + if (Function *CalledF = CI->getCalledFunction()) { + Type *RetElemTy = GR->findDeducedElementType(CalledF); + // Fix inconsistency between known type and function's return type. + if (RetElemTy && RetElemTy != KnownTy) { + replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI); + Changed = true; + continue; + } + } + } Instruction *I = cast(AssignCI->getArgOperand(0)); for (User *U : I->users()) { Instruction *Inst = dyn_cast(U); diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp index f1b10e264781f..83f4b92147a23 100644 --- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp @@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI, return {Reg, GetIdOp}; } +static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) { + MachineBasicBlock &MBB = *Def->getParent(); + MachineBasicBlock::iterator DefIt = + Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end(); + // Skip all the PHI and debug instructions. + while (DefIt != MBB.end() && + (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr())) + DefIt = std::next(DefIt); + MIB.setInsertPt(MBB, DefIt); +} + // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as // a dst of the definition, assign SPIRVType to both registers. If SpvType is // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty. @@ -350,11 +361,9 @@ namespace llvm { Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType, SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB, MachineRegisterInfo &MRI) { - MachineInstr *Def = MRI.getVRegDef(Reg); assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected."); - MIB.setInsertPt(*Def->getParent(), - (Def->getNextNode() ? Def->getNextNode()->getIterator() - : Def->getParent()->end())); + MachineInstr *Def = MRI.getVRegDef(Reg); + setInsertPtAfterDef(MIB, Def); SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB); Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg)); if (auto *RC = MRI.getRegClassOrNull(Reg)) { diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll new file mode 100644 index 0000000000000..6fa3f4e53cc59 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll @@ -0,0 +1,55 @@ +; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types. +; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI. + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: %[[#Char:]] = OpTypeInt 8 0 +; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]] +; CHECK: %[[#Int:]] = OpTypeInt 32 0 +; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]] +; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]] +; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]] +; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]] +; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]] +; CHECK: OpBranchConditional +; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]] +; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]] + +define void @f0(ptr %arg) { +entry: + ret void +} + +define ptr @f1() { +entry: + %p = alloca i8 + store i8 8, ptr %p + ret ptr %p +} + +define ptr @f2() { +entry: + %p = alloca i32 + store i32 32, ptr %p + ret ptr %p +} + +define ptr @foo(i1 %arg) { +entry: + %r1 = tail call ptr @f1() + %r2 = tail call ptr @f2() + br i1 %arg, label %l1, label %l2 + +l1: + br label %exit + +l2: + br label %exit + +exit: + %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ] + %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ] + tail call void @f0(ptr %ret) + ret ptr %ret2 +} diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll new file mode 100644 index 0000000000000..4fbaae2556730 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll @@ -0,0 +1,53 @@ +; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types. +; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI. + +; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: %[[#Char:]] = OpTypeInt 8 0 +; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]] +; CHECK: %[[#Int:]] = OpTypeInt 32 0 +; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]] +; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]] +; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]] +; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]] +; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]] +; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]] + +define ptr @foo(i1 %arg) { +entry: + %r1 = tail call ptr @f1() + %r2 = tail call ptr @f2() + br i1 %arg, label %l1, label %l2 + +l1: + br label %exit + +l2: + br label %exit + +exit: + %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ] + %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ] + tail call void @f0(ptr %ret) + ret ptr %ret2 +} + +define void @f0(ptr %arg) { +entry: + ret void +} + +define ptr @f1() { +entry: + %p = alloca i8 + store i8 8, ptr %p + ret ptr %p +} + +define ptr @f2() { +entry: + %p = alloca i32 + store i32 32, ptr %p + ret ptr %p +}