Skip to content

[HLSL] Re-implement countbits with the correct return type #113189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 73 additions & 51 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -723,66 +723,88 @@ float4 cosh(float4);

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int16_t countbits(int16_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed the earlier PR where this was added, but I don't know how signed integers work here. We only support unsigned in HLSL. Is that a change for clang?
Missing tests as well, if so.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I asked Justin about this he pointed to this link:
https://github.com/microsoft/DirectXShaderCompiler/blob/main/utils/hct/gen_intrin_main.txt#L114
and suggested it meant both signed and unsigned were supported.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, of course our documentation is wrong. countbits of signed is kind of illogical, but definitely should add tests regardless.

const inline uint countbits(int16_t x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int16_t2 countbits(int16_t2);
const inline uint2 countbits(int16_t2 x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int16_t3 countbits(int16_t3);
const inline uint3 countbits(int16_t3 x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int16_t4 countbits(int16_t4);
const inline uint4 countbits(int16_t4 x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint16_t countbits(uint16_t);
const inline uint countbits(uint16_t x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint16_t2 countbits(uint16_t2);
const inline uint2 countbits(uint16_t2 x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint16_t3 countbits(uint16_t3);
const inline uint3 countbits(uint16_t3 x) {
return __builtin_elementwise_popcount(x);
}
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint16_t4 countbits(uint16_t4);
const inline uint4 countbits(uint16_t4 x) {
return __builtin_elementwise_popcount(x);
}
#endif

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int countbits(int);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int2 countbits(int2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int3 countbits(int3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int4 countbits(int4);

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint countbits(uint);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint2 countbits(uint2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint3 countbits(uint3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint4 countbits(uint4);

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int64_t countbits(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int64_t2 countbits(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int64_t3 countbits(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
int64_t4 countbits(int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint64_t countbits(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint64_t2 countbits(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint64_t3 countbits(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount)
uint64_t4 countbits(uint64_t4);
const inline uint countbits(int x) { return __builtin_elementwise_popcount(x); }
const inline uint2 countbits(int2 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint3 countbits(int3 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint4 countbits(int4 x) {
return __builtin_elementwise_popcount(x);
}

const inline uint countbits(uint x) {
return __builtin_elementwise_popcount(x);
}
const inline uint2 countbits(uint2 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint3 countbits(uint3 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint4 countbits(uint4 x) {
return __builtin_elementwise_popcount(x);
}

const inline uint countbits(int64_t x) {
return __builtin_elementwise_popcount(x);
}
const inline uint2 countbits(int64_t2 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint3 countbits(int64_t3 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint4 countbits(int64_t4 x) {
return __builtin_elementwise_popcount(x);
}

const inline uint countbits(uint64_t x) {
return __builtin_elementwise_popcount(x);
}
const inline uint2 countbits(uint64_t2 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint3 countbits(uint64_t3 x) {
return __builtin_elementwise_popcount(x);
}
const inline uint4 countbits(uint64_t4 x) {
return __builtin_elementwise_popcount(x);
}

//===----------------------------------------------------------------------===//
// degrees builtins
Expand Down
62 changes: 45 additions & 17 deletions clang/test/CodeGenHLSL/builtins/countbits.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,51 @@

#ifdef __HLSL_ENABLE_16_BIT
// CHECK-LABEL: test_countbits_ushort
// CHECK: call i16 @llvm.ctpop.i16
uint16_t test_countbits_ushort(uint16_t p0)
// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
// CHECK-NEXT: zext i16 [[A]] to i32
uint test_countbits_ushort(uint16_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_short
// CHECK: [[A:%.*]] = call i16 @llvm.ctpop.i16
// CHECK-NEXT: sext i16 [[A]] to i32
uint test_countbits_short(int16_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort2
// CHECK: call <2 x i16> @llvm.ctpop.v2i16
uint16_t2 test_countbits_ushort2(uint16_t2 p0)
// CHECK: [[A:%.*]] = call <2 x i16> @llvm.ctpop.v2i16
// CHECK-NEXT: zext <2 x i16> [[A]] to <2 x i32>
uint2 test_countbits_ushort2(uint16_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort3
// CHECK: call <3 x i16> @llvm.ctpop.v3i16
uint16_t3 test_countbits_ushort3(uint16_t3 p0)
// CHECK: [[A:%.*]] = call <3 x i16> @llvm.ctpop.v3i16
// CHECK-NEXT: zext <3 x i16> [[A]] to <3 x i32>
uint3 test_countbits_ushort3(uint16_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_ushort4
// CHECK: call <4 x i16> @llvm.ctpop.v4i16
uint16_t4 test_countbits_ushort4(uint16_t4 p0)
// CHECK: [[A:%.*]] = call <4 x i16> @llvm.ctpop.v4i16
// CHECK-NEXT: zext <4 x i16> [[A]] to <4 x i32>
uint4 test_countbits_ushort4(uint16_t4 p0)
{
return countbits(p0);
}
#endif

// CHECK-LABEL: test_countbits_uint
// CHECK: call i32 @llvm.ctpop.i32
int test_countbits_uint(uint p0)
uint test_countbits_uint(uint p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_int
// CHECK: call i32 @llvm.ctpop.i32
uint test_countbits_int(int p0)
{
return countbits(p0);
}
Expand All @@ -55,26 +72,37 @@ uint4 test_countbits_uint4(uint4 p0)
}

// CHECK-LABEL: test_countbits_long
// CHECK: call i64 @llvm.ctpop.i64
uint64_t test_countbits_long(uint64_t p0)
// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
// CHECK-NEXT: trunc i64 [[A]] to i32
uint test_countbits_long(uint64_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_slong
// CHECK: [[A:%.*]] = call i64 @llvm.ctpop.i64
// CHECK-NEXT: trunc i64 [[A]] to i32
uint test_countbits_slong(int64_t p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long2
// CHECK: call <2 x i64> @llvm.ctpop.v2i64
uint64_t2 test_countbits_long2(uint64_t2 p0)
// CHECK: [[A:%.*]] = call <2 x i64> @llvm.ctpop.v2i64
// CHECK-NEXT: trunc <2 x i64> [[A]] to <2 x i32>
uint2 test_countbits_long2(uint64_t2 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long3
// CHECK: call <3 x i64> @llvm.ctpop.v3i64
uint64_t3 test_countbits_long3(uint64_t3 p0)
// CHECK: [[A:%.*]] = call <3 x i64> @llvm.ctpop.v3i64
// CHECK-NEXT: trunc <3 x i64> [[A]] to <3 x i32>
uint3 test_countbits_long3(uint64_t3 p0)
{
return countbits(p0);
}
// CHECK-LABEL: test_countbits_long4
// CHECK: call <4 x i64> @llvm.ctpop.v4i64
uint64_t4 test_countbits_long4(uint64_t4 p0)
// CHECK: [[A:%.*]] = call <4 x i64> @llvm.ctpop.v4i64
// CHECK-NEXT: trunc <4 x i64> [[A]] to <4 x i32>
uint4 test_countbits_long4(uint64_t4 p0)
{
return countbits(p0);
}
14 changes: 5 additions & 9 deletions clang/test/SemaHLSL/BuiltIns/countbits-errors.hlsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
// RUN: %clang_cc1 -finclude-default-header
// -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only
// -disable-llvm-passes -verify -verify-ignore-unexpected
// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify -verify-ignore-unexpected


double test_int_builtin(double p0) {
Expand All @@ -9,13 +7,11 @@ double test_int_builtin(double p0) {
}

double2 test_int_builtin_2(double2 p0) {
return __builtin_elementwise_popcount(p0);
// expected-error@-1 {{1st argument must be a vector of integers
// (was 'double2' (aka 'vector<double, 2>'))}}
return countbits(p0);
// expected-error@-1 {{call to 'countbits' is ambiguous}}
}

double test_int_builtin_3(float p0) {
return __builtin_elementwise_popcount(p0);
// expected-error@-1 {{1st argument must be a vector of integers
// (was 'float')}}
return countbits(p0);
// expected-error@-1 {{call to 'countbits' is ambiguous}}
}
5 changes: 2 additions & 3 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -554,11 +554,10 @@ def Rbits : DXILOp<30, unary> {
let attributes = [Attributes<DXIL1_0, [ReadNone]>];
}

def CBits : DXILOp<31, unary> {
def CountBits : DXILOp<31, unaryBits> {
let Doc = "Returns the number of 1 bits in the specified value.";
let LLVMIntrinsic = int_ctpop;
let arguments = [OverloadTy];
let result = OverloadTy;
let result = Int32Ty;
let overloads =
[Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>];
let stages = [Stages<DXIL1_0, [all_stages]>];
Expand Down
70 changes: 70 additions & 0 deletions llvm/lib/Target/DirectX/DXILOpLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,73 @@ class OpLowerer {
});
}

[[nodiscard]] bool lowerCtpopToCountBits(Function &F) {
IRBuilder<> &IRB = OpBuilder.getIRB();
Type *Int32Ty = IRB.getInt32Ty();

return replaceFunction(F, [&](CallInst *CI) -> Error {
IRB.SetInsertPoint(CI);
SmallVector<Value *> Args;
Args.append(CI->arg_begin(), CI->arg_end());

Type *RetTy = Int32Ty;
Type *FRT = F.getReturnType();
if (const auto *VT = dyn_cast<VectorType>(FRT))
RetTy = VectorType::get(RetTy, VT);

Expected<CallInst *> OpCall = OpBuilder.tryCreateOp(
dxil::OpCode::CountBits, Args, CI->getName(), RetTy);
if (Error E = OpCall.takeError())
return E;

// If the result type is 32 bits we can do a direct replacement.
if (FRT->isIntOrIntVectorTy(32)) {
CI->replaceAllUsesWith(*OpCall);
CI->eraseFromParent();
return Error::success();
}

unsigned CastOp;
unsigned CastOp2;
if (FRT->isIntOrIntVectorTy(16)) {
CastOp = Instruction::ZExt;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to the signed vs unsigned: this is where things I think get confusing. int16_t would almost always be sign extended, not zero extended. But doing a sign extension here doesn't make sense as you'll count 16 extra bits for any negative int16_t.
I think this is why std::bitset basically ignores the type. It's a forcing function to the author to say "yes, I'm purposely counting a signed value now go away"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm being confusing - I am not saying you have the wrong code here, I'm just bringing it up that this is where my thought drifted towards the topic because of this zero extension (which I think is correct)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extension should only be on the return value which is always unsigned. Is this code wrong with that in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. This is overriding the cast op below. Is a zero/sign extension needed? I notice the return type is "int16_t".
Although I guess that's another question - should the return types all be 32 bits? i.e., "int_t"

Copy link
Contributor

@bfavela bfavela Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Disregard. I see you actually updated that correctly! (the whole point of this PR haha)

CastOp2 = Instruction::SExt;
} else { // must be 64 bits
assert(FRT->isIntOrIntVectorTy(64) &&
"Currently only lowering 16, 32, or 64 bit ctpop to CountBits \
is supported.");
CastOp = Instruction::Trunc;
CastOp2 = Instruction::Trunc;
}

// It is correct to replace the ctpop with the dxil op and
// remove all casts to i32
bool NeedsCast = false;
for (User *User : make_early_inc_range(CI->users())) {
Instruction *I = dyn_cast<Instruction>(User);
if (I && (I->getOpcode() == CastOp || I->getOpcode() == CastOp2) &&
I->getType() == RetTy) {
I->replaceAllUsesWith(*OpCall);
I->eraseFromParent();
} else
NeedsCast = true;
}

// It is correct to replace a ctpop with the dxil op and
// a cast from i32 to the return type of the ctpop
// the cast is emitted here if there is a non-cast to i32
// instr which uses the ctpop
if (NeedsCast) {
Value *Cast =
IRB.CreateZExtOrTrunc(*OpCall, F.getReturnType(), "ctpop.cast");
CI->replaceAllUsesWith(Cast);
}

CI->eraseFromParent();
return Error::success();
});
}

bool lowerIntrinsics() {
bool Updated = false;
bool HasErrors = false;
Expand Down Expand Up @@ -543,6 +610,9 @@ class OpLowerer {
return replaceSplitDoubleCallUsages(CI, Op);
});
break;
case Intrinsic::ctpop:
HasErrors |= lowerCtpopToCountBits(F);
break;
}
Updated = true;
}
Expand Down
Loading
Loading