llvm-project/mlir/lib/Transforms/ConvertToCFG.cpp

642 lines
27 KiB
C++

//===- ConvertToCFG.cpp - ML function to CFG function conversion ----------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements APIs to convert ML functions into CFG functions.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/Support/CommandLine.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ML function converter
//===----------------------------------------------------------------------===//
namespace {
// Generates CFG function equivalent to the given ML function.
class FunctionConverter : public StmtVisitor<FunctionConverter> {
public:
FunctionConverter(CFGFunction *cfgFunc)
: cfgFunc(cfgFunc), builder(cfgFunc) {}
CFGFunction *convert(MLFunction *mlFunc);
void visitForStmt(ForStmt *forStmt);
void visitIfStmt(IfStmt *ifStmt);
void visitOperationStmt(OperationStmt *opStmt);
private:
CFGValue *getConstantIndexValue(int64_t value);
void visitStmtBlock(StmtBlock *stmtBlock);
CFGValue *buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values);
CFGFunction *cfgFunc;
CFGFuncBuilder builder;
// Mapping between original MLValues and lowered CFGValues.
llvm::DenseMap<const MLValue *, CFGValue *> valueRemapping;
};
} // end anonymous namespace
// Return a vector of OperationStmt's arguments as SSAValues. For each
// statement operands, represented as MLValue, lookup its CFGValue conterpart in
// the valueRemapping table.
static llvm::SmallVector<SSAValue *, 4>
operandsAs(Statement *opStmt,
const llvm::DenseMap<const MLValue *, CFGValue *> &valueRemapping) {
llvm::SmallVector<SSAValue *, 4> operands;
for (const MLValue *operand : opStmt->getOperands()) {
assert(valueRemapping.count(operand) != 0 && "operand is not defined");
operands.push_back(valueRemapping.lookup(operand));
}
return operands;
}
// Convert an operation statement into an operation instruction.
//
// The operation description (name, number and types of operands or results)
// remains the same but the values must be updated to be CFGValues. Update the
// mapping MLValue->CFGValue as the conversion is performed. The operation
// instruction is appended to current block (end of SESE region).
void FunctionConverter::visitOperationStmt(OperationStmt *opStmt) {
// Set up basic operation state (context, name, operands).
OperationState state(cfgFunc->getContext(), opStmt->getLoc(),
opStmt->getName());
state.addOperands(operandsAs(opStmt, valueRemapping));
// Set up operation return types. The corresponding SSAValues will become
// available after the operation is created.
state.addTypes(
functional::map([](SSAValue *result) { return result->getType(); },
opStmt->getResults()));
// Copy attributes.
for (auto attr : opStmt->getAttrs()) {
state.addAttribute(attr.first.strref(), attr.second);
}
auto opInst = builder.createOperation(state);
// Make results of the operation accessible to the following operations
// through remapping.
assert(opInst->getNumResults() == opStmt->getNumResults());
for (unsigned i = 0, n = opInst->getNumResults(); i < n; ++i) {
valueRemapping.insert(
std::make_pair(opStmt->getResult(i), opInst->getResult(i)));
}
}
// Create a CFGValue for the given integer constant of index type.
CFGValue *FunctionConverter::getConstantIndexValue(int64_t value) {
auto op = builder.create<ConstantIndexOp>(builder.getUnknownLoc(), value);
return cast<CFGValue>(op->getResult());
}
// Visit all statements in the given statement block.
void FunctionConverter::visitStmtBlock(StmtBlock *stmtBlock) {
for (auto &stmt : *stmtBlock)
this->visit(&stmt);
}
// Given a range of values, emit the code that reduces them with "min" or "max"
// depending on the provided comparison predicate. The predicate defines which
// comparison to perform, "lt" for "min", "gt" for "max" and is used for the
// `cmpi` operation followed by the `select` operation:
//
// %cond = cmpi "predicate" %v0, %v1
// %result = select %cond, %v0, %v1
//
// Multiple values are scanned in a linear sequence. This creates a data
// dependences that wouldn't exist in a tree reduction, but is easier to
// recognize as a reduction by the subsequent passes.
CFGValue *FunctionConverter::buildMinMaxReductionSeq(
Location loc, CmpIPredicate predicate,
llvm::iterator_range<Operation::result_iterator> values) {
assert(!llvm::empty(values) && "empty min/max chain");
auto valueIt = values.begin();
CFGValue *value = cast<CFGValue>(*valueIt++);
for (; valueIt != values.end(); ++valueIt) {
auto cmpOp = builder.create<CmpIOp>(loc, predicate, value, *valueIt);
auto selectOp =
builder.create<SelectOp>(loc, cmpOp->getResult(), value, *valueIt);
value = cast<CFGValue>(selectOp->getResult());
}
return value;
}
// Convert a "for" loop to a flow of basic blocks.
//
// Create an SESE region for the loop (including its body) and append it to the
// end of the current region. The loop region consists of the initialization
// block that sets up the initial value of the loop induction variable (%iv) and
// computes the loop bounds that are loop-invariant in MLFunctions; the
// condition block that checks the exit condition of the loop; the body SESE
// region; and the end block that post-dominates the loop. The end block of the
// loop becomes the new end of the current SESE region. The body of the loop is
// constructed recursively after starting a new region (it may be, for example,
// a nested loop). Induction variable modification is appended to the body SESE
// region that always loops back to the condition block.
//
// +--------------------------------+
// | <end of current SESE region> |
// | <current insertion point> |
// | br init |
// +--------------------------------+
// |
// v
// +--------------------------------+
// | init: |
// | <start of loop SESE region> |
// | <compute initial %iv value> |
// | br cond(%iv) |
// +--------------------------------+
// |
// -------| |
// | v v
// | +--------------------------------+
// | | cond(%iv): |
// | | <compare %iv to upper bound> |
// | | cond_br %r, body, end |
// | +--------------------------------+
// | | |
// | | -------------|
// | v |
// | +--------------------------------+ |
// | | body: | |
// | | <body SESE region start> | |
// | | <...> | |
// | +--------------------------------+ |
// | | |
// | ... <SESE region of the body> |
// | | |
// | v |
// | +--------------------------------+ |
// | | body-end: | |
// | | <body SESE region end> | |
// | | %new_iv =<add step to %iv> | |
// | | br cond(%new_iv) | |
// | +--------------------------------+ |
// | | |
// |----------- |--------------------
// v
// +--------------------------------+
// | end: |
// | <end of loop SESE region> |
// | <new insertion point> |
// +--------------------------------+
//
void FunctionConverter::visitForStmt(ForStmt *forStmt) {
// First, store the loop insertion location so that we can go back to it after
// creating the new blocks (block creation updates the insertion point).
BasicBlock *loopInsertionPoint = builder.getInsertionBlock();
// Create blocks so that they appear in more human-readable order in the
// output.
BasicBlock *loopInitBlock = builder.createBlock();
BasicBlock *loopConditionBlock = builder.createBlock();
BasicBlock *loopBodyFirstBlock = builder.createBlock();
// At the loop insertion location, branch immediately to the loop init block.
builder.setInsertionPoint(loopInsertionPoint);
builder.create<BranchOp>(builder.getUnknownLoc(), loopInitBlock);
// The loop condition block has an argument for loop induction variable.
// Create it upfront and make the loop induction variable -> basic block
// argument remapping available to the following instructions. ForStatement
// is-a MLValue corresponding to the loop induction variable.
builder.setInsertionPoint(loopConditionBlock);
CFGValue *iv = loopConditionBlock->addArgument(builder.getIndexType());
valueRemapping.insert(std::make_pair(forStmt, iv));
// Recursively construct loop body region.
// Walking manually because we need custom logic before and after traversing
// the list of children.
builder.setInsertionPoint(loopBodyFirstBlock);
visitStmtBlock(forStmt->getBody());
// Builder point is currently at the last block of the loop body. Append the
// induction variable stepping to this block and branch back to the exit
// condition block. Construct an affine map f : (x -> x+step) and apply this
// map to the induction variable.
auto affStep = builder.getAffineConstantExpr(forStmt->getStep());
auto affDim = builder.getAffineDimExpr(0);
auto affStepMap = builder.getAffineMap(1, 0, {affDim + affStep}, {});
auto stepOp =
builder.create<AffineApplyOp>(forStmt->getLoc(), affStepMap, iv);
CFGValue *nextIvValue = cast<CFGValue>(stepOp->getResult(0));
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
nextIvValue);
// Create post-loop block here so that it appears after all loop body blocks.
BasicBlock *postLoopBlock = builder.createBlock();
builder.setInsertionPoint(loopInitBlock);
// Compute loop bounds using affine_apply after remapping its operands.
auto remapOperands = [this](const SSAValue *value) -> SSAValue * {
const MLValue *mlValue = dyn_cast<MLValue>(value);
return valueRemapping.lookup(mlValue);
};
auto operands =
functional::map(remapOperands, forStmt->getLowerBoundOperands());
auto lbAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getLowerBoundMap(), operands);
CFGValue *lowerBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SGT, lbAffineApply->getResults());
operands = functional::map(remapOperands, forStmt->getUpperBoundOperands());
auto ubAffineApply = builder.create<AffineApplyOp>(
forStmt->getLoc(), forStmt->getUpperBoundMap(), operands);
CFGValue *upperBound = buildMinMaxReductionSeq(
forStmt->getLoc(), CmpIPredicate::SLT, ubAffineApply->getResults());
builder.create<BranchOp>(builder.getUnknownLoc(), loopConditionBlock,
lowerBound);
builder.setInsertionPoint(loopConditionBlock);
auto comparisonOp = builder.create<CmpIOp>(
forStmt->getLoc(), CmpIPredicate::SLT, iv, upperBound);
auto comparisonResult = cast<CFGValue>(comparisonOp->getResult());
builder.create<CondBranchOp>(builder.getUnknownLoc(), comparisonResult,
loopBodyFirstBlock, ArrayRef<SSAValue *>(),
postLoopBlock, ArrayRef<SSAValue *>());
// Finally, make sure building can continue by setting the post-loop block
// (end of loop SESE region) as the insertion point.
builder.setInsertionPoint(postLoopBlock);
}
// Convert an "if" statement into a flow of basic blocks.
//
// Create an SESE region for the if statement (including its "then" and optional
// "else" statement blocks) and append it to the end of the current region. The
// conditional region consists of a sequence of condition-checking blocks that
// implement the short-circuit scheme, followed by a "then" SESE region and an
// "else" SESE region, and the continuation block that post-dominates all blocks
// of the "if" statement. The flow of blocks that correspond to the "then" and
// "else" clauses are constructed recursively, enabling easy nesting of "if"
// statements and if-then-else-if chains.
//
// +--------------------------------+
// | <end of current SESE region> |
// | <current insertion point> |
// | %zero = constant 0 : index |
// | %v = affine_apply #expr1(%ops) |
// | %c = cmpi "sge" %v, %zero |
// | cond_br %c, %next, %else |
// +--------------------------------+
// | |
// | --------------|
// v |
// +--------------------------------+ |
// | next: | |
// | <repeat the check for expr2> | |
// | cond_br %c, %next2, %else | |
// +--------------------------------+ |
// | | |
// ... --------------|
// | <Per-expression checks> |
// v |
// +--------------------------------+ |
// | last: | |
// | <repeat the check for exprN> | |
// | cond_br %c, %then, %else | |
// +--------------------------------+ |
// | | |
// | --------------|
// v |
// +--------------------------------+ |
// | then: | |
// | <then SESE region> | |
// +--------------------------------+ |
// | |
// ... <SESE region of "then"> |
// | |
// v |
// +--------------------------------+ |
// | then_end: | |
// | <then SESE region end> | |
// | br continue | |
// +--------------------------------+ |
// | |
// |---------- |-------------
// | V
// | +--------------------------------+
// | | else: |
// | | <else SESE region> |
// | +--------------------------------+
// | |
// | ... <SESE region of "else">
// | |
// | v
// | +--------------------------------+
// | | else_end: |
// | | <else SESE region> |
// | | br continue |
// | +--------------------------------+
// | |
// ------| |
// v v
// +--------------------------------+
// | continue: |
// | <end of "if" SESE region> |
// | <new insertion point> |
// +--------------------------------+
//
void FunctionConverter::visitIfStmt(IfStmt *ifStmt) {
assert(ifStmt != nullptr);
auto integerSet = ifStmt->getCondition().getIntegerSet();
// Create basic blocks for the 'then' block and for the 'else' block.
// Although 'else' block may be empty in absence of an 'else' clause, create
// it anyway for the sake of consistency and output IR readability. Also
// create extra blocks for condition checking to prepare for short-circuit
// logic: conditions in the 'if' statement are conjunctive, so we can jump to
// the false branch as soon as one condition fails. `cond_br` requires
// another block as a target when the condition is true, and that block will
// contain the next condition.
BasicBlock *ifInsertionBlock = builder.getInsertionBlock();
SmallVector<BasicBlock *, 4> ifConditionExtraBlocks;
unsigned numConstraints = integerSet.getNumConstraints();
ifConditionExtraBlocks.reserve(numConstraints - 1);
for (unsigned i = 0, e = numConstraints - 1; i < e; ++i) {
ifConditionExtraBlocks.push_back(builder.createBlock());
}
BasicBlock *thenBlock = builder.createBlock();
BasicBlock *elseBlock = builder.createBlock();
builder.setInsertionPoint(ifInsertionBlock);
// Implement short-circuit logic. For each affine expression in the 'if'
// condition, convert it into an affine map and call `affine_apply` to obtain
// the resulting value. Perform the equality or the greater-than-or-equality
// test between this value and zero depending on the equality flag of the
// condition. If the test fails, jump immediately to the false branch, which
// may be the else block if it is present or the continuation block otherwise.
// If the test succeeds, jump to the next block testing testing the next
// conjunct of the condition in the similar way. When all conjuncts have been
// handled, jump to the 'then' block instead.
SSAValue *zeroConstant = getConstantIndexValue(0);
ifConditionExtraBlocks.push_back(thenBlock);
for (auto tuple :
llvm::zip(integerSet.getConstraints(), integerSet.getEqFlags(),
ifConditionExtraBlocks)) {
AffineExpr constraintExpr = std::get<0>(tuple);
bool isEquality = std::get<1>(tuple);
BasicBlock *nextBlock = std::get<2>(tuple);
// Build and apply an affine map.
auto affineMap =
builder.getAffineMap(integerSet.getNumDims(),
integerSet.getNumSymbols(), constraintExpr, {});
auto affineApplyOp = builder.create<AffineApplyOp>(
ifStmt->getLoc(), affineMap, operandsAs(ifStmt, valueRemapping));
SSAValue *affResult = affineApplyOp->getResult(0);
// Compare the result of the apply and branch.
auto comparisonOp = builder.create<CmpIOp>(
ifStmt->getLoc(), isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE,
affResult, zeroConstant);
builder.create<CondBranchOp>(ifStmt->getLoc(), comparisonOp->getResult(),
nextBlock, /*trueArgs*/ ArrayRef<SSAValue *>(),
elseBlock,
/*falseArgs*/ ArrayRef<SSAValue *>());
builder.setInsertionPoint(nextBlock);
}
ifConditionExtraBlocks.pop_back();
// Recursively traverse the 'then' block.
builder.setInsertionPoint(thenBlock);
visitStmtBlock(ifStmt->getThen());
BasicBlock *lastThenBlock = builder.getInsertionBlock();
// Recursively traverse the 'else' block if present.
builder.setInsertionPoint(elseBlock);
if (ifStmt->hasElse())
visitStmtBlock(ifStmt->getElse());
BasicBlock *lastElseBlock = builder.getInsertionBlock();
// Create the continuation block here so that it appears lexically after the
// 'then' and 'else' blocks, branch from end of 'then' and 'else' SESE regions
// to the continuation block.
BasicBlock *continuationBlock = builder.createBlock();
builder.setInsertionPoint(lastThenBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
builder.setInsertionPoint(lastElseBlock);
builder.create<BranchOp>(ifStmt->getLoc(), continuationBlock);
// Make sure building can continue by setting up the continuation block as the
// insertion point.
builder.setInsertionPoint(continuationBlock);
}
// Entry point of the function convertor.
//
// Conversion is performed by recursively visiting statements of an MLFunction.
// It reasons in terms of single-entry single-exit (SESE) regions that are not
// materialized in the code. Instead, the pointer to the last block of the
// region is maintained throughout the conversion as the insertion point of the
// IR builder since we never change the first block after its creation. "Block"
// statements such as loops and branches create new SESE regions for their
// bodies, and surround them with additional basic blocks for the control flow.
// Individual operations are simply appended to the end of the last basic block
// of the current region. The SESE invariant allows us to easily handle nested
// structures of arbitrary complexity.
//
// During the conversion, we maintain a mapping between the MLValues present in
// the original function and their CFGValue images in the function under
// construction. When an MLValue is used, it gets replaced with the
// corresponding CFGValue that has been defined previously. The value flow
// starts with function arguments converted to basic block arguments.
CFGFunction *FunctionConverter::convert(MLFunction *mlFunc) {
auto outerBlock = builder.createBlock();
// CFGFunctions do not have explicit arguments but use the arguments to the
// first basic block instead. Create those from the MLFunction arguments and
// set up the value remapping.
outerBlock->addArguments(mlFunc->getType().getInputs());
assert(mlFunc->getNumArguments() == outerBlock->getNumArguments());
for (unsigned i = 0, n = mlFunc->getNumArguments(); i < n; ++i) {
const MLValue *mlArgument = mlFunc->getArgument(i);
CFGValue *cfgArgument = outerBlock->getArgument(i);
valueRemapping.insert(std::make_pair(mlArgument, cfgArgument));
}
// Convert statements in order.
for (auto &stmt : *mlFunc->getBody()) {
visit(&stmt);
}
return cfgFunc;
}
//===----------------------------------------------------------------------===//
// Module converter
//===----------------------------------------------------------------------===//
namespace {
// ModuleConverter class does CFG conversion for the whole module.
class ModuleConverter : public ModulePass {
public:
explicit ModuleConverter() : ModulePass(&ModuleConverter::passID) {}
PassResult runOnModule(Module *m) override;
static char passID;
private:
// Generates CFG functions for all ML functions in the module.
void convertMLFunctions();
// Generates CFG function for the given ML function.
CFGFunction *convert(MLFunction *mlFunc);
// Replaces all ML function references in the module
// with references to the generated CFG functions.
void replaceReferences();
// Replaces function references in the given function.
void replaceReferences(CFGFunction *cfgFunc);
// Replaces MLFunctions with their CFG counterparts in the module.
void replaceFunctions();
// Map from ML functions to generated CFG functions.
llvm::DenseMap<MLFunction *, CFGFunction *> generatedFuncs;
Module *module = nullptr;
};
} // end anonymous namespace
char ModuleConverter::passID = 0;
// Iterates over all functions in the module generating CFG functions
// equivalent to ML functions and replacing references to ML functions
// with references to the generated ML functions. The names of the converted
// functions match those of the original functions to avoid breaking any
// external references to the current module. Therefore, converted functions
// are added to the module at the end of the pass, after removing the original
// functions to avoid name clashes. Conversion procedure has access to the
// module as member of ModuleConverter and must not rely on the converted
// function to belong to the module.
PassResult ModuleConverter::runOnModule(Module *m) {
module = m;
convertMLFunctions();
replaceReferences();
replaceFunctions();
return success();
}
void ModuleConverter::convertMLFunctions() {
for (Function &fn : *module) {
if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
generatedFuncs[mlFunc] = convert(mlFunc);
}
}
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(MLFunction *mlFunc) {
// Use the same name as for ML function; do not add the converted function to
// the module yet to avoid collision.
auto name = mlFunc->getName().str();
auto *cfgFunc = new CFGFunction(mlFunc->getLoc(), name, mlFunc->getType(),
mlFunc->getAttrs());
// Generates the body of the CFG function.
return FunctionConverter(cfgFunc).convert(mlFunc);
}
// Replace references to MLFunctions with the references to the converted
// CFGFunctions. Since this all MLFunctions are converted at this point, it is
// unnecessary to replace references in the MLFunctions that are going to be
// removed anyway. However, it is necessary to replace the references in the
// converted CFGFunctions that have not been added to the module yet.
void ModuleConverter::replaceReferences() {
// Build the remapping between function attributes pointing to ML functions
// and the newly created function attributes pointing to the converted CFG
// functions.
llvm::DenseMap<Attribute, FunctionAttr> remappingTable;
for (const Function &fn : *module) {
const auto *mlFunc = dyn_cast<MLFunction>(&fn);
if (!mlFunc)
continue;
CFGFunction *convertedFunc = generatedFuncs.lookup(mlFunc);
assert(convertedFunc && "ML function was not converted");
MLIRContext *context = module->getContext();
auto mlFuncAttr = FunctionAttr::get(mlFunc, context);
auto cfgFuncAttr = FunctionAttr::get(convertedFunc, module->getContext());
remappingTable.insert({mlFuncAttr, cfgFuncAttr});
}
// Remap in existing functions.
remapFunctionAttrs(*module, remappingTable);
// Remap in generated functions.
for (auto pair : generatedFuncs) {
remapFunctionAttrs(*pair.second, remappingTable);
}
}
// Replace the value of a function attribute named "name" attached to the
// operation "op" and containing an MLFunction-typed value with the result of
// converting "func" to a CFGFunction.
static inline void replaceMLFunctionAttr(
Operation &op, Identifier name, const Function *func,
const llvm::DenseMap<MLFunction *, CFGFunction *> &generatedFuncs) {
const auto *mlFunc = dyn_cast<MLFunction>(func);
if (!mlFunc)
return;
Builder b(op.getContext());
auto cfgFunc = generatedFuncs.lookup(mlFunc);
op.setAttr(name, b.getFunctionAttr(cfgFunc));
}
// The CFG and ML functions have the same name. First, erase the MLFunction.
// Then insert the CFGFunction at the same place.
void ModuleConverter::replaceFunctions() {
for (auto pair : generatedFuncs) {
auto &functions = module->getFunctions();
auto it = functions.erase(pair.first);
functions.insert(it, pair.second);
}
}
//===----------------------------------------------------------------------===//
// Entry point method
//===----------------------------------------------------------------------===//
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly
/// generated CFG functions. Converted functions have the same names as the
/// original functions to preserve module linking.
ModulePass *mlir::createConvertToCFGPass() { return new ModuleConverter(); }
static PassRegistration<ModuleConverter>
pass("convert-to-cfg",
"Convert all ML functions in the module to CFG ones");