From 3de07e5c530fc583de1293c096961776379ca54c Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 22 Jul 2018 21:02:26 -0700 Subject: [PATCH] Implement generic operand/result iterators that map through our implementation details, returning things in terms of values (which is what most clients want). Implement support for operands and results on Operation, and simplify the asmprinter to use it. PiperOrigin-RevId: 205608853 --- mlir/include/mlir/IR/Instructions.h | 110 +++++++++++++++-- mlir/include/mlir/IR/Operation.h | 179 +++++++++++++++++++++++++++- mlir/lib/IR/AsmPrinter.cpp | 59 ++++----- mlir/lib/IR/Operation.cpp | 50 +++++++- 4 files changed, 348 insertions(+), 50 deletions(-) diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 8d4068b2f04e..6f374bd71b5f 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -96,21 +96,45 @@ public: ArrayRef attributes, MLIRContext *context); + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + unsigned getNumOperands() const { return numOperands; } - // TODO: Add a getOperands() custom sequence that provides a value projection - // of the operand list. CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); } const CFGValue *getOperand(unsigned idx) const { return getInstOperand(idx).get(); } - unsigned getNumResults() const { return numResults; } + // Support non-const operand iteration. + using operand_iterator = OperandIterator; - // TODO: Add a getResults() custom sequence that provides a value projection - // of the result list. - CFGValue *getResult(unsigned idx) { return &getInstResult(idx); } - const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); } + operand_iterator operand_begin() { return operand_iterator(this, 0); } + + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() { + return {operand_begin(), operand_end()}; + } + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() const { + return {operand_begin(), operand_end()}; + } ArrayRef getInstOperands() const { return {getTrailingObjects(), numOperands}; @@ -124,17 +148,58 @@ public: return getInstOperands()[idx]; } + //===--------------------------------------------------------------------===// + // Results + //===--------------------------------------------------------------------===// + + unsigned getNumResults() const { return numResults; } + + CFGValue *getResult(unsigned idx) { return &getInstResult(idx); } + const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); } + + // Support non-const result iteration. + typedef ResultIterator result_iterator; + result_iterator result_begin() { return result_iterator(this, 0); } + result_iterator result_end() { + return result_iterator(this, getNumResults()); + } + llvm::iterator_range getResults() { + return {result_begin(), result_end()}; + } + + // Support const operand iteration. + typedef ResultIterator + const_result_iterator; + const_result_iterator result_begin() const { + return const_result_iterator(this, 0); + } + + const_result_iterator result_end() const { + return const_result_iterator(this, getNumResults()); + } + + llvm::iterator_range getResults() const { + return {result_begin(), result_end()}; + } + ArrayRef getInstResults() const { return {getTrailingObjects(), numResults}; } + MutableArrayRef getInstResults() { return {getTrailingObjects(), numResults}; } + InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } + const InstResult &getInstResult(unsigned idx) const { return getInstResults()[idx]; } + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + /// Unlink this instruction from its BasicBlock and delete it. void eraseFromBlock(); @@ -222,13 +287,40 @@ public: unsigned getNumOperands() const { return numOperands; } - // TODO: Add a getOperands() custom sequence that provides a value projection - // of the operand list. CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); } const CFGValue *getOperand(unsigned idx) const { return getInstOperand(idx).get(); } + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() { + return {operand_begin(), operand_end()}; + } + + // Support const operand iteration. + typedef OperandIterator + const_operand_iterator; + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() const { + return {operand_begin(), operand_end()}; + } + ArrayRef getInstOperands() const { return {getTrailingObjects(), numOperands}; } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 52216cd23c6e..bf6192f1aabe 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -29,6 +29,9 @@ class AttributeListStorage; class AbstractOperation; template class ConstOpPointer; template class OpPointer; +template class OperandIterator; +template class ResultIterator; +class SSAValue; /// NamedAttribute is a used for operation attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -44,8 +47,47 @@ public: /// The name of an operation is the key identifier for it. Identifier getName() const { return nameAndIsInstruction.getPointer(); } - // TODO: Need to have results and operands. + /// Return the number of operands this operation has. + unsigned getNumOperands() const; + SSAValue *getOperand(unsigned idx); + const SSAValue *getOperand(unsigned idx) const { + return const_cast(this)->getOperand(idx); + } + + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + operand_iterator operand_begin(); + operand_iterator operand_end(); + llvm::iterator_range getOperands(); + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + const_operand_iterator operand_begin() const; + const_operand_iterator operand_end() const; + llvm::iterator_range getOperands() const; + + /// Return the number of results this operation has. + unsigned getNumResults() const; + + /// Return the indicated result. + SSAValue *getResult(unsigned idx); + const SSAValue *getResult(unsigned idx) const { + return const_cast(this)->getResult(idx); + } + + // Support non-const result iteration. + using result_iterator = ResultIterator; + result_iterator result_begin(); + result_iterator result_end(); + llvm::iterator_range getResults(); + + // Support const operand iteration. + using const_result_iterator = ResultIterator; + const_result_iterator result_begin() const; + const_result_iterator result_end() const; + llvm::iterator_range getResults() const; // Attributes. Operations may optionally carry a list of attributes that // associate constants to names. Attributes may be dynamically added and @@ -141,6 +183,141 @@ private: AttributeListStorage *attrs; }; +/// This is a helper template used to implement an iterator that contains a +/// pointer to some object and an index into it. The iterator moves the +/// index but keeps the object constant. +template +class IndexedAccessorIterator + : public llvm::iterator_facade_base< + ConcreteType, std::random_access_iterator_tag, ElementType *> { +public: + ptrdiff_t operator-(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index - rhs.index; + } + bool operator==(const IndexedAccessorIterator &rhs) const { + return object == rhs.object && index == rhs.index; + } + bool operator<(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index < rhs.index; + } + + ConcreteType &operator+=(ptrdiff_t offset) { + this->index += offset; + return static_cast(*this); + } + ConcreteType &operator-=(ptrdiff_t offset) { + this->index -= offset; + return static_cast(*this); + } + +protected: + IndexedAccessorIterator(ObjectType *object, unsigned index) + : object(object), index(index) {} + ObjectType *object; + unsigned index; +}; + +/// This template implments the operand iterators for the various IR classes +/// in terms of getOperand(idx). +template +class OperandIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the operand iterator to the specified operand index. + OperandIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator OperandIterator() const { + return OperandIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getOperand(this->index); + } +}; + +/// This template implments the result iterators for the various IR classes +/// in terms of getResult(idx). +template +class ResultIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the result iterator to the specified index. + ResultIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultIterator() const { + return ResultIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getResult(this->index); + } +}; + +// Implement the inline operand iterator methods. +inline auto Operation::operand_begin() -> operand_iterator { + return operand_iterator(this, 0); +} + +inline auto Operation::operand_end() -> operand_iterator { + return operand_iterator(this, getNumOperands()); +} + +inline auto Operation::getOperands() -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +inline auto Operation::operand_begin() const -> const_operand_iterator { + return const_operand_iterator(this, 0); +} + +inline auto Operation::operand_end() const -> const_operand_iterator { + return const_operand_iterator(this, getNumOperands()); +} + +inline auto Operation::getOperands() const + -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +// Implement the inline result iterator methods. +inline auto Operation::result_begin() -> result_iterator { + return result_iterator(this, 0); +} + +inline auto Operation::result_end() -> result_iterator { + return result_iterator(this, getNumResults()); +} + +inline auto Operation::getResults() -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto Operation::result_begin() const -> const_result_iterator { + return const_result_iterator(this, 0); +} + +inline auto Operation::result_end() const -> const_result_iterator { + return const_result_iterator(this, getNumResults()); +} + +inline auto Operation::getResults() const + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} } // end namespace mlir #endif diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 0d5610f45d28..afb4e73d2dce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -131,7 +131,6 @@ void ModuleState::visitExtFunction(const ExtFunction *fn) { void ModuleState::visitCFGFunction(const CFGFunction *fn) { visitType(fn->getType()); - // TODO Visit function body instructions. for (auto &block : *fn) { for (auto &op : block.getOperations()) { visitOperation(&op); @@ -555,13 +554,9 @@ private: void FunctionState::printOperation(const Operation *op) { os << " "; - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - if (inst->getNumResults()) { - printValueID(inst->getResult(0), /*dontPrintResultNo*/ true); - os << " = "; - } + if (op->getNumResults()) { + printValueID(op->getResult(0), /*dontPrintResultNo*/ true); + os << " = "; } // Check to see if this is a known operation. If so, use the registered @@ -576,14 +571,8 @@ void FunctionState::printOperation(const Operation *op) { // TODO: escape name if necessary. os << "\"" << op->getName().str() << "\"("; - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - // TODO: Use getOperands() when we have it. - interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) { - printValueID(operand.get()); - }); - } + interleaveComma(op->getOperands(), + [&](const SSAValue *value) { printValueID(value); }); os << ')'; auto attrs = op->getAttrs(); @@ -596,26 +585,19 @@ void FunctionState::printOperation(const Operation *op) { os << '}'; } - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - // Print the type signature of the operation. - os << " : ("; - // TODO: Switch to getOperands() when we have it. - interleaveComma(inst->getInstOperands(), - [&](const InstOperand &op) { print(op.get()->getType()); }); - os << ") -> "; + // Print the type signature of the operation. + os << " : ("; + interleaveComma(op->getOperands(), + [&](const SSAValue *value) { print(value->getType()); }); + os << ") -> "; - // TODO: Switch to getResults() when we have it. - if (inst->getNumResults() == 1) { - print(inst->getInstResult(0).getType()); - } else { - os << '('; - interleaveComma(inst->getInstResults(), [&](const InstResult &result) { - print(result.getType()); - }); - os << ')'; - } + if (op->getNumResults() == 1) { + print(op->getResult(0)->getType()); + } else { + os << '('; + interleaveComma(op->getResults(), + [&](const SSAValue *result) { print(result->getType()); }); + os << ')'; } } @@ -733,11 +715,10 @@ void CFGFunctionPrinter::print(const ReturnInst *inst) { if (inst->getNumOperands() != 0) os << ' '; - // TODO: Use getOperands() when we have it. - interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) { - printValueID(operand.get()); + interleaveComma(inst->getOperands(), [&](const CFGValue *operand) { + printValueID(operand); os << " : "; - ModulePrinter::print(operand.get()->getType()); + ModulePrinter::print(operand->getType()); }); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 12c0c1b6309d..765903d1f41f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -17,6 +17,8 @@ #include "mlir/IR/Operation.h" #include "AttributeListStorage.h" +#include "mlir/IR/Instructions.h" +#include "mlir/IR/Statements.h" using namespace mlir; Operation::Operation(Identifier name, bool isInstruction, @@ -30,7 +32,53 @@ Operation::Operation(Identifier name, bool isInstruction, #endif } -Operation::~Operation() { +Operation::~Operation() {} + +/// Return the number of operands this operation has. +unsigned Operation::getNumOperands() const { + if (auto *inst = dyn_cast(this)) { + return inst->getNumOperands(); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + return 0; + } +} + +SSAValue *Operation::getOperand(unsigned idx) { + if (auto *inst = dyn_cast(this)) { + return inst->getOperand(idx); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + abort(); + } +} + +/// Return the number of results this operation has. +unsigned Operation::getNumResults() const { + if (auto *inst = dyn_cast(this)) { + return inst->getNumResults(); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add results to OperationStmt. + return 0; + } +} + +/// Return the indicated result. +SSAValue *Operation::getResult(unsigned idx) { + if (auto *inst = dyn_cast(this)) { + return inst->getResult(idx); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + abort(); + } } ArrayRef Operation::getAttrs() const {