262 lines
9.7 KiB
C++
262 lines
9.7 KiB
C++
//===-- SPIRVRegularizer.cpp - regularize IR for SPIR-V ---------*- C++ -*-===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This pass implements regularization of LLVM IR for SPIR-V. The prototype of
|
|
// the pass was taken from SPIRV-LLVM translator.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "SPIRV.h"
|
|
#include "SPIRVTargetMachine.h"
|
|
#include "llvm/Demangle/Demangle.h"
|
|
#include "llvm/IR/InstIterator.h"
|
|
#include "llvm/IR/InstVisitor.h"
|
|
#include "llvm/IR/PassManager.h"
|
|
#include "llvm/Transforms/Utils/Cloning.h"
|
|
|
|
#include <list>
|
|
|
|
#define DEBUG_TYPE "spirv-regularizer"
|
|
|
|
using namespace llvm;
|
|
|
|
namespace llvm {
|
|
void initializeSPIRVRegularizerPass(PassRegistry &);
|
|
}
|
|
|
|
namespace {
|
|
struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
|
|
DenseMap<Function *, Function *> Old2NewFuncs;
|
|
|
|
public:
|
|
static char ID;
|
|
SPIRVRegularizer() : FunctionPass(ID) {
|
|
initializeSPIRVRegularizerPass(*PassRegistry::getPassRegistry());
|
|
}
|
|
bool runOnFunction(Function &F) override;
|
|
StringRef getPassName() const override { return "SPIR-V Regularizer"; }
|
|
|
|
void getAnalysisUsage(AnalysisUsage &AU) const override {
|
|
FunctionPass::getAnalysisUsage(AU);
|
|
}
|
|
void visitCallInst(CallInst &CI);
|
|
|
|
private:
|
|
void visitCallScalToVec(CallInst *CI, StringRef MangledName,
|
|
StringRef DemangledName);
|
|
void runLowerConstExpr(Function &F);
|
|
};
|
|
} // namespace
|
|
|
|
char SPIRVRegularizer::ID = 0;
|
|
|
|
INITIALIZE_PASS(SPIRVRegularizer, DEBUG_TYPE, "SPIR-V Regularizer", false,
|
|
false)
|
|
|
|
// Since SPIR-V cannot represent constant expression, constant expressions
|
|
// in LLVM IR need to be lowered to instructions. For each function,
|
|
// the constant expressions used by instructions of the function are replaced
|
|
// by instructions placed in the entry block since it dominates all other BBs.
|
|
// Each constant expression only needs to be lowered once in each function
|
|
// and all uses of it by instructions in that function are replaced by
|
|
// one instruction.
|
|
// TODO: remove redundant instructions for common subexpression.
|
|
void SPIRVRegularizer::runLowerConstExpr(Function &F) {
|
|
LLVMContext &Ctx = F.getContext();
|
|
std::list<Instruction *> WorkList;
|
|
for (auto &II : instructions(F))
|
|
WorkList.push_back(&II);
|
|
|
|
auto FBegin = F.begin();
|
|
while (!WorkList.empty()) {
|
|
Instruction *II = WorkList.front();
|
|
|
|
auto LowerOp = [&II, &FBegin, &F](Value *V) -> Value * {
|
|
if (isa<Function>(V))
|
|
return V;
|
|
auto *CE = cast<ConstantExpr>(V);
|
|
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] " << *CE);
|
|
auto ReplInst = CE->getAsInstruction();
|
|
auto InsPoint = II->getParent() == &*FBegin ? II : &FBegin->back();
|
|
ReplInst->insertBefore(InsPoint);
|
|
LLVM_DEBUG(dbgs() << " -> " << *ReplInst << '\n');
|
|
std::vector<Instruction *> Users;
|
|
// Do not replace use during iteration of use. Do it in another loop.
|
|
for (auto U : CE->users()) {
|
|
LLVM_DEBUG(dbgs() << "[lowerConstantExpressions] Use: " << *U << '\n');
|
|
auto InstUser = dyn_cast<Instruction>(U);
|
|
// Only replace users in scope of current function.
|
|
if (InstUser && InstUser->getParent()->getParent() == &F)
|
|
Users.push_back(InstUser);
|
|
}
|
|
for (auto &User : Users) {
|
|
if (ReplInst->getParent() == User->getParent() &&
|
|
User->comesBefore(ReplInst))
|
|
ReplInst->moveBefore(User);
|
|
User->replaceUsesOfWith(CE, ReplInst);
|
|
}
|
|
return ReplInst;
|
|
};
|
|
|
|
WorkList.pop_front();
|
|
auto LowerConstantVec = [&II, &LowerOp, &WorkList,
|
|
&Ctx](ConstantVector *Vec,
|
|
unsigned NumOfOp) -> Value * {
|
|
if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
|
|
return isa<ConstantExpr>(V) || isa<Function>(V);
|
|
})) {
|
|
// Expand a vector of constexprs and construct it back with
|
|
// series of insertelement instructions.
|
|
std::list<Value *> OpList;
|
|
std::transform(Vec->op_begin(), Vec->op_end(),
|
|
std::back_inserter(OpList),
|
|
[LowerOp](Value *V) { return LowerOp(V); });
|
|
Value *Repl = nullptr;
|
|
unsigned Idx = 0;
|
|
auto *PhiII = dyn_cast<PHINode>(II);
|
|
Instruction *InsPoint =
|
|
PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
|
|
std::list<Instruction *> ReplList;
|
|
for (auto V : OpList) {
|
|
if (auto *Inst = dyn_cast<Instruction>(V))
|
|
ReplList.push_back(Inst);
|
|
Repl = InsertElementInst::Create(
|
|
(Repl ? Repl : PoisonValue::get(Vec->getType())), V,
|
|
ConstantInt::get(Type::getInt32Ty(Ctx), Idx++), "", InsPoint);
|
|
}
|
|
WorkList.splice(WorkList.begin(), ReplList);
|
|
return Repl;
|
|
}
|
|
return nullptr;
|
|
};
|
|
for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
|
|
auto *Op = II->getOperand(OI);
|
|
if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
|
|
Value *ReplInst = LowerConstantVec(Vec, OI);
|
|
if (ReplInst)
|
|
II->replaceUsesOfWith(Op, ReplInst);
|
|
} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
|
|
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
|
|
} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
|
|
auto ConstMD = dyn_cast<ConstantAsMetadata>(MDAsVal->getMetadata());
|
|
if (!ConstMD)
|
|
continue;
|
|
Constant *C = ConstMD->getValue();
|
|
Value *ReplInst = nullptr;
|
|
if (auto *Vec = dyn_cast<ConstantVector>(C))
|
|
ReplInst = LowerConstantVec(Vec, OI);
|
|
if (auto *CE = dyn_cast<ConstantExpr>(C))
|
|
ReplInst = LowerOp(CE);
|
|
if (!ReplInst)
|
|
continue;
|
|
Metadata *RepMD = ValueAsMetadata::get(ReplInst);
|
|
Value *RepMDVal = MetadataAsValue::get(Ctx, RepMD);
|
|
II->setOperand(OI, RepMDVal);
|
|
WorkList.push_front(cast<Instruction>(ReplInst));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// It fixes calls to OCL builtins that accept vector arguments and one of them
|
|
// is actually a scalar splat.
|
|
void SPIRVRegularizer::visitCallInst(CallInst &CI) {
|
|
auto F = CI.getCalledFunction();
|
|
if (!F)
|
|
return;
|
|
|
|
auto MangledName = F->getName();
|
|
size_t n;
|
|
int status;
|
|
char *NameStr = itaniumDemangle(F->getName().data(), nullptr, &n, &status);
|
|
StringRef DemangledName(NameStr);
|
|
|
|
// TODO: add support for other builtins.
|
|
if (DemangledName.startswith("fmin") || DemangledName.startswith("fmax") ||
|
|
DemangledName.startswith("min") || DemangledName.startswith("max"))
|
|
visitCallScalToVec(&CI, MangledName, DemangledName);
|
|
free(NameStr);
|
|
}
|
|
|
|
void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
|
|
StringRef DemangledName) {
|
|
// Check if all arguments have the same type - it's simple case.
|
|
auto Uniform = true;
|
|
Type *Arg0Ty = CI->getOperand(0)->getType();
|
|
auto IsArg0Vector = isa<VectorType>(Arg0Ty);
|
|
for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
|
|
Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
|
|
if (Uniform)
|
|
return;
|
|
|
|
auto *OldF = CI->getCalledFunction();
|
|
Function *NewF = nullptr;
|
|
if (!Old2NewFuncs.count(OldF)) {
|
|
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
|
|
SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
|
|
auto *NewFTy =
|
|
FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
|
|
NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
|
|
*OldF->getParent());
|
|
ValueToValueMapTy VMap;
|
|
auto NewFArgIt = NewF->arg_begin();
|
|
for (auto &Arg : OldF->args()) {
|
|
auto ArgName = Arg.getName();
|
|
NewFArgIt->setName(ArgName);
|
|
VMap[&Arg] = &(*NewFArgIt++);
|
|
}
|
|
SmallVector<ReturnInst *, 8> Returns;
|
|
CloneFunctionInto(NewF, OldF, VMap,
|
|
CloneFunctionChangeType::LocalChangesOnly, Returns);
|
|
NewF->setAttributes(Attrs);
|
|
Old2NewFuncs[OldF] = NewF;
|
|
} else {
|
|
NewF = Old2NewFuncs[OldF];
|
|
}
|
|
assert(NewF);
|
|
|
|
// This produces an instruction sequence that implements a splat of
|
|
// CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
|
|
// and ShuffleVectorInst to generate the same code as the SPIR-V translator.
|
|
// For instance (transcoding/OpMin.ll), this call
|
|
// call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
|
|
// is translated to
|
|
// %8 = OpUndef %v2uint
|
|
// %14 = OpConstantComposite %v2uint %uint_1 %uint_10
|
|
// ...
|
|
// %10 = OpCompositeInsert %v2uint %uint_5 %8 0
|
|
// %11 = OpVectorShuffle %v2uint %10 %8 0 0
|
|
// %call = OpExtInst %v2uint %1 s_min %14 %11
|
|
auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
|
|
PoisonValue *PVal = PoisonValue::get(Arg0Ty);
|
|
Instruction *Inst =
|
|
InsertElementInst::Create(PVal, CI->getOperand(1), ConstInt, "", CI);
|
|
ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
|
|
Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
|
|
Value *NewVec = new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI);
|
|
CI->setOperand(1, NewVec);
|
|
CI->replaceUsesOfWith(OldF, NewF);
|
|
CI->mutateFunctionType(NewF->getFunctionType());
|
|
}
|
|
|
|
bool SPIRVRegularizer::runOnFunction(Function &F) {
|
|
runLowerConstExpr(F);
|
|
visit(F);
|
|
for (auto &OldNew : Old2NewFuncs) {
|
|
Function *OldF = OldNew.first;
|
|
Function *NewF = OldNew.second;
|
|
NewF->takeName(OldF);
|
|
OldF->eraseFromParent();
|
|
}
|
|
return true;
|
|
}
|
|
|
|
FunctionPass *llvm::createSPIRVRegularizerPass() {
|
|
return new SPIRVRegularizer();
|
|
}
|