Skip to content

[InstCombine] Fix miscompilation in PR83947 #83993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2024
Merged

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Mar 5, 2024

// If the mask is all zeros, a scatter does nothing.
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}

Comment from @topperc:

This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value.

Fixes #83947.

@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-analysis

Author: Yingwei Zheng (dtcxzyw)

Changes

// If the mask is all zeros, a scatter does nothing.
if (ConstMask->isNullValue())
return eraseInstFromFunction(II);
// Vector splat address -> scalar store
if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
// scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
StoreInst *S =
new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
S->copyMetadata(II);
return S;
}

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its a splat and not provably all 0s. The mask is a constexpr that includes the address of the global variable. We can't resolve the constant expression to an exact value.

Fixes #83947.


Full diff: https://github.com/llvm/llvm-project/pull/83993.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Analysis/VectorUtils.h (+5)
  • (modified) llvm/lib/Analysis/VectorUtils.cpp (+25)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp (+8-5)
  • (added) llvm/test/Transforms/InstCombine/pr83947.ll (+67)
diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h
index 7a92e62b53c53d..c6eb66cc9660ca 100644
--- a/llvm/include/llvm/Analysis/VectorUtils.h
+++ b/llvm/include/llvm/Analysis/VectorUtils.h
@@ -406,6 +406,11 @@ bool maskIsAllZeroOrUndef(Value *Mask);
 /// lanes can be assumed active.
 bool maskIsAllOneOrUndef(Value *Mask);
 
+/// Given a mask vector of i1, Return true if any of the elements of this
+/// predicate mask are known to be true or undef.  That is, return true if at
+/// least one lane can be assumed active.
+bool maskContainsAllOneOrUndef(Value *Mask);
+
 /// Given a mask vector of the form <Y x i1>, return an APInt (of bitwidth Y)
 /// for each lane which may be active.
 APInt possiblyDemandedEltsInMask(Value *Mask);
diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp
index 73facc76a92b2c..bf7bc0ba84a033 100644
--- a/llvm/lib/Analysis/VectorUtils.cpp
+++ b/llvm/lib/Analysis/VectorUtils.cpp
@@ -1012,6 +1012,31 @@ bool llvm::maskIsAllOneOrUndef(Value *Mask) {
   return true;
 }
 
+bool llvm::maskContainsAllOneOrUndef(Value *Mask) {
+  assert(isa<VectorType>(Mask->getType()) &&
+         isa<IntegerType>(Mask->getType()->getScalarType()) &&
+         cast<IntegerType>(Mask->getType()->getScalarType())->getBitWidth() ==
+             1 &&
+         "Mask must be a vector of i1");
+
+  auto *ConstMask = dyn_cast<Constant>(Mask);
+  if (!ConstMask)
+    return false;
+  if (ConstMask->isAllOnesValue() || isa<UndefValue>(ConstMask))
+    return true;
+  if (isa<ScalableVectorType>(ConstMask->getType()))
+    return false;
+  for (unsigned
+           I = 0,
+           E = cast<FixedVectorType>(ConstMask->getType())->getNumElements();
+       I != E; ++I) {
+    if (auto *MaskElt = ConstMask->getAggregateElement(I))
+      if (MaskElt->isAllOnesValue() || isa<UndefValue>(MaskElt))
+        return true;
+  }
+  return false;
+}
+
 /// TODO: This is a lot like known bits, but for
 /// vectors.  Is there something we can common this with?
 APInt llvm::possiblyDemandedEltsInMask(Value *Mask) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 50c0f9a913f32a..89ca40f1c20d53 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -399,11 +399,14 @@ Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) {
   if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) {
     // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr
     if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) {
-      Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
-      StoreInst *S =
-          new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment);
-      S->copyMetadata(II);
-      return S;
+      if (maskContainsAllOneOrUndef(ConstMask)) {
+        Align Alignment =
+            cast<ConstantInt>(II.getArgOperand(2))->getAlignValue();
+        StoreInst *S = new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false,
+                                     Alignment);
+        S->copyMetadata(II);
+        return S;
+      }
     }
     // scatter(vector, splat(ptr), splat(true)) -> store extract(vector,
     // lastlane), ptr
diff --git a/llvm/test/Transforms/InstCombine/pr83947.ll b/llvm/test/Transforms/InstCombine/pr83947.ll
new file mode 100644
index 00000000000000..c1d601ff637183
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr83947.ll
@@ -0,0 +1,67 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 4
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+@c = global i32 0, align 4
+@b = global i32 0, align 4
+
+define void @masked_scatter1() {
+; CHECK-LABEL: define void @masked_scatter1() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> shufflevector (<vscale x 4 x ptr> insertelement (<vscale x 4 x ptr> poison, ptr @c, i64 0), <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer), i32 4, <vscale x 4 x i1> shufflevector (<vscale x 4 x i1> insertelement (<vscale x 4 x i1> poison, i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i64 0), <vscale x 4 x i1> poison, <vscale x 4 x i32> zeroinitializer))
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> zeroinitializer, <vscale x 4 x ptr> splat (ptr @c), i32 4, <vscale x 4 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}
+
+define void @masked_scatter2() {
+; CHECK-LABEL: define void @masked_scatter2() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 true))
+  ret void
+}
+
+define void @masked_scatter3() {
+; CHECK-LABEL: define void @masked_scatter3() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> undef)
+  ret void
+}
+
+define void @masked_scatter4() {
+; CHECK-LABEL: define void @masked_scatter4() {
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 false))
+  ret void
+}
+
+define void @masked_scatter5() {
+; CHECK-LABEL: define void @masked_scatter5() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 true, i1 false>)
+  ret void
+}
+
+define void @masked_scatter6() {
+; CHECK-LABEL: define void @masked_scatter6() {
+; CHECK-NEXT:    store i32 0, ptr @c, align 4
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> <i1 undef, i1 false>)
+  ret void
+}
+
+define void @masked_scatter7() {
+; CHECK-LABEL: define void @masked_scatter7() {
+; CHECK-NEXT:    call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> <ptr @c, ptr @c>, i32 4, <2 x i1> <i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c), i1 icmp eq (ptr getelementptr inbounds (i32, ptr @b, i64 1), ptr @c)>)
+; CHECK-NEXT:    ret void
+;
+  call void @llvm.masked.scatter.v2i32.v2p0(<2 x i32> zeroinitializer, <2 x ptr> splat (ptr @c), i32 4, <2 x i1> splat (i1 icmp eq (ptr getelementptr (i32, ptr @b, i64 1), ptr @c)))
+  ret void
+}

Copy link
Contributor

@nikic nikic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, though the need to the new helper is a bit unfortunate...

@dtcxzyw dtcxzyw merged commit a1a590e into llvm:main Mar 5, 2024
@dtcxzyw dtcxzyw deleted the fix-pr83947 branch March 5, 2024 14:34
@dtcxzyw dtcxzyw added this to the LLVM 18.X Release milestone Mar 5, 2024
@dtcxzyw
Copy link
Member Author

dtcxzyw commented Mar 5, 2024

/cherry-pick a1a590e

llvmbot pushed a commit to llvmbot/llvm-project that referenced this pull request Mar 5, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.

(cherry picked from commit a1a590e)
@llvmbot
Copy link
Member

llvmbot commented Mar 5, 2024

/pull-request #84021

dtcxzyw added a commit to dtcxzyw/llvm-project that referenced this pull request Mar 7, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
tstellar pushed a commit to dtcxzyw/llvm-project that referenced this pull request Mar 13, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
@pointhex pointhex mentioned this pull request May 7, 2024
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Sep 9, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 10, 2024
https://github.com/llvm/llvm-project/blob/762f762504967efbe159db5c737154b989afc9bb/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp#L394-L407

Comment from @topperc:
> This transforms assumes the mask is a non-zero splat. We only know its
a splat and not provably all 0s. The mask is a constexpr that includes
the address of the global variable. We can't resolve the constant
expression to an exact value.

Fixes llvm#83947.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

[RISC-V] Miscompile at -O2
3 participants