diff --git a/mlir/include/mlir/IR/BasicBlock.h b/mlir/include/mlir/IR/BasicBlock.h index 487c1bb63a4d..1b38a5eaa042 100644 --- a/mlir/include/mlir/IR/BasicBlock.h +++ b/mlir/include/mlir/IR/BasicBlock.h @@ -47,7 +47,7 @@ public: CFGFunction *getParent() { return function; } //===--------------------------------------------------------------------===// - // Block arguments management + // Block argument management //===--------------------------------------------------------------------===// // This is the list of arguments to the block. diff --git a/mlir/include/mlir/IR/CFGValue.h b/mlir/include/mlir/IR/CFGValue.h index d0140f3a3674..8dd3f96d5641 100644 --- a/mlir/include/mlir/IR/CFGValue.h +++ b/mlir/include/mlir/IR/CFGValue.h @@ -51,6 +51,7 @@ public: return true; case SSAValueKind::MLFuncArgument: + case SSAValueKind::BlockArgument: case SSAValueKind::StmtResult: case SSAValueKind::ForStmt: return false; diff --git a/mlir/include/mlir/IR/MLValue.h b/mlir/include/mlir/IR/MLValue.h index 0c6c0b226963..847bd1bccf0d 100644 --- a/mlir/include/mlir/IR/MLValue.h +++ b/mlir/include/mlir/IR/MLValue.h @@ -25,16 +25,18 @@ #include "mlir/IR/SSAValue.h" namespace mlir { -class MLValue; -class Statement; -class MLFunction; class ForStmt; +class MLValue; +class MLFunction; +class Statement; +class StmtBlock; /// This enum contains all of the SSA value kinds that are valid in an ML /// function. This should be kept as a proper subtype of SSAValueKind, /// including having all of the values of the enumerators align. enum class MLValueKind { MLFuncArgument = (int)SSAValueKind::MLFuncArgument, + BlockArgument = (int)SSAValueKind::BlockArgument, StmtResult = (int)SSAValueKind::StmtResult, ForStmt = (int)SSAValueKind::ForStmt, }; @@ -54,6 +56,7 @@ public: static bool classof(const SSAValue *value) { switch (value->getKind()) { case SSAValueKind::MLFuncArgument: + case SSAValueKind::BlockArgument: case SSAValueKind::StmtResult: case SSAValueKind::ForStmt: return true; @@ -102,6 +105,33 @@ private: MLFunction *const owner; }; +/// Block arguments are ML Values. +class BlockArgument : public MLValue { +public: + static bool classof(const SSAValue *value) { + return value->getKind() == SSAValueKind::BlockArgument; + } + + /// Return the function that this argument is defined in. + MLFunction *getFunction(); + const MLFunction *getFunction() const { + return const_cast(this)->getFunction(); + } + + StmtBlock *getOwner() { return owner; } + const StmtBlock *getOwner() const { return owner; } + +private: + friend class StmtBlock; // For access to private constructor. + BlockArgument(Type type, StmtBlock *owner) + : MLValue(MLValueKind::BlockArgument, type), owner(owner) {} + + /// The owner of this operand. + /// TODO: can encode this more efficiently to avoid the space hit of this + /// through bitpacking shenanigans. + StmtBlock *const owner; +}; + /// This is a value defined by a result of an operation instruction. class StmtResult : public MLValue { public: diff --git a/mlir/include/mlir/IR/SSAValue.h b/mlir/include/mlir/IR/SSAValue.h index 29bc6f034c50..919429b209a5 100644 --- a/mlir/include/mlir/IR/SSAValue.h +++ b/mlir/include/mlir/IR/SSAValue.h @@ -37,6 +37,7 @@ enum class SSAValueKind { BBArgument, // basic block argument InstResult, // instruction result MLFuncArgument, // ML function argument + BlockArgument, // Block argument StmtResult, // statement result ForStmt, // for statement induction variable }; diff --git a/mlir/include/mlir/IR/Statement.h b/mlir/include/mlir/IR/Statement.h index 8b5dce7aad81..d6576d32486e 100644 --- a/mlir/include/mlir/IR/Statement.h +++ b/mlir/include/mlir/IR/Statement.h @@ -34,6 +34,10 @@ class MLFunction; class StmtBlock; class ForStmt; class MLIRContext; + +/// The operand of a Terminator contains a StmtBlock. +using StmtBlockOperand = IROperandImpl; + } // namespace mlir //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Statements.h b/mlir/include/mlir/IR/Statements.h index 6b8aa0355c5a..2e07bfb8d32b 100644 --- a/mlir/include/mlir/IR/Statements.h +++ b/mlir/include/mlir/IR/Statements.h @@ -24,10 +24,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/IR/MLValue.h" -#include "mlir/IR/Operation.h" #include "mlir/IR/StmtBlock.h" -#include "mlir/Support/LLVM.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -36,9 +33,6 @@ class IntegerSet; class AffineCondition; class OperationStmt; -/// The operand of a Terminator contains a StmtBlock. -using StmtBlockOperand = IROperandImpl; - /// Operation statements represent operations inside ML functions. class OperationStmt final : public Operation, diff --git a/mlir/include/mlir/IR/StmtBlock.h b/mlir/include/mlir/IR/StmtBlock.h index c9f2638d56d2..65e0f19066e3 100644 --- a/mlir/include/mlir/IR/StmtBlock.h +++ b/mlir/include/mlir/IR/StmtBlock.h @@ -29,6 +29,11 @@ class MLFunction; class IfStmt; class MLValue; +// TODO(clattner): drop the Stmt prefixes on these once BasicBlock's versions of +// these go away. +template class StmtPredecessorIterator; +template class StmtSuccessorIterator; + /// Statement block represents an ordered list of statements, with the order /// being the contiguous lexical order in which the statements appear as /// children of a parent statement in the ML Function. @@ -40,7 +45,7 @@ public: IfClause // IfClause }; - ~StmtBlock() { clear(); } + ~StmtBlock(); void clear() { // Clear statements in the reverse order so that uses are destroyed @@ -63,6 +68,36 @@ public: /// The function is determined by traversing the chain of parent statements. MLFunction *findFunction() const; + //===--------------------------------------------------------------------===// + // Block argument management + //===--------------------------------------------------------------------===// + + // This is the list of arguments to the block. + using BlockArgListType = ArrayRef; + BlockArgListType getArguments() const { return arguments; } + + using args_iterator = BlockArgListType::iterator; + using reverse_args_iterator = BlockArgListType::reverse_iterator; + args_iterator args_begin() const { return getArguments().begin(); } + args_iterator args_end() const { return getArguments().end(); } + reverse_args_iterator args_rbegin() const { return getArguments().rbegin(); } + reverse_args_iterator args_rend() const { return getArguments().rend(); } + + bool args_empty() const { return arguments.empty(); } + + /// Add one value to the argument list. + BlockArgument *addArgument(Type type); + + /// Add one argument to the argument list for each type specified in the list. + llvm::iterator_range addArguments(ArrayRef types); + + /// Erase the argument at 'index' and remove it from the argument list. + void eraseArgument(unsigned index); + + unsigned getNumArguments() const { return arguments.size(); } + BlockArgument *getArgument(unsigned i) { return arguments[i]; } + const BlockArgument *getArgument(unsigned i) const { return arguments[i]; } + //===--------------------------------------------------------------------===// // Statement list management //===--------------------------------------------------------------------===// @@ -100,19 +135,10 @@ public: return const_cast(this)->front(); } - /// getSublistAccess() - Returns pointer to member of statement list - static StmtListType StmtBlock::*getSublistAccess(Statement *) { - return &StmtBlock::statements; - } - - /// These have unconventional names to avoid derive class ambiguities. - void printBlock(raw_ostream &os) const; - void dumpBlock() const; - /// Returns the statement's position in this block or -1 if the statement is /// not present. - int findStmtPosInBlock(const Statement &stmt) const { - unsigned j = 0; + int64_t findStmtPosInBlock(const Statement &stmt) const { + int64_t j = 0; for (const auto &s : statements) { if (&s == &stmt) return j; @@ -129,6 +155,75 @@ public: return const_cast(findAncestorStmtInBlock(*stmt)); } + //===--------------------------------------------------------------------===// + // Terminator management + //===--------------------------------------------------------------------===// + + /// Get the terminator instruction of this block, or null if the block is + /// malformed. + OperationStmt *getTerminator(); + + const OperationStmt *getTerminator() const { + return const_cast(this)->getTerminator(); + } + + //===--------------------------------------------------------------------===// + // Predecessors and successors. + //===--------------------------------------------------------------------===// + + // Predecessor iteration. + using const_pred_iterator = StmtPredecessorIterator; + const_pred_iterator pred_begin() const; + const_pred_iterator pred_end() const; + llvm::iterator_range getPredecessors() const; + + using pred_iterator = StmtPredecessorIterator; + pred_iterator pred_begin(); + pred_iterator pred_end(); + llvm::iterator_range getPredecessors(); + + /// Return true if this block has no predecessors. + bool hasNoPredecessors() const; + + /// If this block has exactly one predecessor, return it. Otherwise, return + /// null. + /// + /// Note that if a block has duplicate predecessors from a single block (e.g. + /// if you have a conditional branch with the same block as the true/false + /// destinations) is not considered to be a single predecessor. + StmtBlock *getSinglePredecessor(); + + const StmtBlock *getSinglePredecessor() const { + return const_cast(this)->getSinglePredecessor(); + } + + // Indexed successor access. + unsigned getNumSuccessors() const; + const StmtBlock *getSuccessor(unsigned i) const { + return const_cast(this)->getSuccessor(i); + } + StmtBlock *getSuccessor(unsigned i); + + // Successor iteration. + using const_succ_iterator = StmtSuccessorIterator; + const_succ_iterator succ_begin() const; + const_succ_iterator succ_end() const; + llvm::iterator_range getSuccessors() const; + + using succ_iterator = StmtSuccessorIterator; + succ_iterator succ_begin(); + succ_iterator succ_end(); + llvm::iterator_range getSuccessors(); + + /// getSublistAccess() - Returns pointer to member of statement list + static StmtListType StmtBlock::*getSublistAccess(Statement *) { + return &StmtBlock::statements; + } + + /// These have unconventional names to avoid derive class ambiguities. + void printBlock(raw_ostream &os) const; + void dumpBlock() const; + protected: StmtBlock(StmtBlockKind kind) : kind(kind) {} @@ -137,9 +232,142 @@ private: /// This is the list of statements in the block. StmtListType statements; + /// This is the list of arguments to the block. + std::vector arguments; + StmtBlock(const StmtBlock &) = delete; void operator=(const StmtBlock &) = delete; }; +//===----------------------------------------------------------------------===// +// Predecessors +//===----------------------------------------------------------------------===// + +/// Implement a predecessor iterator as a forward iterator. This works by +/// walking the use lists of the blocks. The entries on this list are the +/// StmtBlockOperands that are embedded into terminator instructions. From the +/// operand, we can get the terminator that contains it, and it's parent block +/// is the predecessor. +template +class StmtPredecessorIterator + : public llvm::iterator_facade_base, + std::forward_iterator_tag, + BlockType *> { +public: + StmtPredecessorIterator(StmtBlockOperand *firstOperand) + : bbUseIterator(firstOperand) {} + + StmtPredecessorIterator &operator=(const StmtPredecessorIterator &rhs) { + bbUseIterator = rhs.bbUseIterator; + } + + bool operator==(const StmtPredecessorIterator &rhs) const { + return bbUseIterator == rhs.bbUseIterator; + } + + BlockType *operator*() const { + // The use iterator points to an operand of a terminator. The predecessor + // we return is the block that the terminator is embedded into. + return bbUseIterator.getUser()->getBlock(); + } + + StmtPredecessorIterator &operator++() { + ++bbUseIterator; + return *this; + } + + /// Get the successor number in the predecessor terminator. + unsigned getSuccessorIndex() const { + return bbUseIterator->getOperandNumber(); + } + +private: + using BBUseIterator = SSAValueUseIterator; + BBUseIterator bbUseIterator; +}; + +inline auto StmtBlock::pred_begin() const -> const_pred_iterator { + return const_pred_iterator((StmtBlockOperand *)getFirstUse()); +} + +inline auto StmtBlock::pred_end() const -> const_pred_iterator { + return const_pred_iterator(nullptr); +} + +inline auto StmtBlock::getPredecessors() const + -> llvm::iterator_range { + return {pred_begin(), pred_end()}; +} + +inline auto StmtBlock::pred_begin() -> pred_iterator { + return pred_iterator((StmtBlockOperand *)getFirstUse()); +} + +inline auto StmtBlock::pred_end() -> pred_iterator { + return pred_iterator(nullptr); +} + +inline auto StmtBlock::getPredecessors() + -> llvm::iterator_range { + return {pred_begin(), pred_end()}; +} + +//===----------------------------------------------------------------------===// +// Successors +//===----------------------------------------------------------------------===// + +/// This template implments the successor iterators for StmtBlock. +template +class StmtSuccessorIterator final + : public IndexedAccessorIterator, + BlockType, BlockType> { +public: + /// Initializes the result iterator to the specified index. + StmtSuccessorIterator(BlockType *object, unsigned index) + : IndexedAccessorIterator, BlockType, + BlockType>(object, index) {} + + StmtSuccessorIterator(const StmtSuccessorIterator &other) + : StmtSuccessorIterator(other.object, other.index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator StmtSuccessorIterator() const { + return StmtSuccessorIterator(this->object, this->index); + } + + BlockType *operator*() const { + return this->object->getSuccessor(this->index); + } + + /// Get the successor number in the terminator. + unsigned getSuccessorIndex() const { return this->index; } +}; + +inline auto StmtBlock::succ_begin() const -> const_succ_iterator { + return const_succ_iterator(this, 0); +} + +inline auto StmtBlock::succ_end() const -> const_succ_iterator { + return const_succ_iterator(this, getNumSuccessors()); +} + +inline auto StmtBlock::getSuccessors() const + -> llvm::iterator_range { + return {succ_begin(), succ_end()}; +} + +inline auto StmtBlock::succ_begin() -> succ_iterator { + return succ_iterator(this, 0); +} + +inline auto StmtBlock::succ_end() -> succ_iterator { + return succ_iterator(this, getNumSuccessors()); +} + +inline auto StmtBlock::getSuccessors() -> llvm::iterator_range { + return {succ_begin(), succ_end()}; +} + } //end namespace mlir #endif // MLIR_IR_STMTBLOCK_H diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 58f34af60f50..2de17563d932 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1011,6 +1011,7 @@ protected: } // Otherwise number it normally. LLVM_FALLTHROUGH; + case SSAValueKind::BlockArgument: case SSAValueKind::InstResult: case SSAValueKind::StmtResult: // This is an uninteresting result, give it a boring number and be @@ -1576,8 +1577,9 @@ void IntegerSet::print(raw_ostream &os) const { void SSAValue::print(raw_ostream &os) const { switch (getKind()) { case SSAValueKind::BBArgument: + case SSAValueKind::BlockArgument: // TODO: Improve this. - os << "\n"; + os << "\n"; return; case SSAValueKind::InstResult: return getDefiningInst()->print(os); diff --git a/mlir/lib/IR/SSAValue.cpp b/mlir/lib/IR/SSAValue.cpp index 14fd4a459d3e..2179ba0fc11d 100644 --- a/mlir/lib/IR/SSAValue.cpp +++ b/mlir/lib/IR/SSAValue.cpp @@ -56,6 +56,8 @@ Function *SSAValue::getFunction() { return getDefiningInst()->getFunction(); case SSAValueKind::MLFuncArgument: return cast(this)->getFunction(); + case SSAValueKind::BlockArgument: + return cast(this)->getFunction(); case SSAValueKind::StmtResult: return getDefiningStmt()->findFunction(); case SSAValueKind::ForStmt: @@ -113,3 +115,14 @@ CFGFunction *BBArgument::getFunction() { MLFunction *MLValue::getFunction() { return cast(static_cast(this)->getFunction()); } + +//===----------------------------------------------------------------------===// +// BlockArgument implementation. +//===----------------------------------------------------------------------===// + +/// Return the function that this argument is defined in. +MLFunction *BlockArgument::getFunction() { + if (auto *owner = getOwner()) + return owner->findFunction(); + return nullptr; +} diff --git a/mlir/lib/IR/Statement.cpp b/mlir/lib/IR/Statement.cpp index 0f7e90bec320..d4c26d5a7553 100644 --- a/mlir/lib/IR/Statement.cpp +++ b/mlir/lib/IR/Statement.cpp @@ -47,6 +47,11 @@ template <> unsigned StmtOperand::getOperandNumber() const { return this - &getOwner()->getStmtOperands()[0]; } +/// Return which operand this is in the operand list. +template <> unsigned StmtBlockOperand::getOperandNumber() const { + return this - &getOwner()->getBlockOperands()[0]; +} + //===----------------------------------------------------------------------===// // Statement //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/StmtBlock.cpp b/mlir/lib/IR/StmtBlock.cpp index 898dd7bc337b..8ecb903d21d5 100644 --- a/mlir/lib/IR/StmtBlock.cpp +++ b/mlir/lib/IR/StmtBlock.cpp @@ -20,9 +20,11 @@ #include "mlir/IR/Statements.h" using namespace mlir; -//===----------------------------------------------------------------------===// -// Statement block -//===----------------------------------------------------------------------===// +StmtBlock::~StmtBlock() { + clear(); + + llvm::DeleteContainerPointers(arguments); +} Statement *StmtBlock::getContainingStmt() { switch (kind) { @@ -62,3 +64,82 @@ StmtBlock::findAncestorStmtInBlock(const Statement &stmt) const { } return currStmt; } + +//===----------------------------------------------------------------------===// +// Argument list management. +//===----------------------------------------------------------------------===// + +BlockArgument *StmtBlock::addArgument(Type type) { + auto *arg = new BlockArgument(type, this); + arguments.push_back(arg); + return arg; +} + +/// Add one argument to the argument list for each type specified in the list. +auto StmtBlock::addArguments(ArrayRef types) + -> llvm::iterator_range { + arguments.reserve(arguments.size() + types.size()); + auto initialSize = arguments.size(); + for (auto type : types) { + addArgument(type); + } + return {arguments.data() + initialSize, arguments.data() + arguments.size()}; +} + +void StmtBlock::eraseArgument(unsigned index) { + assert(index < arguments.size()); + + // Delete the argument. + delete arguments[index]; + arguments.erase(arguments.begin() + index); + + // Erase this argument from each of the predecessor's terminator. + for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE; + ++predIt) { + auto *predTerminator = (*predIt)->getTerminator(); + predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index); + } +} + +//===----------------------------------------------------------------------===// +// Terminator management +//===----------------------------------------------------------------------===// + +OperationStmt *StmtBlock::getTerminator() { + if (empty()) + return nullptr; + + // Check if the last instruction is a terminator. + auto &backInst = statements.back(); + auto *opStmt = dyn_cast(&backInst); + if (!opStmt || !opStmt->isTerminator()) + return nullptr; + return opStmt; +} + +/// Return true if this block has no predecessors. +bool StmtBlock::hasNoPredecessors() const { return pred_begin() == pred_end(); } + +// Indexed successor access. +unsigned StmtBlock::getNumSuccessors() const { + return getTerminator()->getNumSuccessors(); +} + +StmtBlock *StmtBlock::getSuccessor(unsigned i) { + return getTerminator()->getSuccessor(i); +} + +/// If this block has exactly one predecessor, return it. Otherwise, return +/// null. +/// +/// Note that multiple edges from a single block (e.g. if you have a cond +/// branch with the same block as the true/false destinations) is not +/// considered to be a single predecessor. +StmtBlock *StmtBlock::getSinglePredecessor() { + auto it = pred_begin(); + if (it == pred_end()) + return nullptr; + auto *firstPred = *it; + ++it; + return it == pred_end() ? firstPred : nullptr; +}