[HLSL] clang codeGen for HLSLNumThreadsAttr

Translate HLSLNumThreadsAttr into function attribute with name "dx.numthreads" and value format as "x,y,z".

Reviewed By: beanz

Differential Revision: https://reviews.llvm.org/D131799
This commit is contained in:
Xiang Li 2022-08-12 11:50:22 -07:00
parent 1ab2b0075d
commit bad2e6c830
2 changed files with 11 additions and 2 deletions

View File

@ -19,6 +19,7 @@
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
using namespace clang;
using namespace CodeGen;
@ -107,6 +108,13 @@ void clang::CodeGen::CGHLSLRuntime::setHLSLEntryAttributes(
const StringRef ShaderAttrKindStr = "hlsl.shader";
Fn->addFnAttr(ShaderAttrKindStr,
ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
const StringRef NumThreadsKindStr = "hlsl.numthreads";
std::string NumThreadsStr =
formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
NumThreadsAttr->getZ());
Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
}
}
llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,

View File

@ -4,9 +4,10 @@
// Make sure not mangle entry.
// CHECK:define void @foo()
// Make sure add function attribute.
// Make sure add function attribute and numthreads attribute.
// CHECK:"hlsl.numthreads"="16,8,1"
// CHECK:"hlsl.shader"="compute"
[numthreads(1,1,1)]
[numthreads(16,8,1)]
void foo() {
}