Implement StmtBlocks support for arguments and pred/succ iteration. This isn't

tested yet, but will when stuff starts switching over to it.  This is part 3/n of merging CFGFunctions and MLFunctions.

PiperOrigin-RevId: 226794787
This commit is contained in:
Chris Lattner 2018-12-24 18:01:01 -08:00 committed by jpienaar
parent 87ce4cc501
commit eadaa1101c
11 changed files with 385 additions and 26 deletions

View File

@ -47,7 +47,7 @@ public:
CFGFunction *getParent() { return function; }
//===--------------------------------------------------------------------===//
// Block arguments management
// Block argument management
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.

View File

@ -51,6 +51,7 @@ public:
return true;
case SSAValueKind::MLFuncArgument:
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
return false;

View File

@ -25,16 +25,18 @@
#include "mlir/IR/SSAValue.h"
namespace mlir {
class MLValue;
class Statement;
class MLFunction;
class ForStmt;
class MLValue;
class MLFunction;
class Statement;
class StmtBlock;
/// This enum contains all of the SSA value kinds that are valid in an ML
/// function. This should be kept as a proper subtype of SSAValueKind,
/// including having all of the values of the enumerators align.
enum class MLValueKind {
MLFuncArgument = (int)SSAValueKind::MLFuncArgument,
BlockArgument = (int)SSAValueKind::BlockArgument,
StmtResult = (int)SSAValueKind::StmtResult,
ForStmt = (int)SSAValueKind::ForStmt,
};
@ -54,6 +56,7 @@ public:
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::MLFuncArgument:
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
return true;
@ -102,6 +105,33 @@ private:
MLFunction *const owner;
};
/// Block arguments are ML Values.
class BlockArgument : public MLValue {
public:
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::BlockArgument;
}
/// Return the function that this argument is defined in.
MLFunction *getFunction();
const MLFunction *getFunction() const {
return const_cast<BlockArgument *>(this)->getFunction();
}
StmtBlock *getOwner() { return owner; }
const StmtBlock *getOwner() const { return owner; }
private:
friend class StmtBlock; // For access to private constructor.
BlockArgument(Type type, StmtBlock *owner)
: MLValue(MLValueKind::BlockArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
StmtBlock *const owner;
};
/// This is a value defined by a result of an operation instruction.
class StmtResult : public MLValue {
public:

View File

@ -37,6 +37,7 @@ enum class SSAValueKind {
BBArgument, // basic block argument
InstResult, // instruction result
MLFuncArgument, // ML function argument
BlockArgument, // Block argument
StmtResult, // statement result
ForStmt, // for statement induction variable
};

View File

@ -34,6 +34,10 @@ class MLFunction;
class StmtBlock;
class ForStmt;
class MLIRContext;
/// The operand of a Terminator contains a StmtBlock.
using StmtBlockOperand = IROperandImpl<StmtBlock, OperationStmt>;
} // namespace mlir
//===----------------------------------------------------------------------===//

View File

@ -24,10 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLValue.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StmtBlock.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@ -36,9 +33,6 @@ class IntegerSet;
class AffineCondition;
class OperationStmt;
/// The operand of a Terminator contains a StmtBlock.
using StmtBlockOperand = IROperandImpl<StmtBlock, OperationStmt>;
/// Operation statements represent operations inside ML functions.
class OperationStmt final
: public Operation,

View File

@ -29,6 +29,11 @@ class MLFunction;
class IfStmt;
class MLValue;
// TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of
// these go away.
template <typename BlockType> class StmtPredecessorIterator;
template <typename BlockType> class StmtSuccessorIterator;
/// Statement block represents an ordered list of statements, with the order
/// being the contiguous lexical order in which the statements appear as
/// children of a parent statement in the ML Function.
@ -40,7 +45,7 @@ public:
IfClause // IfClause
};
~StmtBlock() { clear(); }
~StmtBlock();
void clear() {
// Clear statements in the reverse order so that uses are destroyed
@ -63,6 +68,36 @@ public:
/// The function is determined by traversing the chain of parent statements.
MLFunction *findFunction() const;
//===--------------------------------------------------------------------===//
// Block argument management
//===--------------------------------------------------------------------===//
// This is the list of arguments to the block.
using BlockArgListType = ArrayRef<BlockArgument *>;
BlockArgListType getArguments() const { return arguments; }
using args_iterator = BlockArgListType::iterator;
using reverse_args_iterator = BlockArgListType::reverse_iterator;
args_iterator args_begin() const { return getArguments().begin(); }
args_iterator args_end() const { return getArguments().end(); }
reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); }
reverse_args_iterator args_rend() const { return getArguments().rend(); }
bool args_empty() const { return arguments.empty(); }
/// Add one value to the argument list.
BlockArgument *addArgument(Type type);
/// Add one argument to the argument list for each type specified in the list.
llvm::iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index);
unsigned getNumArguments() const { return arguments.size(); }
BlockArgument *getArgument(unsigned i) { return arguments[i]; }
const BlockArgument *getArgument(unsigned i) const { return arguments[i]; }
//===--------------------------------------------------------------------===//
// Statement list management
//===--------------------------------------------------------------------===//
@ -100,19 +135,10 @@ public:
return const_cast<StmtBlock *>(this)->front();
}
/// getSublistAccess() - Returns pointer to member of statement list
static StmtListType StmtBlock::*getSublistAccess(Statement *) {
return &StmtBlock::statements;
}
/// These have unconventional names to avoid derive class ambiguities.
void printBlock(raw_ostream &os) const;
void dumpBlock() const;
/// Returns the statement's position in this block or -1 if the statement is
/// not present.
int findStmtPosInBlock(const Statement &stmt) const {
unsigned j = 0;
int64_t findStmtPosInBlock(const Statement &stmt) const {
int64_t j = 0;
for (const auto &s : statements) {
if (&s == &stmt)
return j;
@ -129,6 +155,75 @@ public:
return const_cast<Statement *>(findAncestorStmtInBlock(*stmt));
}
//===--------------------------------------------------------------------===//
// Terminator management
//===--------------------------------------------------------------------===//
/// Get the terminator instruction of this block, or null if the block is
/// malformed.
OperationStmt *getTerminator();
const OperationStmt *getTerminator() const {
return const_cast<StmtBlock *>(this)->getTerminator();
}
//===--------------------------------------------------------------------===//
// Predecessors and successors.
//===--------------------------------------------------------------------===//
// Predecessor iteration.
using const_pred_iterator = StmtPredecessorIterator<const StmtBlock>;
const_pred_iterator pred_begin() const;
const_pred_iterator pred_end() const;
llvm::iterator_range<const_pred_iterator> getPredecessors() const;
using pred_iterator = StmtPredecessorIterator<StmtBlock>;
pred_iterator pred_begin();
pred_iterator pred_end();
llvm::iterator_range<pred_iterator> getPredecessors();
/// Return true if this block has no predecessors.
bool hasNoPredecessors() const;
/// If this block has exactly one predecessor, return it. Otherwise, return
/// null.
///
/// Note that if a block has duplicate predecessors from a single block (e.g.
/// if you have a conditional branch with the same block as the true/false
/// destinations) is not considered to be a single predecessor.
StmtBlock *getSinglePredecessor();
const StmtBlock *getSinglePredecessor() const {
return const_cast<StmtBlock *>(this)->getSinglePredecessor();
}
// Indexed successor access.
unsigned getNumSuccessors() const;
const StmtBlock *getSuccessor(unsigned i) const {
return const_cast<StmtBlock *>(this)->getSuccessor(i);
}
StmtBlock *getSuccessor(unsigned i);
// Successor iteration.
using const_succ_iterator = StmtSuccessorIterator<const StmtBlock>;
const_succ_iterator succ_begin() const;
const_succ_iterator succ_end() const;
llvm::iterator_range<const_succ_iterator> getSuccessors() const;
using succ_iterator = StmtSuccessorIterator<StmtBlock>;
succ_iterator succ_begin();
succ_iterator succ_end();
llvm::iterator_range<succ_iterator> getSuccessors();
/// getSublistAccess() - Returns pointer to member of statement list
static StmtListType StmtBlock::*getSublistAccess(Statement *) {
return &StmtBlock::statements;
}
/// These have unconventional names to avoid derive class ambiguities.
void printBlock(raw_ostream &os) const;
void dumpBlock() const;
protected:
StmtBlock(StmtBlockKind kind) : kind(kind) {}
@ -137,9 +232,142 @@ private:
/// This is the list of statements in the block.
StmtListType statements;
/// This is the list of arguments to the block.
std::vector<BlockArgument *> arguments;
StmtBlock(const StmtBlock &) = delete;
void operator=(const StmtBlock &) = delete;
};
//===----------------------------------------------------------------------===//
// Predecessors
//===----------------------------------------------------------------------===//
/// Implement a predecessor iterator as a forward iterator. This works by
/// walking the use lists of the blocks. The entries on this list are the
/// StmtBlockOperands that are embedded into terminator instructions. From the
/// operand, we can get the terminator that contains it, and it's parent block
/// is the predecessor.
template <typename BlockType>
class StmtPredecessorIterator
: public llvm::iterator_facade_base<StmtPredecessorIterator<BlockType>,
std::forward_iterator_tag,
BlockType *> {
public:
StmtPredecessorIterator(StmtBlockOperand *firstOperand)
: bbUseIterator(firstOperand) {}
StmtPredecessorIterator &operator=(const StmtPredecessorIterator &rhs) {
bbUseIterator = rhs.bbUseIterator;
}
bool operator==(const StmtPredecessorIterator &rhs) const {
return bbUseIterator == rhs.bbUseIterator;
}
BlockType *operator*() const {
// The use iterator points to an operand of a terminator. The predecessor
// we return is the block that the terminator is embedded into.
return bbUseIterator.getUser()->getBlock();
}
StmtPredecessorIterator &operator++() {
++bbUseIterator;
return *this;
}
/// Get the successor number in the predecessor terminator.
unsigned getSuccessorIndex() const {
return bbUseIterator->getOperandNumber();
}
private:
using BBUseIterator = SSAValueUseIterator<StmtBlockOperand, OperationStmt>;
BBUseIterator bbUseIterator;
};
inline auto StmtBlock::pred_begin() const -> const_pred_iterator {
return const_pred_iterator((StmtBlockOperand *)getFirstUse());
}
inline auto StmtBlock::pred_end() const -> const_pred_iterator {
return const_pred_iterator(nullptr);
}
inline auto StmtBlock::getPredecessors() const
-> llvm::iterator_range<const_pred_iterator> {
return {pred_begin(), pred_end()};
}
inline auto StmtBlock::pred_begin() -> pred_iterator {
return pred_iterator((StmtBlockOperand *)getFirstUse());
}
inline auto StmtBlock::pred_end() -> pred_iterator {
return pred_iterator(nullptr);
}
inline auto StmtBlock::getPredecessors()
-> llvm::iterator_range<pred_iterator> {
return {pred_begin(), pred_end()};
}
//===----------------------------------------------------------------------===//
// Successors
//===----------------------------------------------------------------------===//
/// This template implments the successor iterators for StmtBlock.
template <typename BlockType>
class StmtSuccessorIterator final
: public IndexedAccessorIterator<StmtSuccessorIterator<BlockType>,
BlockType, BlockType> {
public:
/// Initializes the result iterator to the specified index.
StmtSuccessorIterator(BlockType *object, unsigned index)
: IndexedAccessorIterator<StmtSuccessorIterator<BlockType>, BlockType,
BlockType>(object, index) {}
StmtSuccessorIterator(const StmtSuccessorIterator &other)
: StmtSuccessorIterator(other.object, other.index) {}
/// Support converting to the const variant. This will be a no-op for const
/// variant.
operator StmtSuccessorIterator<const BlockType>() const {
return StmtSuccessorIterator<const BlockType>(this->object, this->index);
}
BlockType *operator*() const {
return this->object->getSuccessor(this->index);
}
/// Get the successor number in the terminator.
unsigned getSuccessorIndex() const { return this->index; }
};
inline auto StmtBlock::succ_begin() const -> const_succ_iterator {
return const_succ_iterator(this, 0);
}
inline auto StmtBlock::succ_end() const -> const_succ_iterator {
return const_succ_iterator(this, getNumSuccessors());
}
inline auto StmtBlock::getSuccessors() const
-> llvm::iterator_range<const_succ_iterator> {
return {succ_begin(), succ_end()};
}
inline auto StmtBlock::succ_begin() -> succ_iterator {
return succ_iterator(this, 0);
}
inline auto StmtBlock::succ_end() -> succ_iterator {
return succ_iterator(this, getNumSuccessors());
}
inline auto StmtBlock::getSuccessors() -> llvm::iterator_range<succ_iterator> {
return {succ_begin(), succ_end()};
}
} //end namespace mlir
#endif // MLIR_IR_STMTBLOCK_H

View File

@ -1011,6 +1011,7 @@ protected:
}
// Otherwise number it normally.
LLVM_FALLTHROUGH;
case SSAValueKind::BlockArgument:
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
// This is an uninteresting result, give it a boring number and be
@ -1576,8 +1577,9 @@ void IntegerSet::print(raw_ostream &os) const {
void SSAValue::print(raw_ostream &os) const {
switch (getKind()) {
case SSAValueKind::BBArgument:
case SSAValueKind::BlockArgument:
// TODO: Improve this.
os << "<bb argument>\n";
os << "<block argument>\n";
return;
case SSAValueKind::InstResult:
return getDefiningInst()->print(os);

View File

@ -56,6 +56,8 @@ Function *SSAValue::getFunction() {
return getDefiningInst()->getFunction();
case SSAValueKind::MLFuncArgument:
return cast<MLFuncArgument>(this)->getFunction();
case SSAValueKind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult:
return getDefiningStmt()->findFunction();
case SSAValueKind::ForStmt:
@ -113,3 +115,14 @@ CFGFunction *BBArgument::getFunction() {
MLFunction *MLValue::getFunction() {
return cast<MLFunction>(static_cast<SSAValue *>(this)->getFunction());
}
//===----------------------------------------------------------------------===//
// BlockArgument implementation.
//===----------------------------------------------------------------------===//
/// Return the function that this argument is defined in.
MLFunction *BlockArgument::getFunction() {
if (auto *owner = getOwner())
return owner->findFunction();
return nullptr;
}

View File

@ -47,6 +47,11 @@ template <> unsigned StmtOperand::getOperandNumber() const {
return this - &getOwner()->getStmtOperands()[0];
}
/// Return which operand this is in the operand list.
template <> unsigned StmtBlockOperand::getOperandNumber() const {
return this - &getOwner()->getBlockOperands()[0];
}
//===----------------------------------------------------------------------===//
// Statement
//===----------------------------------------------------------------------===//

View File

@ -20,9 +20,11 @@
#include "mlir/IR/Statements.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// Statement block
//===----------------------------------------------------------------------===//
StmtBlock::~StmtBlock() {
clear();
llvm::DeleteContainerPointers(arguments);
}
Statement *StmtBlock::getContainingStmt() {
switch (kind) {
@ -62,3 +64,82 @@ StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const {
}
return currStmt;
}
//===----------------------------------------------------------------------===//
// Argument list management.
//===----------------------------------------------------------------------===//
BlockArgument *StmtBlock::addArgument(Type type) {
auto *arg = new BlockArgument(type, this);
arguments.push_back(arg);
return arg;
}
/// Add one argument to the argument list for each type specified in the list.
auto StmtBlock::addArguments(ArrayRef<Type> types)
-> llvm::iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
for (auto type : types) {
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}
void StmtBlock::eraseArgument(unsigned index) {
assert(index < arguments.size());
// Delete the argument.
delete arguments[index];
arguments.erase(arguments.begin() + index);
// Erase this argument from each of the predecessor's terminator.
for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
++predIt) {
auto *predTerminator = (*predIt)->getTerminator();
predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
}
}
//===----------------------------------------------------------------------===//
// Terminator management
//===----------------------------------------------------------------------===//
OperationStmt *StmtBlock::getTerminator() {
if (empty())
return nullptr;
// Check if the last instruction is a terminator.
auto &backInst = statements.back();
auto *opStmt = dyn_cast<OperationStmt>(&backInst);
if (!opStmt || !opStmt->isTerminator())
return nullptr;
return opStmt;
}
/// Return true if this block has no predecessors.
bool StmtBlock::hasNoPredecessors() const { return pred_begin() == pred_end(); }
// Indexed successor access.
unsigned StmtBlock::getNumSuccessors() const {
return getTerminator()->getNumSuccessors();
}
StmtBlock *StmtBlock::getSuccessor(unsigned i) {
return getTerminator()->getSuccessor(i);
}
/// If this block has exactly one predecessor, return it. Otherwise, return
/// null.
///
/// Note that multiple edges from a single block (e.g. if you have a cond
/// branch with the same block as the true/false destinations) is not
/// considered to be a single predecessor.
StmtBlock *StmtBlock::getSinglePredecessor() {
auto it = pred_begin();
if (it == pred_end())
return nullptr;
auto *firstPred = *it;
++it;
return it == pred_end() ? firstPred : nullptr;
}