Skip to content

[DAG] Failure to fold select(x, sub(x, c), m) -> sub(x, and(c,m)) #66101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
RKSimon opened this issue Sep 12, 2023 · 12 comments · Fixed by #83640
Closed

[DAG] Failure to fold select(x, sub(x, c), m) -> sub(x, and(c,m)) #66101

RKSimon opened this issue Sep 12, 2023 · 12 comments · Fixed by #83640
Assignees
Labels

Comments

@RKSimon
Copy link
Collaborator

RKSimon commented Sep 12, 2023

https://godbolt.org/z/a1PczEM8a

If we're selecting a subtraction with a non-constant we fold the select into an and:

#include <x86intrin.h>
auto masked_select(__m128i a, __m128i b, __m128i x, __m128i y) {
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2))
  pcmpgtd %xmm3, %xmm2
  pand %xmm1, %xmm2
  psubd %xmm2, %xmm0
  retq

But for constants this fails, which on x86 can result in a BLENDV instruction, which is never faster than an AND

#include <x86intrin.h>
auto masked_select_const(__m128i a, __m128i x, __m128i y) {
    __m128i b = _mm_set1_epi32(24);
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2))
  movdqa %xmm0, %xmm3
  movdqa .LCPI3_0(%rip), %xmm4 # xmm4 = [4294967272,4294967272,4294967272,4294967272]
  paddd %xmm0, %xmm4
  pcmpgtd %xmm2, %xmm1
  movdqa %xmm1, %xmm0
  blendvps %xmm0, %xmm4, %xmm3
  movaps %xmm3, %xmm0
  retq
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2023

@llvm/issue-subscribers-backend-x86

https://godbolt.org/z/a1PczEM8a

If we're selecting a subtracting a non-constant we fold the select into a and:

#include <x86intrin.h>
auto masked_select(__m128i a, __m128i b, __m128i x, __m128i y) {
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select(long long __vector(2), long long __vector(2), long long __vector(2), long long __vector(2))
  pcmpgtd %xmm3, %xmm2
  pand %xmm1, %xmm2
  psubd %xmm2, %xmm0
  retq

But for constants this fails, which on x86 can result in a BLENDV instruction, which is never faster than a AND

#include <x86intrin.h>
auto masked_select_const(__m128i a, __m128i x, __m128i y) {
    __m128i b = _mm_set1_epi32(24);
    return _mm_blendv_epi8(a, _mm_sub_epi32(a, b), _mm_cmpgt_epi32(x,y));
}
masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2)): # @masked_select_const(long long __vector(2), long long __vector(2), long long __vector(2))
  movdqa %xmm0, %xmm3
  movdqa .LCPI3_0(%rip), %xmm4 # xmm4 = [4294967272,4294967272,4294967272,4294967272]
  paddd %xmm0, %xmm4
  pcmpgtd %xmm2, %xmm1
  movdqa %xmm1, %xmm0
  blendvps %xmm0, %xmm4, %xmm3
  movaps %xmm3, %xmm0
  retq

@Endilll Endilll removed the new issue label Sep 15, 2023
@RKSimon
Copy link
Collaborator Author

RKSimon commented Oct 1, 2023

CC @elhewaty

@elhewaty
Copy link
Member

elhewaty commented Oct 1, 2023

assign me, please.

@elhewaty
Copy link
Member

elhewaty commented Oct 3, 2023

@RKSimon Is there any source I can use to understand DAG internals.

@RKSimon
Copy link
Collaborator Author

RKSimon commented Oct 3, 2023

I'd start by seeing whats the difference between the IR being fed to DAG from masked_select vs masked_select_const - you will probably need to remove a lot of unnecessary bitcasts. Then step through the DAGCombine stages of running llc in a debugger - add breakpoints to the start of visitADD/visitSUB/visitVSELECT and see whats happening.

You can also use "llc --debug" (using a debug assertion build) to dump out everything llc has done: https://rust.godbolt.org/z/szYv5G8n9

@elhewaty
Copy link
Member

elhewaty commented Feb 7, 2024

Hello @RKSimon.

// select X, sub(X, C), m --> sub (X, and(C, m))
  if (N1.getOpcode() == ISD::SUB && N1.getOperand(0) == N0 && N1.hasOneUse()) {
    if (dyn_cast<ConstantSDNode>(N1.getOperand(1)))
      return DAG.getNode(ISD::SUB, DL, N0.getValueType(), N0,
                         DAG.getNode(ISD::AND, DL, N2.getValueType(),
                                     N1.getOperand(1), N2));
  }

Here's what reached so far, I tried to match a pattern in visitSELECT function.
is this logic correct?

@RKSimon
Copy link
Collaborator Author

RKSimon commented Feb 7, 2024

Yes, that looks about right - you should use isConstantIntBuildVectorOrConstantInt instead of dyn_cast<ConstantSDNode> so it can match vector constant as well

@RKSimon
Copy link
Collaborator Author

RKSimon commented Feb 7, 2024

Also, you need to sort out argument order (sorry when I reported this I was thinking _mm_blendv_epi8 order not select IR order)

@elhewaty
Copy link
Member

elhewaty commented Feb 8, 2024

@RKSimon, I used the following test case:

define <2 x i64> @masked_select_const(<2 x i64> %a, <2 x i64> %x, <2 x i64> %y) {
  %bit_a = bitcast <2 x i64> %a to <4 x i32>
  %sub.i = add <4 x i32> %bit_a, <i32 -24, i32 -24, i32 -24, i32 -24>
  %bit_x = bitcast <2 x i64> %x to <4 x i32>
  %bit_y = bitcast <2 x i64> %y to <4 x i32>
  %cmp.i = icmp sgt <4 x i32> %bit_x, %bit_y
  %sel = select <4 x i1> %cmp.i, <4 x i32> %sub.i, <4 x i32> %bit_a
  %bit_sel = bitcast <4 x i32> %sel to <2 x i64>
  ret <2 x i64> %bit_sel
}

The following code can't match the select

// select m, sub(X, C), X --> sub (X, and(C, m))
  if (N1.getOpcode() == ISD::SUB && N1.getOperand(0) == N2 && N1->hasOneUse() &&
      DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1))) {
    return DAG.getNode(ISD::SUB, DL, N1.getValueType(), N2,
                       DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1),
                                   N0));
  }

Any hint?

@elhewaty
Copy link
Member

@RKSimon ping

@RKSimon
Copy link
Collaborator Author

RKSimon commented Feb 19, 2024

Sorry I missed your ping.

In many cases DAG will try to fold (sub x, c) -> (add x, -c) so you will need to do this in terms of ADD:

  // select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
  if (N1.getOpcode() == ISD::ADD && N1.getOperand(0) == N2 && N1->hasOneUse() &&
      DAG.isConstantIntBuildVectorOrConstantInt(N1.getOperand(1)) && 
      N0.getScalarValueSizeInBits() == N1.getScalarValueSizeInBits()) {
    return DAG.getNode(ISD::ADD, DL, N1.getValueType(), N2,
                       DAG.getNode(ISD::AND, DL, N0.getValueType(), N1.getOperand(1),
                                   N0));
  }

Note you need to ensure the N0 condition is the same width as the True/False operands otherwise you might affect targets with predicate mask types (AVX512 etc).

@RKSimon
Copy link
Collaborator Author

RKSimon commented Feb 20, 2024

@elhewaty Do you have a PR (draft or active) anywhere with your work so far?

RKSimon pushed a commit that referenced this issue Mar 5, 2024
…83640)

- [DAG][X86] Add tests for Folding select m, add(X, C), X --> add (X, and(C, m))(NFC)
- [DAG][X86] Fold select (sext m), (add X, C), X --> (add X, (and C, (sext m))))
- Fixes: #66101
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
5 participants