From bc96fd8563ecc742a49b45cbc2b6cc89da570fcc Mon Sep 17 00:00:00 2001 From: Jules-Kong Date: Sun, 8 Dec 2024 13:25:28 +0800 Subject: [PATCH 1/2] [VENTUS][NFC] Add build options --- build-ventus.sh | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/build-ventus.sh b/build-ventus.sh index b02a91c6d3c7..e9fedf34f8c2 100755 --- a/build-ventus.sh +++ b/build-ventus.sh @@ -10,7 +10,8 @@ PROGRAMS_TOBUILD=(llvm ocl-icd libclc spike driver pocl rodinia test-pocl) help() { cat < - Chosen programs to build : llvm, ocl-icd, libclc, spike, driver, pocl + Chosen programs to build : llvm, ocl-icd, libclc, spike, driver, pocl, rodinia, test-pocl Option format : "llvm;pocl", string are seperated by semicolon - Default : "llvm;ocl-icd;libclc;spike;driver;pocl" + Default : "llvm;ocl-icd;libclc;spike;driver;pocl;rodinia;test-pocl" 'BUILD_TYPE' is default 'Release' which can be changed by enviroment variable --help | -h From 61eafc4714c6edc69a6d0d4201f78cf4189a5a04 Mon Sep 17 00:00:00 2001 From: Jules-Kong Date: Mon, 9 Dec 2024 12:25:57 +0800 Subject: [PATCH 2/2] [VENTUS][Printf] Add opencl printf pass --- libclc/riscv32/lib/workitem/workitem.S | 9 + llvm/lib/Target/RISCV/CMakeLists.txt | 1 + llvm/lib/Target/RISCV/RISCV.h | 10 + llvm/lib/Target/RISCV/RISCVTargetMachine.cpp | 24 +- llvm/lib/Target/RISCV/RISCVTargetMachine.h | 2 + .../RISCV/VentusPrintfRuntimeBinding.cpp | 601 ++++++++++++++++++ llvm/test/CodeGen/RISCV/VentusGPGPU/printf.ll | 24 + llvm/tools/opt/opt.cpp | 3 +- 8 files changed, 672 insertions(+), 2 deletions(-) create mode 100644 llvm/lib/Target/RISCV/VentusPrintfRuntimeBinding.cpp create mode 100644 llvm/test/CodeGen/RISCV/VentusGPGPU/printf.ll diff --git a/libclc/riscv32/lib/workitem/workitem.S b/libclc/riscv32/lib/workitem/workitem.S index ab763d1a6cd1..8385ee8888d8 100644 --- a/libclc/riscv32/lib/workitem/workitem.S +++ b/libclc/riscv32/lib/workitem/workitem.S @@ -397,3 +397,12 @@ __builtin_riscv_work_dim: lw t0, KNL_WORK_DIM(a0) # Get work_dim vmv.v.x v0, t0 ret + + + .section .text.__builtin_riscv_printf_alloc,"ax",@progbits + .global __builtin_riscv_printf_alloc + .type __builtin_riscv_printf_alloc, @function +__builtin_riscv_printf_alloc: + csrr a0, CSR_PRINT # Get printf buffer + vadd.vx v0, v0, a0 + ret diff --git a/llvm/lib/Target/RISCV/CMakeLists.txt b/llvm/lib/Target/RISCV/CMakeLists.txt index 4316d6793df4..d5239a545bfd 100644 --- a/llvm/lib/Target/RISCV/CMakeLists.txt +++ b/llvm/lib/Target/RISCV/CMakeLists.txt @@ -47,6 +47,7 @@ add_llvm_target(RISCVCodeGen GISel/RISCVInstructionSelector.cpp GISel/RISCVLegalizerInfo.cpp GISel/RISCVRegisterBankInfo.cpp + VentusPrintfRuntimeBinding.cpp LINK_COMPONENTS Analysis diff --git a/llvm/lib/Target/RISCV/RISCV.h b/llvm/lib/Target/RISCV/RISCV.h index 318a8dc9518b..37fb84438dc3 100644 --- a/llvm/lib/Target/RISCV/RISCV.h +++ b/llvm/lib/Target/RISCV/RISCV.h @@ -14,6 +14,7 @@ #ifndef LLVM_LIB_TARGET_RISCV_RISCV_H #define LLVM_LIB_TARGET_RISCV_RISCV_H +#include "llvm/Pass.h" #include "MCTargetDesc/RISCVBaseInfo.h" #include "llvm/Target/TargetMachine.h" @@ -81,6 +82,15 @@ void initializeVentusInsertJoinToVBranchPass(PassRegistry &); InstructionSelector *createRISCVInstructionSelector(const RISCVTargetMachine &, RISCVSubtarget &, RISCVRegisterBankInfo &); + +ModulePass *createVentusPrintfRuntimeBinding(); +void initializeVentusPrintfRuntimeBindingPass(PassRegistry&); +extern char &VentusPrintfRuntimeBindingID; + +struct VentusPrintfRuntimeBindingPass + : PassInfoMixin { + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; } diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp index b29a2e34943c..3cbe9be8f75a 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.cpp @@ -30,8 +30,10 @@ #include "llvm/CodeGen/TargetLoweringObjectFileImpl.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Passes/PassBuilder.h" #include "llvm/Support/FormattedStream.h" #include "llvm/Target/TargetOptions.h" #include "llvm/Transforms/Scalar.h" @@ -65,6 +67,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeRISCVTarget() { initializeRISCVSExtWRemovalPass(*PR); initializeRISCVPreRAExpandPseudoPass(*PR); initializeRISCVExpandPseudoPass(*PR); + initializeVentusPrintfRuntimeBindingPass(*PR); } static StringRef computeDataLayout(const Triple &TT, StringRef CPU) { @@ -134,6 +137,23 @@ RISCVTargetMachine::getSubtargetImpl(const Function &F) const { return I.get(); } +void RISCVTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { + PB.registerPipelineParsingCallback( + [this](StringRef PassName, ModulePassManager &PM, + ArrayRef) { + if (PassName == "ventus-printf-runtime-binding") { + PM.addPass(VentusPrintfRuntimeBindingPass()); + return true; + } + return false; + }); + + PB.registerPipelineEarlySimplificationEPCallback( + [this](ModulePassManager &PM, OptimizationLevel Level) { + PM.addPass(VentusPrintfRuntimeBindingPass()); + }); +} + TargetTransformInfo RISCVTargetMachine::getTargetTransformInfo(const Function &F) const { return TargetTransformInfo(RISCVTTIImpl(this, F)); @@ -201,6 +221,8 @@ TargetPassConfig *RISCVTargetMachine::createPassConfig(PassManagerBase &PM) { } void RISCVPassConfig::addIRPasses() { + addPass(createVentusPrintfRuntimeBinding()); + if (getOptLevel() != CodeGenOpt::None) { addPass(createSROAPass()); addPass(createInferAddressSpacesPass()); @@ -359,4 +381,4 @@ RISCVTargetMachine::getAddressSpaceForPseudoSourceKind(unsigned Kind) const { return RISCVAS::CONSTANT_ADDRESS; } return RISCVAS::FLAT_ADDRESS; -} \ No newline at end of file +} diff --git a/llvm/lib/Target/RISCV/RISCVTargetMachine.h b/llvm/lib/Target/RISCV/RISCVTargetMachine.h index 78ffdad8ab18..699f8d873e59 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetMachine.h +++ b/llvm/lib/Target/RISCV/RISCVTargetMachine.h @@ -44,6 +44,8 @@ public: return TLOF.get(); } + void registerPassBuilderCallbacks(PassBuilder &PB) override; + TargetTransformInfo getTargetTransformInfo(const Function &F) const override; bool isNoopAddrSpaceCast(unsigned SrcAS, unsigned DstAS) const override; diff --git a/llvm/lib/Target/RISCV/VentusPrintfRuntimeBinding.cpp b/llvm/lib/Target/RISCV/VentusPrintfRuntimeBinding.cpp new file mode 100644 index 000000000000..f9a7f177b085 --- /dev/null +++ b/llvm/lib/Target/RISCV/VentusPrintfRuntimeBinding.cpp @@ -0,0 +1,601 @@ +//=== VentusPrintfRuntimeBinding.cpp - OpenCL printf implementation -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// \file +// +// The pass bind printfs to a kernel arg pointer that will be bound to a buffer +// later by the runtime. +// +// This pass traverses the functions in the module and converts +// each call to printf to a sequence of operations that +// store the following into the printf buffer: +// - format string (passed as a module's metadata unique ID) +// - bitwise copies of printf arguments +// The backend passes will need to store metadata in the kernel +//===----------------------------------------------------------------------===// + +#include "RISCV.h" +#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "printfForVentus" +#define DWORD_ALIGN 4 + +namespace { +class VentusPrintfRuntimeBinding final : public ModulePass { + +public: + static char ID; + + explicit VentusPrintfRuntimeBinding(); + +private: + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addRequired(); + } +}; + +class VentusPrintfRuntimeBindingImpl { +public: + VentusPrintfRuntimeBindingImpl( + function_ref GetDT, + function_ref GetTLI) + : GetDT(GetDT), GetTLI(GetTLI) {} + bool run(Module &M); + +private: + void getConversionSpecifiers(SmallVectorImpl &OpConvSpecifiers, + StringRef fmt, size_t num_ops) const; + + bool shouldPrintAsStr(char Specifier, Type *OpType) const; + bool lowerPrintfForGpu(Module &M); + + Value *simplify(Instruction *I, const TargetLibraryInfo *TLI, + const DominatorTree *DT) { + return simplifyInstruction(I, {*TD, TLI, DT}); + } + + const DataLayout *TD; + function_ref GetDT; + function_ref GetTLI; + SmallVector Printfs; +}; +} // namespace + +char VentusPrintfRuntimeBinding::ID = 0; + +INITIALIZE_PASS_BEGIN(VentusPrintfRuntimeBinding, + "ventus-printf-runtime-binding", "Ventus Printf lowering", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(VentusPrintfRuntimeBinding, "ventus-printf-runtime-binding", + "Ventus Printf lowering", false, false) + +char &llvm::VentusPrintfRuntimeBindingID = VentusPrintfRuntimeBinding::ID; + +namespace llvm { +ModulePass *createVentusPrintfRuntimeBinding() { + return new VentusPrintfRuntimeBinding(); +} +} // namespace llvm + +VentusPrintfRuntimeBinding::VentusPrintfRuntimeBinding() : ModulePass(ID) { + initializeVentusPrintfRuntimeBindingPass(*PassRegistry::getPassRegistry()); +} + +void VentusPrintfRuntimeBindingImpl::getConversionSpecifiers( + SmallVectorImpl &OpConvSpecifiers, StringRef Fmt, + size_t NumOps) const { + // not all format characters are collected. + // At this time the format characters of interest + // are %p and %s, which use to know if we + // are either storing a literal string or a + // pointer to the printf buffer. + static const char ConvSpecifiers[] = "cdieEfgGaosuxXp"; + size_t CurFmtSpecifierIdx = 0; + size_t PrevFmtSpecifierIdx = 0; + + while ((CurFmtSpecifierIdx = Fmt.find_first_of( + ConvSpecifiers, CurFmtSpecifierIdx)) != StringRef::npos) { + bool ArgDump = false; + StringRef CurFmt = Fmt.substr(PrevFmtSpecifierIdx, + CurFmtSpecifierIdx - PrevFmtSpecifierIdx); + size_t pTag = CurFmt.find_last_of("%"); + if (pTag != StringRef::npos) { + ArgDump = true; + while (pTag && CurFmt[--pTag] == '%') { + ArgDump = !ArgDump; + } + } + + if (ArgDump) + OpConvSpecifiers.push_back(Fmt[CurFmtSpecifierIdx]); + + PrevFmtSpecifierIdx = ++CurFmtSpecifierIdx; + } +} + +bool VentusPrintfRuntimeBindingImpl::shouldPrintAsStr(char Specifier, + Type *OpType) const { + if (Specifier != 's') + return false; + const PointerType *PT = dyn_cast(OpType); + if (!PT || PT->getAddressSpace() != RISCVAS::CONSTANT_ADDRESS) + return false; + Type *ElemType = PT->getContainedType(0); + if (ElemType->getTypeID() != Type::IntegerTyID) + return false; + IntegerType *ElemIType = cast(ElemType); + return ElemIType->getBitWidth() == 8; +} + +bool VentusPrintfRuntimeBindingImpl::lowerPrintfForGpu(Module &M) { + LLVMContext &Ctx = M.getContext(); + IRBuilder<> Builder(Ctx); + Type *I32Ty = Type::getInt32Ty(Ctx); + unsigned UniqID = 0; + // NB: This is important for this string size to be divisible by 4 + const char NonLiteralStr[4] = "???"; + + for (auto *CI : Printfs) { + unsigned NumOps = CI->arg_size(); + + SmallString<16> OpConvSpecifiers; + Value *Op = CI->getArgOperand(0); + + if (auto LI = dyn_cast(Op)) { + Op = LI->getPointerOperand(); + for (auto *Use : Op->users()) { + if (auto SI = dyn_cast(Use)) { + Op = SI->getValueOperand(); + break; + } + } + } + + if (auto I = dyn_cast(Op)) { + Value *Op_simplified = + simplify(I, &GetTLI(*I->getFunction()), &GetDT(*I->getFunction())); + if (Op_simplified) + Op = Op_simplified; + } + + ConstantExpr *ConstExpr = dyn_cast(Op); + GlobalVariable *GVarOp = dyn_cast(Op); + + if (ConstExpr || GVarOp) { + GlobalVariable *GVar = nullptr; + if (ConstExpr) { + GVar = dyn_cast(ConstExpr->getOperand(0)); + } else { + GVar = GVarOp; + } + + StringRef Str("unknown"); + if (GVar && GVar->hasInitializer()) { + auto *Init = GVar->getInitializer(); + if (auto *CA = dyn_cast(Init)) { + if (CA->isString()) + Str = CA->getAsCString(); + } else if (isa(Init)) { + Str = ""; + } + // + // we need this call to ascertain + // that we are printing a string + // or a pointer. It takes out the + // specifiers and fills up the first + // arg + getConversionSpecifiers(OpConvSpecifiers, Str, NumOps - 1); + } + // Add metadata for the string + std::string AStreamHolder; + raw_string_ostream Sizes(AStreamHolder); + int Sum = DWORD_ALIGN; + Sizes << CI->arg_size() - 1; + Sizes << ':'; + for (unsigned ArgCount = 1; + ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); + ArgCount++) { + Value *Arg = CI->getArgOperand(ArgCount); + Type *ArgType = Arg->getType(); + unsigned ArgSize = TD->getTypeAllocSizeInBits(ArgType); + ArgSize = ArgSize / 8; + // + // ArgSize by design should be a multiple of DWORD_ALIGN, + // expand the arguments that do not follow this rule. + // + if (ArgSize % DWORD_ALIGN != 0) { + llvm::Type *ResType = llvm::Type::getInt32Ty(Ctx); + auto *LLVMVecType = llvm::dyn_cast(ArgType); + int NumElem = LLVMVecType ? LLVMVecType->getNumElements() : 1; + if (LLVMVecType && NumElem > 1) + ResType = llvm::FixedVectorType::get(ResType, NumElem); + Builder.SetInsertPoint(CI); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + if (OpConvSpecifiers[ArgCount - 1] == 'x' || + OpConvSpecifiers[ArgCount - 1] == 'X' || + OpConvSpecifiers[ArgCount - 1] == 'u' || + OpConvSpecifiers[ArgCount - 1] == 'o') + Arg = Builder.CreateZExt(Arg, ResType); + else + Arg = Builder.CreateSExt(Arg, ResType); + ArgType = Arg->getType(); + ArgSize = TD->getTypeAllocSizeInBits(ArgType); + ArgSize = ArgSize / 8; + CI->setOperand(ArgCount, Arg); + } + if (OpConvSpecifiers[ArgCount - 1] == 'f') { + ConstantFP *FpCons = dyn_cast(Arg); + if (FpCons) + ArgSize = 4; + else { + FPExtInst *FpExt = dyn_cast(Arg); + if (FpExt && FpExt->getType()->isDoubleTy() && + FpExt->getOperand(0)->getType()->isFloatTy()) + ArgSize = 4; + } + } + if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) { + if (auto *ConstExpr = dyn_cast(Arg)) { + auto *GV = dyn_cast(ConstExpr->getOperand(0)); + if (GV && GV->hasInitializer()) { + Constant *Init = GV->getInitializer(); + bool IsZeroValue = Init->isZeroValue(); + auto *CA = dyn_cast(Init); + if (IsZeroValue || (CA && CA->isString())) { + size_t SizeStr = + IsZeroValue ? 1 : (strlen(CA->getAsCString().data()) + 1); + size_t Rem = SizeStr % DWORD_ALIGN; + size_t NSizeStr = 0; + LLVM_DEBUG(dbgs() << "Printf string original size = " << SizeStr + << '\n'); + if (Rem) { + NSizeStr = SizeStr + (DWORD_ALIGN - Rem); + } else { + NSizeStr = SizeStr; + } + ArgSize = NSizeStr; + } + } else { + ArgSize = sizeof(NonLiteralStr); + } + } else { + ArgSize = sizeof(NonLiteralStr); + } + } + LLVM_DEBUG(dbgs() << "Printf ArgSize (in buffer) = " << ArgSize + << " for type: " << *ArgType << '\n'); + Sizes << ArgSize << ':'; + Sum += ArgSize; + } + LLVM_DEBUG(dbgs() << "Printf format string in source = " << Str.str() + << '\n'); + for (char C : Str) { + // Rest of the C escape sequences (e.g. \') are handled correctly + // by the MDParser + switch (C) { + case '\a': + Sizes << "\\a"; + break; + case '\b': + Sizes << "\\b"; + break; + case '\f': + Sizes << "\\f"; + break; + case '\n': + Sizes << "\\n"; + break; + case '\r': + Sizes << "\\r"; + break; + case '\v': + Sizes << "\\v"; + break; + case ':': + // ':' cannot be scanned by Flex, as it is defined as a delimiter + // Replace it with it's octal representation \72 + Sizes << "\\72"; + break; + default: + Sizes << C; + break; + } + } + + // Insert the printf_alloc call + Builder.SetInsertPoint(CI); + Builder.SetCurrentDebugLocation(CI->getDebugLoc()); + + AttributeList Attr = AttributeList::get(Ctx, AttributeList::FunctionIndex, + Attribute::NoUnwind); + + Type *SizetTy = Type::getInt32Ty(Ctx); + + Type *Tys_alloc[1] = {SizetTy}; + Type *I8Ty = Type::getInt8Ty(Ctx); + Type *I8Ptr = PointerType::get(I8Ty, 1); + FunctionType *FTy_alloc = FunctionType::get(I8Ptr, Tys_alloc, false); + FunctionCallee PrintfAllocFn = + M.getOrInsertFunction(StringRef("__builtin_riscv_printf_alloc"), FTy_alloc, Attr); + + LLVM_DEBUG(dbgs() << "Printf metadata = " << Sizes.str() << '\n'); + std::string fmtstr = itostr(++UniqID) + ":" + Sizes.str(); + MDString *fmtStrArray = MDString::get(Ctx, fmtstr); + + // Instead of creating global variables, the + // printf format strings are extracted + // and passed as metadata. This avoids + // polluting llvm's symbol tables in this module. + // Metadata is going to be extracted + // by the backend passes and inserted + // into the OpenCL binary as appropriate. + StringRef ventus_printf("llvm.printf.fmts"); + NamedMDNode *metaD = M.getOrInsertNamedMetadata(ventus_printf); + MDNode *myMD = MDNode::get(Ctx, fmtStrArray); + metaD->addOperand(myMD); + Value *sumC = ConstantInt::get(SizetTy, Sum, false); + SmallVector alloc_args; + alloc_args.push_back(sumC); + CallInst *pcall = + CallInst::Create(PrintfAllocFn, alloc_args, "printf_alloc_fn", CI); + + // + // Insert code to split basicblock with a + // piece of hammock code. + // basicblock splits after buffer overflow check + // + ConstantPointerNull *zeroIntPtr = + ConstantPointerNull::get(PointerType::get(I8Ty, 1)); + auto *cmp = cast(Builder.CreateICmpNE(pcall, zeroIntPtr, "")); + if (!CI->use_empty()) { + Value *result = + Builder.CreateSExt(Builder.CreateNot(cmp), I32Ty, "printf_res"); + CI->replaceAllUsesWith(result); + } + SplitBlock(CI->getParent(), cmp); + Instruction *Brnch = + SplitBlockAndInsertIfThen(cmp, cmp->getNextNode(), false); + + Builder.SetInsertPoint(Brnch); + + // store unique printf id in the buffer + // + GetElementPtrInst *BufferIdx = GetElementPtrInst::Create( + I8Ty, pcall, ConstantInt::get(Ctx, APInt(32, 0)), "PrintBuffID", + Brnch); + + Type *idPointer = PointerType::get(I32Ty, RISCVAS::GLOBAL_ADDRESS); + Value *id_gep_cast = + new BitCastInst(BufferIdx, idPointer, "PrintBuffIdCast", Brnch); + + new StoreInst(ConstantInt::get(I32Ty, UniqID), id_gep_cast, Brnch); + + // 1st 4 bytes hold the printf_id + // the following GEP is the buffer pointer + BufferIdx = GetElementPtrInst::Create( + I8Ty, pcall, ConstantInt::get(Ctx, APInt(32, 4)), "PrintBuffGep", + Brnch); + + Type *Int32Ty = Type::getInt32Ty(Ctx); + Type *Int64Ty = Type::getInt64Ty(Ctx); + for (unsigned ArgCount = 1; + ArgCount < CI->arg_size() && ArgCount <= OpConvSpecifiers.size(); + ArgCount++) { + Value *Arg = CI->getArgOperand(ArgCount); + Type *ArgType = Arg->getType(); + SmallVector WhatToStore; + if (ArgType->isFPOrFPVectorTy() && !isa(ArgType)) { + Type *IType = (ArgType->isFloatTy()) ? Int32Ty : Int64Ty; + if (OpConvSpecifiers[ArgCount - 1] == 'f') { + if (auto *FpCons = dyn_cast(Arg)) { + APFloat Val(FpCons->getValueAPF()); + bool Lost = false; + Val.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, + &Lost); + Arg = ConstantFP::get(Ctx, Val); + IType = Int32Ty; + } else if (auto *FpExt = dyn_cast(Arg)) { + if (FpExt->getType()->isDoubleTy() && + FpExt->getOperand(0)->getType()->isFloatTy()) { + Arg = FpExt->getOperand(0); + IType = Int32Ty; + } + } + } + Arg = new BitCastInst(Arg, IType, "PrintArgFP", Brnch); + WhatToStore.push_back(Arg); + } else if (ArgType->getTypeID() == Type::PointerTyID) { + if (shouldPrintAsStr(OpConvSpecifiers[ArgCount - 1], ArgType)) { + const char *S = NonLiteralStr; + if (auto *ConstExpr = dyn_cast(Arg)) { + auto *GV = dyn_cast(ConstExpr->getOperand(0)); + if (GV && GV->hasInitializer()) { + Constant *Init = GV->getInitializer(); + bool IsZeroValue = Init->isZeroValue(); + auto *CA = dyn_cast(Init); + if (IsZeroValue || (CA && CA->isString())) { + S = IsZeroValue ? "" : CA->getAsCString().data(); + } + } + } + size_t SizeStr = strlen(S) + 1; + size_t Rem = SizeStr % DWORD_ALIGN; + size_t NSizeStr = 0; + if (Rem) { + NSizeStr = SizeStr + (DWORD_ALIGN - Rem); + } else { + NSizeStr = SizeStr; + } + if (S[0]) { + char *MyNewStr = new char[NSizeStr](); + strcpy(MyNewStr, S); + int NumInts = NSizeStr / 4; + int CharC = 0; + while (NumInts) { + int ANum = *(int *)(MyNewStr + CharC); + CharC += 4; + NumInts--; + Value *ANumV = ConstantInt::get(Int32Ty, ANum, false); + WhatToStore.push_back(ANumV); + } + delete[] MyNewStr; + } else { + // Empty string, give a hint to RT it is no NULL + Value *ANumV = ConstantInt::get(Int32Ty, 0xFFFFFF00, false); + WhatToStore.push_back(ANumV); + } + } else { + uint64_t Size = TD->getTypeAllocSizeInBits(ArgType); + assert((Size == 32 || Size == 64) && "unsupported size"); + Type *DstType = (Size == 32) ? Int32Ty : Int64Ty; + Arg = new PtrToIntInst(Arg, DstType, "PrintArgPtr", Brnch); + WhatToStore.push_back(Arg); + } + } else if (isa(ArgType)) { + Type *IType = nullptr; + uint32_t EleCount = cast(ArgType)->getNumElements(); + uint32_t EleSize = ArgType->getScalarSizeInBits(); + uint32_t TotalSize = EleCount * EleSize; + if (EleCount == 3) { + ShuffleVectorInst *Shuffle = + new ShuffleVectorInst(Arg, Arg, ArrayRef{0, 1, 2, 2}); + Shuffle->insertBefore(Brnch); + Arg = Shuffle; + ArgType = Arg->getType(); + TotalSize += EleSize; + } + switch (EleSize) { + default: + EleCount = TotalSize / 64; + IType = Type::getInt64Ty(ArgType->getContext()); + break; + case 8: + if (EleCount >= 8) { + EleCount = TotalSize / 64; + IType = Type::getInt64Ty(ArgType->getContext()); + } else if (EleCount >= 3) { + EleCount = 1; + IType = Type::getInt32Ty(ArgType->getContext()); + } else { + EleCount = 1; + IType = Type::getInt16Ty(ArgType->getContext()); + } + break; + case 16: + if (EleCount >= 3) { + EleCount = TotalSize / 64; + IType = Type::getInt64Ty(ArgType->getContext()); + } else { + EleCount = 1; + IType = Type::getInt32Ty(ArgType->getContext()); + } + break; + } + if (EleCount > 1) { + IType = FixedVectorType::get(IType, EleCount); + } + Arg = new BitCastInst(Arg, IType, "PrintArgVect", Brnch); + WhatToStore.push_back(Arg); + } else { + WhatToStore.push_back(Arg); + } + for (unsigned I = 0, E = WhatToStore.size(); I != E; ++I) { + Value *TheBtCast = WhatToStore[I]; + unsigned ArgSize = + TD->getTypeAllocSizeInBits(TheBtCast->getType()) / 8; + SmallVector BuffOffset; + BuffOffset.push_back(ConstantInt::get(I32Ty, ArgSize)); + + Type *ArgPointer = PointerType::get(TheBtCast->getType(), 1); + Value *CastedGEP = + new BitCastInst(BufferIdx, ArgPointer, "PrintBuffPtrCast", Brnch); + StoreInst *StBuff = new StoreInst(TheBtCast, CastedGEP, Brnch); + LLVM_DEBUG(dbgs() << "inserting store to printf buffer:\n" + << *StBuff << '\n'); + (void)StBuff; + if (I + 1 == E && ArgCount + 1 == CI->arg_size()) + break; + BufferIdx = GetElementPtrInst::Create(I8Ty, BufferIdx, BuffOffset, + "PrintBuffNextPtr", Brnch); + LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:\n" + << *BufferIdx << '\n'); + } + } + } + } + + // erase the printf calls + for (auto *CI : Printfs) + CI->eraseFromParent(); + + Printfs.clear(); + return true; +} + +bool VentusPrintfRuntimeBindingImpl::run(Module &M) { + Triple TT(M.getTargetTriple()); + if (TT.getArch() != Triple::riscv32) + return false; + + auto PrintfFunction = M.getFunction("printf"); + if (!PrintfFunction) + return false; + + for (auto &U : PrintfFunction->uses()) { + if (auto *CI = dyn_cast(U.getUser())) { + if (CI->isCallee(&U) && + CI->getParent()->getParent()->getCallingConv() == CallingConv::VENTUS_KERNEL) + Printfs.push_back(CI); + } + } + + if (Printfs.empty()) + return false; + + TD = &M.getDataLayout(); + + return lowerPrintfForGpu(M); +} + +bool VentusPrintfRuntimeBinding::runOnModule(Module &M) { + auto GetDT = [this](Function &F) -> DominatorTree & { + return this->getAnalysis(F).getDomTree(); + }; + auto GetTLI = [this](Function &F) -> TargetLibraryInfo & { + return this->getAnalysis().getTLI(F); + }; + + return VentusPrintfRuntimeBindingImpl(GetDT, GetTLI).run(M); +} + +PreservedAnalyses +VentusPrintfRuntimeBindingPass::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult(M).getManager(); + auto GetDT = [&FAM](Function &F) -> DominatorTree & { + return FAM.getResult(F); + }; + auto GetTLI = [&FAM](Function &F) -> TargetLibraryInfo & { + return FAM.getResult(F); + }; + bool Changed = VentusPrintfRuntimeBindingImpl(GetDT, GetTLI).run(M); + return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/llvm/test/CodeGen/RISCV/VentusGPGPU/printf.ll b/llvm/test/CodeGen/RISCV/VentusGPGPU/printf.ll new file mode 100644 index 000000000000..9a3c29ee3523 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/VentusGPGPU/printf.ll @@ -0,0 +1,24 @@ +; RUN: opt -mtriple=riscv32 -ventus-printf-runtime-binding -mcpu=ventus-gpgpu -S < %s | FileCheck --check-prefix=VENTUS %s +; RUN: opt -mtriple=riscv32 -passes=ventus-printf-runtime-binding -mcpu=ventus-gpgpu -S < %s | FileCheck --check-prefix=VENTUS %s + +; VENTUS-LABEL: @test_kernel( +; VENTUS-LABEL: entry +; VENTUS: call ptr addrspace(1) @__builtin_riscv_printf_alloc +; VENTUS-LABEL: entry.split +; VENTUS: icmp ne ptr addrspace(1) %printf_alloc_fn, null +; VENTUS: %PrintBuffID = getelementptr i8, ptr addrspace(1) %printf_alloc_fn, i32 0 +; VENTUS: %PrintBuffIdCast = bitcast ptr addrspace(1) %PrintBuffID to ptr addrspace(1) +; VENTUS: store i32 1, ptr addrspace(1) %PrintBuffIdCast +; VENTUS: %PrintBuffGep = getelementptr i8, ptr addrspace(1) %printf_alloc_fn, i32 4 +; VENTUS: %PrintBuffPtrCast = bitcast ptr addrspace(1) %PrintBuffGep to ptr addrspace(1) +; VENTUS: store i32 10, ptr addrspace(1) %PrintBuffPtrCast + +@.str = private unnamed_addr addrspace(4) constant [5 x i8] c"%5d\0A\00", align 1 + +define ventus_kernel void @test_kernel() { +entry: + %call = call i32 (ptr addrspace(4), ...) @printf(ptr addrspace(4) @.str, i32 noundef 10) + ret void +} + +declare i32 @printf(ptr addrspace(4), ...) diff --git a/llvm/tools/opt/opt.cpp b/llvm/tools/opt/opt.cpp index 749bc671c43e..f15c01d5a229 100644 --- a/llvm/tools/opt/opt.cpp +++ b/llvm/tools/opt/opt.cpp @@ -352,7 +352,8 @@ static bool shouldPinPassToLegacyPM(StringRef Pass) { "amdgpu-propagate-attributes-late", "amdgpu-unify-metadata", "amdgpu-printf-runtime-binding", - "amdgpu-always-inline"}; + "amdgpu-always-inline", + "ventus-printf-runtime-binding"}; if (llvm::is_contained(PassNameExactToIgnore, Pass)) return false;