Skip to content

[HLSL] Add Increment/DecrementCounter methods to structured buffers #114148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Nov 23, 2024

Conversation

hekota
Copy link
Member

@hekota hekota commented Oct 29, 2024

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that enables adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Fixes #113513

Introduces `__builtin_hlsl_buffer_update_counter` clang buildin that is used to implement IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer. The builtin is translated to LLVM intrisics llvm.dx.bufferUpdateCounter/llvm.spv.bufferUpdateCounter.

Introduces `BuiltinTypeMethodBuilder` helper in `HLSLExternalSemaSource` that allows adding methods to builtin types
using the builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Fixes llvm#113513
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen IR generation bugs: mangling, exceptions, etc. backend:DirectX HLSL HLSL Language Support llvm:ir labels Oct 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-hlsl
@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-directx

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-clang-codegen

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-llvm-ir

Author: Helena Kotas (hekota)

Changes

Introduces __builtin_hlsl_buffer_update_counter clang buildin that is used to implement the IncrementCounter and DecrementCounter methods on RWStructuredBuffer and RasterizerOrderedStructuredBuffer (see Note).

The builtin is translated to LLVM intrisic llvm.dx.bufferUpdateCounter or llvm.spv.bufferUpdateCounter.

Introduces BuiltinTypeMethodBuilder helper in HLSLExternalSemaSource that allows adding methods to builtin types using builder pattern like this:

   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();

Note: RasterizerOrderedStructuredBuffer does not exist yet, it is being added in PR llvm/llvm-project#113648. After llvm/llvm-project#113648 is merged this PR will be updated to add Increment/DecrementCounter on this buffer type as well.

Fixes #113513


Patch is 25.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/114148.diff

12 Files Affected:

  • (modified) clang/include/clang/Basic/Builtins.td (+6-1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+8)
  • (modified) clang/lib/CodeGen/CGHLSLRuntime.h (+1)
  • (modified) clang/lib/Sema/HLSLExternalSemaSource.cpp (+247-31)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+4)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+41)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl (+25)
  • (added) clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-ps.hlsl (+29)
  • (added) clang/test/SemaHLSL/BuiltIns/buffer_update_counter-errors.hlsl (+22)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsSPIRV.td (+3)
diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td
index 90475a361bb8f8..72bc2d5e7df23e 100644
--- a/clang/include/clang/Basic/Builtins.td
+++ b/clang/include/clang/Basic/Builtins.td
@@ -4846,7 +4846,6 @@ def HLSLSaturate : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
-
 def HLSLSelect : LangBuiltin<"HLSL_LANG"> {
   let Spellings = ["__builtin_hlsl_select"];
   let Attributes = [NoThrow, Const];
@@ -4871,6 +4870,12 @@ def HLSLRadians : LangBuiltin<"HLSL_LANG"> {
   let Prototype = "void(...)";
 }
 
+def HLSLBufferUpdateCounter : LangBuiltin<"HLSL_LANG"> {
+  let Spellings = ["__builtin_hlsl_buffer_update_counter"];
+  let Attributes = [NoThrow, Const];
+  let Prototype = "uint32_t(...)";
+}
+
 // Builtins for XRay.
 def XRayCustomEvent : Builtin {
   let Spellings = ["__xray_customevent"];
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 8e4718008ece72..2aea6bb657578a 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -7256,6 +7256,8 @@ def err_typecheck_illegal_increment_decrement : Error<
   "cannot %select{decrement|increment}1 value of type %0">;
 def err_typecheck_expect_int : Error<
   "used type %0 where integer is required">;
+def err_typecheck_expect_hlsl_resource : Error<
+  "used type %0 where __hlsl_resource_t is required">;
 def err_typecheck_arithmetic_incomplete_or_sizeless_type : Error<
   "arithmetic on a pointer to %select{an incomplete|sizeless}0 type %1">;
 def err_typecheck_pointer_arith_function_type : Error<
@@ -12485,6 +12487,8 @@ def warn_attr_min_eq_max:  Warning<
 
 def err_hlsl_attribute_number_arguments_insufficient_shader_model: Error<
   "attribute %0 with %1 arguments requires shader model %2 or greater">;
+def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
+  "argument %0 must be constant integer 1 or -1">;
 
 // Layout randomization diagnostics.
 def err_non_designated_init_used : Error<
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index e2d03eff8ab4a0..71273de3400b17 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -18959,6 +18959,14 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: {
         CGM.getHLSLRuntime().getRadiansIntrinsic(), ArrayRef<Value *>{Op0},
         nullptr, "hlsl.radians");
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    Value *ResHandle = EmitScalarExpr(E->getArg(0));
+    Value *Offset = EmitScalarExpr(E->getArg(1));
+    return Builder.CreateIntrinsic(
+        /*ReturnType=*/Offset->getType(),
+        CGM.getHLSLRuntime().getBufferUpdateCounterIntrinsic(),
+        ArrayRef<Value *>{ResHandle, Offset}, nullptr);
+  }
   }
   return nullptr;
 }
diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h
index ff7df41b5c62e7..aac93dfc373ed4 100644
--- a/clang/lib/CodeGen/CGHLSLRuntime.h
+++ b/clang/lib/CodeGen/CGHLSLRuntime.h
@@ -93,6 +93,7 @@ class CGHLSLRuntime {
   GENERATE_HLSL_INTRINSIC_FUNCTION(WaveReadLaneAt, wave_readlane)
 
   GENERATE_HLSL_INTRINSIC_FUNCTION(CreateHandleFromBinding, handle_fromBinding)
+  GENERATE_HLSL_INTRINSIC_FUNCTION(BufferUpdateCounter, bufferUpdateCounter)
 
   //===----------------------------------------------------------------------===//
   // End of reserved area for HLSL intrinsic getters.
diff --git a/clang/lib/Sema/HLSLExternalSemaSource.cpp b/clang/lib/Sema/HLSLExternalSemaSource.cpp
index ce8564429b3802..24c3954b134c5f 100644
--- a/clang/lib/Sema/HLSLExternalSemaSource.cpp
+++ b/clang/lib/Sema/HLSLExternalSemaSource.cpp
@@ -22,12 +22,15 @@
 #include "clang/Sema/SemaHLSL.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Frontend/HLSL/HLSLResource.h"
+#include "llvm/Support/ErrorHandling.h"
 
 #include <functional>
 
 using namespace clang;
 using namespace llvm::hlsl;
 
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name);
+
 namespace {
 
 struct TemplateParameterListBuilder;
@@ -121,12 +124,8 @@ struct BuiltinTypeDeclBuilder {
     TypeSourceInfo *ElementTypeInfo = nullptr;
 
     QualType ElemTy = Ctx.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     ElementTypeInfo = Ctx.getTrivialTypeSourceInfo(ElemTy, SourceLocation());
 
     // add handle member with resource type attributes
@@ -145,25 +144,6 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
-  static DeclRefExpr *lookupBuiltinFunction(ASTContext &AST, Sema &S,
-                                            StringRef Name) {
-    IdentifierInfo &II = AST.Idents.get(Name, tok::TokenKind::identifier);
-    DeclarationNameInfo NameInfo =
-        DeclarationNameInfo(DeclarationName(&II), SourceLocation());
-    LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
-    // AllowBuiltinCreation is false but LookupDirect will create
-    // the builtin when searching the global scope anyways...
-    S.LookupName(R, S.getCurScope());
-    // FIXME: If the builtin function was user-declared in global scope,
-    // this assert *will* fail. Should this call LookupBuiltin instead?
-    assert(R.isSingleResult() &&
-           "Since this is a builtin it should always resolve!");
-    auto *VD = cast<ValueDecl>(R.getFoundDecl());
-    QualType Ty = VD->getType();
-    return DeclRefExpr::Create(AST, NestedNameSpecifierLoc(), SourceLocation(),
-                               VD, false, NameInfo, Ty, VK_PRValue);
-  }
-
   static Expr *emitResourceClassExpr(ASTContext &AST, ResourceClass RC) {
     return IntegerLiteral::Create(
         AST,
@@ -211,12 +191,8 @@ struct BuiltinTypeDeclBuilder {
 
     ASTContext &AST = Record->getASTContext();
     QualType ElemTy = AST.Char8Ty;
-    if (Template) {
-      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
-              Template->getTemplateParameters()->getParam(0))) {
-        ElemTy = QualType(TTD->getTypeForDecl(), 0);
-      }
-    }
+    if (Template)
+      ElemTy = getFirstTemplateTypeParam();
     QualType ReturnTy = ElemTy;
 
     FunctionProtoType::ExtProtoInfo ExtInfo;
@@ -282,6 +258,23 @@ struct BuiltinTypeDeclBuilder {
     return *this;
   }
 
+  FieldDecl *getResourceHandleField() {
+    FieldDecl *FD = Fields["h"];
+    if (FD && FD->getType()->isHLSLAttributedResourceType())
+      return FD;
+    return nullptr;
+  }
+
+  QualType getFirstTemplateTypeParam() {
+    if (Template) {
+      if (const auto *TTD = dyn_cast<TemplateTypeParmDecl>(
+              Template->getTemplateParameters()->getParam(0))) {
+        return QualType(TTD->getTypeForDecl(), 0);
+      }
+    }
+    return QualType();
+  }
+
   BuiltinTypeDeclBuilder &startDefinition() {
     if (Record->isCompleteDefinition())
       return *this;
@@ -302,6 +295,10 @@ struct BuiltinTypeDeclBuilder {
   TemplateParameterListBuilder addTemplateArgumentList(Sema &S);
   BuiltinTypeDeclBuilder &addSimpleTemplateParams(Sema &S,
                                                   ArrayRef<StringRef> Names);
+
+  // Builtin types methods
+  BuiltinTypeDeclBuilder &addIncrementCounterMethod(Sema &S);
+  BuiltinTypeDeclBuilder &addDecrementCounterMethod(Sema &S);
 };
 
 struct TemplateParameterListBuilder {
@@ -359,6 +356,176 @@ struct TemplateParameterListBuilder {
     return Builder;
   }
 };
+
+// Builder for methods of builtin types. Allows adding methods to builtin types
+// using the builder pattern like this:
+//
+//   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
+//       .addParam("param_name", Type, InOutModifier)
+//       .callBuiltin("buildin_name", { BuiltinParams })
+//       .finalizeMethod();
+//
+// The builder needs to have all of the method parameters before it can create
+// a CXXMethodDecl. It collects them in addParam calls and when a first
+// method that builds the body is called it creates the CXXMethodDecl and
+// ParmVarDecls instances. These can then be referenced from the body building
+// methods. Destructor or an explicit call to finalizeMethod() will complete
+// the method definition.
+struct BuiltinTypeMethodBuilder {
+  struct MethodParam {
+    const IdentifierInfo &NameII;
+    QualType Ty;
+    HLSLParamModifierAttr::Spelling Modifier;
+    MethodParam(const IdentifierInfo &NameII, QualType Ty,
+                HLSLParamModifierAttr::Spelling Modifier)
+        : NameII(NameII), Ty(Ty), Modifier(Modifier) {}
+  };
+
+  BuiltinTypeDeclBuilder &DeclBuilder;
+  Sema &S;
+  DeclarationNameInfo NameInfo;
+  QualType ReturnTy;
+  CXXMethodDecl *Method;
+  llvm::SmallVector<MethodParam> Params;
+  llvm::SmallVector<Stmt *> StmtsList;
+
+public:
+  BuiltinTypeMethodBuilder(Sema &S, BuiltinTypeDeclBuilder &DB, StringRef Name,
+                           QualType ReturnTy)
+      : DeclBuilder(DB), S(S), ReturnTy(ReturnTy), Method(nullptr) {
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    NameInfo = DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  }
+
+  BuiltinTypeMethodBuilder &addParam(StringRef Name, QualType Ty,
+                                     HLSLParamModifierAttr::Spelling Modifier =
+                                         HLSLParamModifierAttr::Keyword_in) {
+    assert(Method == nullptr && "Cannot add param, method already created");
+
+    const IdentifierInfo &II =
+        S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+    Params.emplace_back(II, Ty, Modifier);
+    return *this;
+  }
+
+private:
+  void createMethodDecl() {
+    assert(Method == nullptr && "Method already created");
+
+    // create method type
+    ASTContext &AST = S.getASTContext();
+    SmallVector<QualType> ParamTypes;
+    for (auto &MP : Params)
+      ParamTypes.emplace_back(MP.Ty);
+    QualType MethodTy = AST.getFunctionType(ReturnTy, ParamTypes,
+                                            FunctionProtoType::ExtProtoInfo());
+
+    // create method decl
+    auto *TSInfo = AST.getTrivialTypeSourceInfo(MethodTy, SourceLocation());
+    Method =
+        CXXMethodDecl::Create(AST, DeclBuilder.Record, SourceLocation(),
+                              NameInfo, MethodTy, TSInfo, SC_None, false, false,
+                              ConstexprSpecKind::Unspecified, SourceLocation());
+
+    // create params & set them to the function prototype
+    SmallVector<ParmVarDecl *> ParmDecls;
+    auto FnProtoLoc =
+        Method->getTypeSourceInfo()->getTypeLoc().getAs<FunctionProtoTypeLoc>();
+    unsigned i = 0;
+    for (auto &MP : Params) {
+      ParmVarDecl *Parm = ParmVarDecl::Create(
+          AST, Method->getDeclContext(), SourceLocation(), SourceLocation(),
+          &MP.NameII, MP.Ty,
+          AST.getTrivialTypeSourceInfo(MP.Ty, SourceLocation()), SC_None,
+          nullptr);
+      if (MP.Modifier != HLSLParamModifierAttr::Keyword_in) {
+        auto *Mod =
+            HLSLParamModifierAttr::Create(AST, SourceRange(), MP.Modifier);
+        Parm->addAttr(Mod);
+      }
+      ParmDecls.push_back(Parm);
+      FnProtoLoc.setParam(i++, Parm);
+    }
+    Method->setParams({ParmDecls});
+  }
+
+  void addResourceHandleToParms(SmallVector<Expr *> &Parms) {
+    ASTContext &AST = S.getASTContext();
+    FieldDecl *HandleField = DeclBuilder.getResourceHandleField();
+    auto *This = CXXThisExpr::Create(
+        AST, SourceLocation(), Method->getFunctionObjectParameterType(), true);
+    Parms.push_back(MemberExpr::CreateImplicit(AST, This, false, HandleField,
+                                               HandleField->getType(),
+                                               VK_LValue, OK_Ordinary));
+  }
+
+public:
+  ~BuiltinTypeMethodBuilder() { finalizeMethod(); }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltin(StringRef BuiltinName, ArrayRef<Expr *> CallParms,
+              bool AddResourceHandleAsFirstArg = true) {
+    if (!Method)
+      createMethodDecl();
+
+    ASTContext &AST = S.getASTContext();
+    DeclRefExpr *Fn = lookupBuiltinFunction(S, BuiltinName);
+    Expr *Call = nullptr;
+
+    if (AddResourceHandleAsFirstArg) {
+      SmallVector<Expr *> NewCallParms;
+      addResourceHandleToParms(NewCallParms);
+      for (auto *P : CallParms)
+        NewCallParms.push_back(P);
+
+      Call = CallExpr::Create(AST, Fn, NewCallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    } else {
+      Call = CallExpr::Create(AST, Fn, CallParms, AST.VoidPtrTy, VK_PRValue,
+                              SourceLocation(), FPOptionsOverride());
+    }
+    StmtsList.push_back(Call);
+    return *this;
+  }
+
+  BuiltinTypeMethodBuilder &
+  callBuiltinForwardArgs(StringRef BuiltinName,
+                         bool AddResourceHandleAsFirstArg = true) {
+    // FIXME: Call the buildin with all of the method parameters
+    // plus optional resource handle as the first arg.
+    llvm_unreachable("not yet implemented");
+  }
+
+  BuiltinTypeDeclBuilder &finalizeMethod() {
+    if (DeclBuilder.Record->isCompleteDefinition())
+      return DeclBuilder;
+
+    if (!Method)
+      createMethodDecl();
+
+    if (!Method->hasBody()) {
+      ASTContext &AST = S.getASTContext();
+      if (ReturnTy != AST.VoidTy && !StmtsList.empty()) {
+        if (Expr *LastExpr = dyn_cast<Expr>(StmtsList.back())) {
+          StmtsList.pop_back();
+          StmtsList.push_back(
+              ReturnStmt::Create(AST, SourceLocation(), LastExpr, nullptr));
+        }
+      }
+
+      Method->setBody(CompoundStmt::Create(AST, StmtsList, FPOptionsOverride(),
+                                           SourceLocation(), SourceLocation()));
+      Method->setLexicalDeclContext(DeclBuilder.Record);
+      Method->setAccess(AccessSpecifier::AS_public);
+      Method->addAttr(AlwaysInlineAttr::CreateImplicit(
+          AST, SourceRange(), AlwaysInlineAttr::CXX11_clang_always_inline));
+      DeclBuilder.Record->addDecl(Method);
+    }
+    return DeclBuilder;
+  }
+};
+
 } // namespace
 
 TemplateParameterListBuilder
@@ -375,6 +542,30 @@ BuiltinTypeDeclBuilder::addSimpleTemplateParams(Sema &S,
   return Builder.finalizeTemplateArgs();
 }
 
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addIncrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *One =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), 1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
+      .finalizeMethod();
+}
+
+BuiltinTypeDeclBuilder &
+BuiltinTypeDeclBuilder::addDecrementCounterMethod(Sema &S) {
+  ASTContext &AST = S.getASTContext();
+  Expr *NegOne =
+      IntegerLiteral::Create(AST, llvm::APInt(AST.getTypeSize(AST.IntTy), -1),
+                             AST.IntTy, SourceLocation());
+  return BuiltinTypeMethodBuilder(S, *this, "DecrementCounter",
+                                  AST.UnsignedIntTy)
+      .callBuiltin("__builtin_hlsl_buffer_update_counter", {NegOne})
+      .finalizeMethod();
+}
+
 HLSLExternalSemaSource::~HLSLExternalSemaSource() {}
 
 void HLSLExternalSemaSource::InitializeSema(Sema &S) {
@@ -528,8 +719,13 @@ void HLSLExternalSemaSource::defineHLSLTypesWithForwardDeclarations() {
                     ResourceKind::TypedBuffer, /*IsROV=*/false,
                     /*RawBuffer=*/true)
         .addArraySubscriptOperators()
+        .addIncrementCounterMethod(*SemaPtr)
+        .addDecrementCounterMethod(*SemaPtr)
         .completeDefinition();
   });
+
+  // FIXME: Also add Increment/DecrementCounter to
+  // RasterizerOrderedStructuredBuffer when llvm/llvm-project/#113648 is merged.
 }
 
 void HLSLExternalSemaSource::onCompletion(CXXRecordDecl *Record,
@@ -552,3 +748,23 @@ void HLSLExternalSemaSource::CompleteType(TagDecl *Tag) {
     return;
   It->second(Record);
 }
+
+static DeclRefExpr *lookupBuiltinFunction(Sema &S, StringRef Name) {
+  IdentifierInfo &II =
+      S.getASTContext().Idents.get(Name, tok::TokenKind::identifier);
+  DeclarationNameInfo NameInfo =
+      DeclarationNameInfo(DeclarationName(&II), SourceLocation());
+  LookupResult R(S, NameInfo, Sema::LookupOrdinaryName);
+  // AllowBuiltinCreation is false but LookupDirect will create
+  // the builtin when searching the global scope anyways...
+  S.LookupName(R, S.getCurScope());
+  // FIXME: If the builtin function was user-declared in global scope,
+  // this assert *will* fail. Should this call LookupBuiltin instead?
+  assert(R.isSingleResult() &&
+         "Since this is a builtin it should always resolve!");
+  auto *VD = cast<ValueDecl>(R.getFoundDecl());
+  QualType Ty = VD->getType();
+  return DeclRefExpr::Create(S.getASTContext(), NestedNameSpecifierLoc(),
+                             SourceLocation(), VD, false, NameInfo, Ty,
+                             VK_PRValue);
+}
diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp
index ff6616901016ab..770bd4a81633e1 100644
--- a/clang/lib/Sema/SemaExpr.cpp
+++ b/clang/lib/Sema/SemaExpr.cpp
@@ -986,6 +986,10 @@ Sema::VarArgKind Sema::isValidVarArgType(const QualType &Ty) {
   if (getLangOpts().MSVCCompat)
     return VAK_MSVCUndefined;
 
+  if (getLangOpts().HLSL &&
+      Ty->getUnqualifiedDesugaredType()->isHLSLAttributedResourceType())
+    return VAK_Valid;
+
   // FIXME: In C++11, these cases are conditionally-supported, meaning we're
   // permitted to reject them. We should consider doing so.
   return VAK_Undefined;
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index 1f6c5b8d4561bc..1b7f0456a3e82a 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -1860,6 +1860,31 @@ static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) {
   return false;
 }
 
+static bool CheckResourceHandle(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType.getTypePtr()
+           ->getUnqualifiedDesugaredType()
+           ->isHLSLAttributedResourceType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(),
+            diag::err_typecheck_expect_hlsl_resource)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
+static bool CheckInt(Sema *S, CallExpr *TheCall, unsigned ArgIndex) {
+  assert(TheCall->getNumArgs() >= ArgIndex);
+  QualType ArgType = TheCall->getArg(ArgIndex)->getType();
+  if (!ArgType->isIntegerType()) {
+    S->Diag(TheCall->getArg(0)->getBeginLoc(), diag::err_typecheck_expect_int)
+        << ArgType;
+    return true;
+  }
+  return false;
+}
+
 // Note: returning true in this case results in CheckBuiltinFunctionCall
 // returning an ExprError
 bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
@@ -2100,6 +2125,22 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
       return true;
     break;
   }
+  case Builtin::BI__builtin_hlsl_buffer_update_counter: {
+    if (SemaRef.checkArgCount(TheCall, 2) ||
+        CheckResourceHandle(&SemaRef, TheCall, 0) ||
+        CheckInt(&SemaRef, TheCall, 1))
+      return true;
+    Expr *OffsetExpr = TheCall->getArg(1);
+    std::optional<llvm::APSInt> Offset =
+        OffsetExpr->getIntegerConstantExpr(SemaRef.getASTContext());
+    if (!Offset.has_value() || abs(Offset->getExtValue()) != 1) {
+      SemaRef.Diag(TheCall->getArg(1)->getBeginLoc(),
+                   diag::err_hlsl_expect_arg_const_int_one_or_neg_one)
+          << 1;
+      return true;
+    }
+    break;
+  }
   }
   return false;
 }
diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
new file mode 100644
index 00000000000000..c8ff5d3cd905fb
--- /dev/null
+++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-methods-lib.hlsl
@@ -0,0 +1,25 @@
+// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL
+// RUN-DISABLED...
[truncated]

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly looking good. I really like the additions to the builder API. Can you please add an AST test to verify the shape of the new AST nodes and their instantiations?

Copy link

github-actions bot commented Nov 5, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@llvm-beanz llvm-beanz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple small comment take or leave, but looks good.

Copy link
Contributor

@pow2clk pow2clk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#113648 is merged. Should probably update the description as I think the corresponding inclusion of ROV buffers went in too.

I haven't finished my review, but I thought I'd get these two points in quickly.

AccessSpecifier Access = AccessSpecifier::AS_private) {
if (Record->isCompleteDefinition())
return *this;
assert(!Record->isCompleteDefinition() && "record is already complete");

ASTContext &Ctx = S.getASTContext();
TypeSourceInfo *ElementTypeInfo = nullptr;

QualType ElemTy = Ctx.Char8Ty;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pre-existing side-note: Does this mean that ByteAddressBuffer will have an element type of i8? Shouldn't it be void, to clearly differentiate this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i8 seems like a good default to me. We can decide if we want to use void instead when we implement the ByteAddressBuffer.

// method that builds the body is called it creates the CXXMethodDecl and
// ParmVarDecls instances. These can then be referenced from the body building
// methods. Destructor or an explicit call to finalizeMethod() will complete
// the method definition.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some built-in behaviors/assumptions, some of which probably should be mentioned in the documentation comment.

Here are some oddities I noticed:

  • The value returned from the last (statement) expression (builtin call only for now) will be used as the return value for the method (type from builtin must match ReturnType). I get the convenience for this use case, and maybe we never need anything more, but it does seem like we should point this out in the comment.
  • It looks like it's trying to be flexible to allow multiple calls to be added, but I don't know any way this architecture would allow you to capture returned values from those calls, reference those in subsequent calls, reference method args as mentioned earlier, etc... It seems like the interface has been designed tightly around the one use case, but has hints of incompletely designed flexibility (which would likely require changes to this design to take advantage of).
  • callBuiltin(): Default passing of MemberExpr for the resource handle of this from RecordBuilder as the first argument. I feel like this could be made more explicit, such as by adding a method BuiltinTypeDeclBuilder::getResourceHandleExpr(), and a modification of BuiltinParams specified when used below to: {getResourceHandleExpr(), One}.

Copy link
Member Author

@hekota hekota Nov 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am trying to keep the same builder pattern for the interface as we already have in the BuiltinTypeDeclBuilder. I did have more than just one use case in mind I designed it, but at this moment I can only add building methods that can be tested as part of task #113513. We are going to add more building methods as needed as we implement the rest of the HLSL methods.

For example this is how the Append(T value) method on the AppendStructuredBuffer would look like:

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
  return BuiltinTypeMethodBuilder(S, *this, "Append", S.getASTContext().VoidTy)
      .addParam("value", getFirstTemplateTypeParam())
      .callBuiltin("__builtin_hlsl_buffer_update_counter", {getConstantIntExpr(1)})
      .callBuiltinForwardParams("__builtin_hlsl_buffer_store") 
      .finalizeMethod();
}

The callBuiltinForwardParams passes in the resource handle as the first argument of the builtin (unless that is not desired and the method takes a bool to disable that) and then it adds/forwards all of the Append method params to the builtin call (in this the case the value). The addParam & callBuiltinForwardParams combo will probably be the most commonly used case when we implement other HLSL methods.

Without the implicit behavior it might look like this:

BuiltinTypeDeclBuilder &BuiltinTypeDeclBuilder::addAppendMethod() {
  BuiltinTypeMethodBuilder MB =
      BuiltinTypeMethodBuilder(S, *this, "Append", S.getASTContext().UnsignedIntTy);
  return MB
      .addParam("value", getFirstTemplateTypeParam())
      .addStmt(callBuiltin(S, "__builtin_hlsl_buffer_update_counter",
                           {MB.getResourceHandleExpr(), getConstantIntExpr(1)}))
      .returnExpr(callBuiltin(S, "__builtin_hlsl_buffer_store",
                              {MB.getResourceHandleExpr(), MB.getMethodArg(0)}))
}

I think the first case looks much cleaner. As long as the implicit behavior is well documented and used often enough, I would prefer to keep it like that. I will add more comments pointing out the implicit behavior.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tex3d - Justin is looking into having callBuiltin take placeholder values (something like std::placeholders) that would be used to reference the handle and method arguments.

@bogner FYI

SourceLocation());
return BuiltinTypeMethodBuilder(S, *this, "IncrementCounter",
AST.UnsignedIntTy)
.callBuiltin("__builtin_hlsl_buffer_update_counter", {One})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if we were to generalize this a bit more, we'd end up with something like:
.returnExpr(CallBuiltin(S, "__builtin_hlsl_buffer_update_counter", {getResourceHandleExpr() One})

Then we wouldn't have to depend on extra implicit functionality in 'callBuiltin' and finalizeMethod.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments above.

// AllowBuiltinCreation is false but LookupDirect will create
// the builtin when searching the global scope anyways...
S.LookupName(R, S.getCurScope());
// FIXME: If the builtin function was user-declared in global scope,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that require doing so here inside HLSLExternalSemaSource? How else would this find something before we even start parsing the built-in HLSL header, let alone the source file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is just moving a pre-existing function, feel free to ignore this comment for the purpose of completing the PR. I'm just wondering out loud.

Few changes to minimize conflicts with another PR in progress.
Update handle field name and tests.
Copy link
Contributor

@bogner bogner left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few nitpicks - this looks great, thanks!

Comment on lines +12533 to +12536
def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
"argument %0 must be constant integer 1 or -1">;
def err_invalid_hlsl_resource_type: Error<
"invalid __hlsl_resource_t type attributes">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These messages feel like they're generic enough that we could reuse some existing diagnostic, but I don't really see anything that works. These are probably fine, but we should make sure to pay attention to adding too many new diagnostics if/when we don't need to.

@@ -37,7 +37,7 @@ def int_dx_typedBufferStore
: DefaultAttrsIntrinsic<[], [llvm_any_ty, llvm_i32_ty, llvm_anyvector_ty],
[IntrWriteMem]>;

def int_dx_updateCounter
def int_dx_bufferUpdateCounter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll need to revisit this once we've agreed on llvm/wg-hlsl#99, so I guess no changes needed for now, but I think we'll want to rename this int_dx_resource_updatecounter later.

@hekota hekota merged commit 94bde8c into llvm:main Nov 23, 2024
8 checks passed
hekota added a commit to hekota/llvm-project that referenced this pull request Nov 23, 2024
hekota added a commit that referenced this pull request Nov 23, 2024
hekota added a commit that referenced this pull request Nov 26, 2024
…rs (#117608)

Introduces `__builtin_hlsl_buffer_update_counter` clang buildin that is
used to implement the `IncrementCounter` and `DecrementCounter` methods
on `RWStructuredBuffer` and `RasterizerOrderedStructuredBuffer` (see
Note).

The builtin is translated to LLVM intrisic `llvm.dx.bufferUpdateCounter`
or `llvm.spv.bufferUpdateCounter`.

Introduces `BuiltinTypeMethodBuilder` helper in `HLSLExternalSemaSource`
that enables adding methods to builtin types using builder pattern like
this:
```
   BuiltinTypeMethodBuilder(Sema, RecordBuilder, "MethodName", ReturnType)
       .addParam("param_name", Type, InOutModifier)
       .callBuiltin("buildin_name", { BuiltinParams })
       .finalizeMethod();
```

Fixes #113513

[First version](#114148) of this PR was reverted
because of build break.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:DirectX clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support llvm:ir
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[HLSL] Implement IncrementCounter/DecrementCounter on structured buffers
6 participants