Skip to content

[SPIR-V] Fix inconsistency between previously deduced element type of a pointer and function's return type #109660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR improves type inference and fixes inconsistency between previously deduced element type of a pointer and function's return type. It fixes #109401 by ensuring that OpPhi is consistent with respect to operand types.

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR improves type inference and fixes inconsistency between previously deduced element type of a pointer and function's return type. It fixes #109401 by ensuring that OpPhi is consistent with respect to operand types.


Full diff: https://github.com/llvm/llvm-project/pull/109660.diff

4 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp (+38-2)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+13-4)
  • (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll (+55)
  • (added) llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll (+53)
diff --git a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
index 795ddf47c40dab..7057cc1fd30242 100644
--- a/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
@@ -144,6 +144,8 @@ class SPIRVEmitIntrinsics
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx);
   Type *deduceFunParamElementType(Function *F, unsigned OpIdx,
                                   std::unordered_set<Function *> &FVisited);
+  void replaceWithPtrcasted(Instruction *CI, Type *NewElemTy, Type *KnownElemTy,
+                            CallInst *AssignCI);
 
 public:
   static char ID;
@@ -475,10 +477,11 @@ Type *SPIRVEmitIntrinsics::deduceElementTypeHelper(
       if (DemangledName.length() > 0)
         DemangledName = SPIRV::lookupBuiltinNameHelper(DemangledName);
       auto AsArgIt = ResTypeByArg.find(DemangledName);
-      if (AsArgIt != ResTypeByArg.end()) {
+      if (AsArgIt != ResTypeByArg.end())
         Ty = deduceElementTypeHelper(CI->getArgOperand(AsArgIt->second),
                                      Visited, UnknownElemTypeI8);
-      }
+      else if (Type *KnownRetTy = GR->findDeducedElementType(CalledF))
+        Ty = KnownRetTy;
     }
   }
 
@@ -808,6 +811,7 @@ void SPIRVEmitIntrinsics::deduceOperandElementType(Instruction *I,
       CallInst *PtrCastI =
           B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
       I->setOperand(OpIt.second, PtrCastI);
+      buildAssignPtr(B, KnownElemTy, PtrCastI);
     }
   }
 }
@@ -1706,6 +1710,26 @@ bool SPIRVEmitIntrinsics::runOnFunction(Function &Func) {
   return true;
 }
 
+void SPIRVEmitIntrinsics::replaceWithPtrcasted(Instruction *CI, Type *NewElemTy,
+                                               Type *KnownElemTy,
+                                               CallInst *AssignCI) {
+  updateAssignType(AssignCI, CI, PoisonValue::get(NewElemTy));
+  IRBuilder<> B(CI->getContext());
+  B.SetInsertPoint(*CI->getInsertionPointAfterDef());
+  B.SetCurrentDebugLocation(CI->getDebugLoc());
+  Type *OpTy = CI->getType();
+  SmallVector<Type *, 2> Types = {OpTy, OpTy};
+  SmallVector<Value *, 2> Args = {CI, buildMD(PoisonValue::get(KnownElemTy)),
+                                  B.getInt32(getPointerAddressSpace(OpTy))};
+  CallInst *PtrCasted =
+      B.CreateIntrinsic(Intrinsic::spv_ptrcast, {Types}, Args);
+  SmallVector<User *> Users(CI->users());
+  for (auto *U : Users)
+    if (U != AssignCI && U != PtrCasted)
+      U->replaceUsesOfWith(CI, PtrCasted);
+  buildAssignPtr(B, KnownElemTy, PtrCasted);
+}
+
 // Try to deduce a better type for pointers to untyped ptr.
 bool SPIRVEmitIntrinsics::postprocessTypes() {
   bool Changed = false;
@@ -1717,6 +1741,18 @@ bool SPIRVEmitIntrinsics::postprocessTypes() {
     Type *KnownTy = GR->findDeducedElementType(*IB);
     if (!KnownTy || !AssignCI || !isa<Instruction>(AssignCI->getArgOperand(0)))
       continue;
+    // Try to improve the type deduced after all Functions are processed.
+    if (auto *CI = dyn_cast<CallInst>(*IB)) {
+      if (Function *CalledF = CI->getCalledFunction()) {
+        Type *RetElemTy = GR->findDeducedElementType(CalledF);
+        // Fix inconsistency between known type and function's return type.
+        if (RetElemTy && RetElemTy != KnownTy) {
+          replaceWithPtrcasted(CI, RetElemTy, KnownTy, AssignCI);
+          Changed = true;
+          continue;
+        }
+      }
+    }
     Instruction *I = cast<Instruction>(AssignCI->getArgOperand(0));
     for (User *U : I->users()) {
       Instruction *Inst = dyn_cast<Instruction>(U);
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index f1b10e264781f2..83f4b92147a231 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -341,6 +341,17 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
   return {Reg, GetIdOp};
 }
 
+static void setInsertPtAfterDef(MachineIRBuilder &MIB, MachineInstr *Def) {
+  MachineBasicBlock &MBB = *Def->getParent();
+  MachineBasicBlock::iterator DefIt =
+      Def->getNextNode() ? Def->getNextNode()->getIterator() : MBB.end();
+  // Skip all the PHI and debug instructions.
+  while (DefIt != MBB.end() &&
+         (DefIt->isPHI() || DefIt->isDebugOrPseudoInstr()))
+    DefIt = std::next(DefIt);
+  MIB.setInsertPt(MBB, DefIt);
+}
+
 // Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
 // a dst of the definition, assign SPIRVType to both registers. If SpvType is
 // provided, use it as SPIRVType in ASSIGN_TYPE, otherwise create it from Ty.
@@ -350,11 +361,9 @@ namespace llvm {
 Register insertAssignInstr(Register Reg, Type *Ty, SPIRVType *SpvType,
                            SPIRVGlobalRegistry *GR, MachineIRBuilder &MIB,
                            MachineRegisterInfo &MRI) {
-  MachineInstr *Def = MRI.getVRegDef(Reg);
   assert((Ty || SpvType) && "Either LLVM or SPIRV type is expected.");
-  MIB.setInsertPt(*Def->getParent(),
-                  (Def->getNextNode() ? Def->getNextNode()->getIterator()
-                                      : Def->getParent()->end()));
+  MachineInstr *Def = MRI.getVRegDef(Reg);
+  setInsertPtAfterDef(MIB, Def);
   SpvType = SpvType ? SpvType : GR->getOrCreateSPIRVType(Ty, MIB);
   Register NewReg = MRI.createGenericVirtualRegister(MRI.getType(Reg));
   if (auto *RC = MRI.getRegClassOrNull(Reg)) {
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
new file mode 100644
index 00000000000000..6fa3f4e53cc598
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types-rev.ll
@@ -0,0 +1,55 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK-DAG: %[[#Casted1:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK-DAG: %[[#Casted2:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpBranchConditional
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted1]] %[[#]]
+; CHECK-DAG: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted2]] %[[#]]
+
+define void @f0(ptr %arg) {
+entry:
+  ret void
+}
+
+define ptr @f1() {
+entry:
+  %p = alloca i8
+  store i8 8, ptr %p
+  ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+  %p = alloca i32
+  store i32 32, ptr %p
+  ret ptr %p
+}
+
+define ptr @foo(i1 %arg) {
+entry:
+  %r1 = tail call ptr @f1()
+  %r2 = tail call ptr @f2()
+  br i1 %arg, label %l1, label %l2
+
+l1:
+  br label %exit
+
+l2:
+  br label %exit
+
+exit:
+  %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  tail call void @f0(ptr %ret)
+  ret ptr %ret2
+}
diff --git a/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
new file mode 100644
index 00000000000000..4fbaae25567300
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/pointers/phi-valid-operand-types.ll
@@ -0,0 +1,53 @@
+; The goal of the test case is to ensure that OpPhi is consistent with respect to operand types.
+; -verify-machineinstrs is not available due to mutually exclusive requirements for G_BITCAST and G_PHI.
+
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK: %[[#Char:]] = OpTypeInt 8 0
+; CHECK: %[[#PtrChar:]] = OpTypePointer Function %[[#Char]]
+; CHECK: %[[#Int:]] = OpTypeInt 32 0
+; CHECK: %[[#PtrInt:]] = OpTypePointer Function %[[#Int]]
+; CHECK: %[[#R1:]] = OpFunctionCall %[[#PtrChar]] %[[#]]
+; CHECK: %[[#R2:]] = OpFunctionCall %[[#PtrInt]] %[[#]]
+; CHECK: %[[#Casted:]] = OpBitcast %[[#PtrChar]] %[[#R2]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+; CHECK: OpPhi %[[#PtrChar]] %[[#R1]] %[[#]] %[[#Casted]] %[[#]]
+
+define ptr @foo(i1 %arg) {
+entry:
+  %r1 = tail call ptr @f1()
+  %r2 = tail call ptr @f2()
+  br i1 %arg, label %l1, label %l2
+
+l1:
+  br label %exit
+
+l2:
+  br label %exit
+
+exit:
+  %ret = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  %ret2 = phi ptr [ %r1, %l1 ], [ %r2, %l2 ]
+  tail call void @f0(ptr %ret)
+  ret ptr %ret2
+}
+
+define void @f0(ptr %arg) {
+entry:
+  ret void
+}
+
+define ptr @f1() {
+entry:
+  %p = alloca i8
+  store i8 8, ptr %p
+  ret ptr %p
+}
+
+define ptr @f2() {
+entry:
+  %p = alloca i32
+  store i32 32, ptr %p
+  ret ptr %p
+}

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 8bc8b84 into llvm:main Oct 1, 2024
11 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
… a pointer and function's return type (llvm#109660)

This PR improves type inference and fixes inconsistency between
previously deduced element type of a pointer and function's return type.
It fixes llvm#109401 by ensuring
that OpPhi is consistent with respect to operand types.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIR-V] Backend emits invalid types in OpPhi
4 participants