Skip to content

Commit 6418461

Browse files
committed
[HLSL] Implement SV_GroupID semantic
Support SV_GroupID attribute. Translate it into dx.group.id in clang codeGen. Fixes: #70120
1 parent 5098b56 commit 6418461

File tree

11 files changed

+116
-6
lines changed

11 files changed

+116
-6
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4587,6 +4587,13 @@ def HLSLNumThreads: InheritableAttr {
45874587
let Documentation = [NumThreadsDocs];
45884588
}
45894589

4590+
def HLSLSV_GroupID: HLSLAnnotationAttr {
4591+
let Spellings = [HLSLAnnotation<"SV_GroupID">];
4592+
let Subjects = SubjectList<[ParmVar, Field]>;
4593+
let LangOpts = [HLSL];
4594+
let Documentation = [HLSLSV_GroupIDDocs];
4595+
}
4596+
45904597
def HLSLSV_GroupIndex: HLSLAnnotationAttr {
45914598
let Spellings = [HLSLAnnotation<"SV_GroupIndex">];
45924599
let Subjects = SubjectList<[ParmVar, GlobalVar]>;

clang/include/clang/Basic/AttrDocs.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7816,6 +7816,17 @@ randomized.
78167816
}];
78177817
}
78187818

7819+
def HLSLSV_GroupIDDocs : Documentation {
7820+
let Category = DocCatFunction;
7821+
let Content = [{
7822+
The ``SV_GroupID`` semantic, when applied to an input parameter, specifies a
7823+
data binding to map the group id to the specified parameter. This attribute is
7824+
only supported in compute shaders.
7825+
7826+
The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid
7827+
}];
7828+
}
7829+
78197830
def HLSLSV_GroupIndexDocs : Documentation {
78207831
let Category = DocCatFunction;
78217832
let Content = [{

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase {
119119
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
120120
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
121121
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);
122+
void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL);
122123
void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL);
123124
void handleShaderAttr(Decl *D, const ParsedAttr &AL);
124125
void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL);

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
389389
CGM.getIntrinsic(getThreadIdIntrinsic());
390390
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
391391
}
392+
if (D.hasAttr<HLSLSV_GroupIDAttr>()) {
393+
llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id);
394+
return buildVectorInput(B, GroupIDIntrinsic, Ty);
395+
}
392396
assert(false && "Unhandled parameter attribute");
393397
return nullptr;
394398
}

clang/lib/Parse/ParseHLSL.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs,
280280
case ParsedAttr::UnknownAttribute:
281281
Diag(Loc, diag::err_unknown_hlsl_semantic) << II;
282282
return;
283+
case ParsedAttr::AT_HLSLSV_GroupID:
283284
case ParsedAttr::AT_HLSLSV_GroupIndex:
284285
case ParsedAttr::AT_HLSLSV_DispatchThreadID:
285286
break;

clang/lib/Sema/SemaDeclAttr.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6990,6 +6990,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
69906990
case ParsedAttr::AT_HLSLWaveSize:
69916991
S.HLSL().handleWaveSizeAttr(D, AL);
69926992
break;
6993+
case ParsedAttr::AT_HLSLSV_GroupID:
6994+
S.HLSL().handleSV_GroupIDAttr(D, AL);
6995+
break;
69936996
case ParsedAttr::AT_HLSLSV_GroupIndex:
69946997
handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL);
69956998
break;

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation(
434434
switch (AnnotationAttr->getKind()) {
435435
case attr::HLSLSV_DispatchThreadID:
436436
case attr::HLSLSV_GroupIndex:
437+
case attr::HLSLSV_GroupID:
437438
if (ST == llvm::Triple::Compute)
438439
return;
439440
DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute});
@@ -764,7 +765,7 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) {
764765
D->addAttr(NewAttr);
765766
}
766767

767-
static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
768+
static bool isLegalTypeForHLSLSV_ThreadOrGroupID(QualType T) {
768769
if (!T->hasUnsignedIntegerRepresentation())
769770
return false;
770771
if (const auto *VT = T->getAs<VectorType>())
@@ -774,7 +775,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) {
774775

775776
void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
776777
auto *VD = cast<ValueDecl>(D);
777-
if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) {
778+
if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) {
778779
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
779780
<< AL << "uint/uint2/uint3";
780781
return;
@@ -784,6 +785,17 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) {
784785
HLSLSV_DispatchThreadIDAttr(getASTContext(), AL));
785786
}
786787

788+
void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) {
789+
auto *VD = cast<ValueDecl>(D);
790+
if (!isLegalTypeForHLSLSV_ThreadOrGroupID(VD->getType())) {
791+
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type)
792+
<< AL << "uint/uint2/uint3";
793+
return;
794+
}
795+
796+
D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL));
797+
}
798+
787799
void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) {
788800
if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) {
789801
Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s
2+
3+
// Make sure SV_GroupID translated into dx.group.id.
4+
5+
// CHECK: define void @foo()
6+
// CHECK: %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0)
7+
// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]])
8+
[shader("compute")]
9+
[numthreads(8,8,1)]
10+
void foo(uint Idx : SV_GroupID) {}
11+
12+
// CHECK: define void @bar()
13+
// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0)
14+
// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0
15+
// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1)
16+
// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1
17+
// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]])
18+
[shader("compute")]
19+
[numthreads(8,8,1)]
20+
void bar(uint2 Idx : SV_GroupID) {}
21+

clang/test/SemaHLSL/Semantics/entry_parameter.hlsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s
33

44
[numthreads(8,8,1)]
5-
// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
6-
// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
7-
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) {
8-
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)'
5+
// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}}
6+
// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}}
7+
// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}}
8+
void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) {
9+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)'
910
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int'
1011
// CHECK-NEXT: HLSLSV_GroupIndexAttr
1112
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint'
1213
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
14+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint'
15+
// CHECK-NEXT: HLSLSV_GroupIDAttr
1316
}

clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,25 @@ struct ST2 {
2727
static uint X : SV_DispatchThreadID;
2828
uint s : SV_DispatchThreadID;
2929
};
30+
31+
[numthreads(8,8,1)]
32+
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
33+
void CSMain_GID(float ID : SV_GroupID) {
34+
}
35+
36+
[numthreads(8,8,1)]
37+
// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}}
38+
void CSMain2_GID(ST GID : SV_GroupID) {
39+
40+
}
41+
42+
void foo_GID() {
43+
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
44+
uint GIS : SV_GroupID;
45+
}
46+
47+
struct ST2_GID {
48+
// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}}
49+
static uint GID : SV_GroupID;
50+
uint s_gid : SV_GroupID;
51+
};

clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) {
2424
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3'
2525
// CHECK-NEXT: HLSLSV_DispatchThreadIDAttr
2626
}
27+
28+
[numthreads(8,8,1)]
29+
void CSMain_GID(uint ID : SV_GroupID) {
30+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)'
31+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint'
32+
// CHECK-NEXT: HLSLSV_GroupIDAttr
33+
}
34+
[numthreads(8,8,1)]
35+
void CSMain1_GID(uint2 ID : SV_GroupID) {
36+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)'
37+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2'
38+
// CHECK-NEXT: HLSLSV_GroupIDAttr
39+
}
40+
[numthreads(8,8,1)]
41+
void CSMain2_GID(uint3 ID : SV_GroupID) {
42+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)'
43+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3'
44+
// CHECK-NEXT: HLSLSV_GroupIDAttr
45+
}
46+
[numthreads(8,8,1)]
47+
void CSMain3_GID(uint3 : SV_GroupID) {
48+
// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)'
49+
// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3'
50+
// CHECK-NEXT: HLSLSV_GroupIDAttr
51+
}

0 commit comments

Comments
 (0)