diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 8d4068b2f04e..6f374bd71b5f 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -96,21 +96,45 @@ public: ArrayRef attributes, MLIRContext *context); + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + unsigned getNumOperands() const { return numOperands; } - // TODO: Add a getOperands() custom sequence that provides a value projection - // of the operand list. CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); } const CFGValue *getOperand(unsigned idx) const { return getInstOperand(idx).get(); } - unsigned getNumResults() const { return numResults; } + // Support non-const operand iteration. + using operand_iterator = OperandIterator; - // TODO: Add a getResults() custom sequence that provides a value projection - // of the result list. - CFGValue *getResult(unsigned idx) { return &getInstResult(idx); } - const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); } + operand_iterator operand_begin() { return operand_iterator(this, 0); } + + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() { + return {operand_begin(), operand_end()}; + } + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() const { + return {operand_begin(), operand_end()}; + } ArrayRef getInstOperands() const { return {getTrailingObjects(), numOperands}; @@ -124,17 +148,58 @@ public: return getInstOperands()[idx]; } + //===--------------------------------------------------------------------===// + // Results + //===--------------------------------------------------------------------===// + + unsigned getNumResults() const { return numResults; } + + CFGValue *getResult(unsigned idx) { return &getInstResult(idx); } + const CFGValue *getResult(unsigned idx) const { return &getInstResult(idx); } + + // Support non-const result iteration. + typedef ResultIterator result_iterator; + result_iterator result_begin() { return result_iterator(this, 0); } + result_iterator result_end() { + return result_iterator(this, getNumResults()); + } + llvm::iterator_range getResults() { + return {result_begin(), result_end()}; + } + + // Support const operand iteration. + typedef ResultIterator + const_result_iterator; + const_result_iterator result_begin() const { + return const_result_iterator(this, 0); + } + + const_result_iterator result_end() const { + return const_result_iterator(this, getNumResults()); + } + + llvm::iterator_range getResults() const { + return {result_begin(), result_end()}; + } + ArrayRef getInstResults() const { return {getTrailingObjects(), numResults}; } + MutableArrayRef getInstResults() { return {getTrailingObjects(), numResults}; } + InstResult &getInstResult(unsigned idx) { return getInstResults()[idx]; } + const InstResult &getInstResult(unsigned idx) const { return getInstResults()[idx]; } + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + /// Unlink this instruction from its BasicBlock and delete it. void eraseFromBlock(); @@ -222,13 +287,40 @@ public: unsigned getNumOperands() const { return numOperands; } - // TODO: Add a getOperands() custom sequence that provides a value projection - // of the operand list. CFGValue *getOperand(unsigned idx) { return getInstOperand(idx).get(); } const CFGValue *getOperand(unsigned idx) const { return getInstOperand(idx).get(); } + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() { + return {operand_begin(), operand_end()}; + } + + // Support const operand iteration. + typedef OperandIterator + const_operand_iterator; + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + llvm::iterator_range getOperands() const { + return {operand_begin(), operand_end()}; + } + ArrayRef getInstOperands() const { return {getTrailingObjects(), numOperands}; } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 52216cd23c6e..bf6192f1aabe 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -29,6 +29,9 @@ class AttributeListStorage; class AbstractOperation; template class ConstOpPointer; template class OpPointer; +template class OperandIterator; +template class ResultIterator; +class SSAValue; /// NamedAttribute is a used for operation attribute lists, it holds an /// identifier for the name and a value for the attribute. The attribute @@ -44,8 +47,47 @@ public: /// The name of an operation is the key identifier for it. Identifier getName() const { return nameAndIsInstruction.getPointer(); } - // TODO: Need to have results and operands. + /// Return the number of operands this operation has. + unsigned getNumOperands() const; + SSAValue *getOperand(unsigned idx); + const SSAValue *getOperand(unsigned idx) const { + return const_cast(this)->getOperand(idx); + } + + // Support non-const operand iteration. + using operand_iterator = OperandIterator; + operand_iterator operand_begin(); + operand_iterator operand_end(); + llvm::iterator_range getOperands(); + + // Support const operand iteration. + using const_operand_iterator = + OperandIterator; + const_operand_iterator operand_begin() const; + const_operand_iterator operand_end() const; + llvm::iterator_range getOperands() const; + + /// Return the number of results this operation has. + unsigned getNumResults() const; + + /// Return the indicated result. + SSAValue *getResult(unsigned idx); + const SSAValue *getResult(unsigned idx) const { + return const_cast(this)->getResult(idx); + } + + // Support non-const result iteration. + using result_iterator = ResultIterator; + result_iterator result_begin(); + result_iterator result_end(); + llvm::iterator_range getResults(); + + // Support const operand iteration. + using const_result_iterator = ResultIterator; + const_result_iterator result_begin() const; + const_result_iterator result_end() const; + llvm::iterator_range getResults() const; // Attributes. Operations may optionally carry a list of attributes that // associate constants to names. Attributes may be dynamically added and @@ -141,6 +183,141 @@ private: AttributeListStorage *attrs; }; +/// This is a helper template used to implement an iterator that contains a +/// pointer to some object and an index into it. The iterator moves the +/// index but keeps the object constant. +template +class IndexedAccessorIterator + : public llvm::iterator_facade_base< + ConcreteType, std::random_access_iterator_tag, ElementType *> { +public: + ptrdiff_t operator-(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index - rhs.index; + } + bool operator==(const IndexedAccessorIterator &rhs) const { + return object == rhs.object && index == rhs.index; + } + bool operator<(const IndexedAccessorIterator &rhs) const { + assert(object == rhs.object && "incompatible iterators"); + return index < rhs.index; + } + + ConcreteType &operator+=(ptrdiff_t offset) { + this->index += offset; + return static_cast(*this); + } + ConcreteType &operator-=(ptrdiff_t offset) { + this->index -= offset; + return static_cast(*this); + } + +protected: + IndexedAccessorIterator(ObjectType *object, unsigned index) + : object(object), index(index) {} + ObjectType *object; + unsigned index; +}; + +/// This template implments the operand iterators for the various IR classes +/// in terms of getOperand(idx). +template +class OperandIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the operand iterator to the specified operand index. + OperandIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator OperandIterator() const { + return OperandIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getOperand(this->index); + } +}; + +/// This template implments the result iterators for the various IR classes +/// in terms of getResult(idx). +template +class ResultIterator final + : public IndexedAccessorIterator, + ObjectType, ElementType> { +public: + /// Initializes the result iterator to the specified index. + ResultIterator(ObjectType *object, unsigned index) + : IndexedAccessorIterator, + ObjectType, ElementType>(object, index) {} + + /// Support converting to the const variant. This will be a no-op for const + /// variant. + operator ResultIterator() const { + return ResultIterator(this->object, + this->index); + } + + ElementType *operator*() const { + return this->object->getResult(this->index); + } +}; + +// Implement the inline operand iterator methods. +inline auto Operation::operand_begin() -> operand_iterator { + return operand_iterator(this, 0); +} + +inline auto Operation::operand_end() -> operand_iterator { + return operand_iterator(this, getNumOperands()); +} + +inline auto Operation::getOperands() -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +inline auto Operation::operand_begin() const -> const_operand_iterator { + return const_operand_iterator(this, 0); +} + +inline auto Operation::operand_end() const -> const_operand_iterator { + return const_operand_iterator(this, getNumOperands()); +} + +inline auto Operation::getOperands() const + -> llvm::iterator_range { + return {operand_begin(), operand_end()}; +} + +// Implement the inline result iterator methods. +inline auto Operation::result_begin() -> result_iterator { + return result_iterator(this, 0); +} + +inline auto Operation::result_end() -> result_iterator { + return result_iterator(this, getNumResults()); +} + +inline auto Operation::getResults() -> llvm::iterator_range { + return {result_begin(), result_end()}; +} + +inline auto Operation::result_begin() const -> const_result_iterator { + return const_result_iterator(this, 0); +} + +inline auto Operation::result_end() const -> const_result_iterator { + return const_result_iterator(this, getNumResults()); +} + +inline auto Operation::getResults() const + -> llvm::iterator_range { + return {result_begin(), result_end()}; +} } // end namespace mlir #endif diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 0d5610f45d28..afb4e73d2dce 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -131,7 +131,6 @@ void ModuleState::visitExtFunction(const ExtFunction *fn) { void ModuleState::visitCFGFunction(const CFGFunction *fn) { visitType(fn->getType()); - // TODO Visit function body instructions. for (auto &block : *fn) { for (auto &op : block.getOperations()) { visitOperation(&op); @@ -555,13 +554,9 @@ private: void FunctionState::printOperation(const Operation *op) { os << " "; - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - if (inst->getNumResults()) { - printValueID(inst->getResult(0), /*dontPrintResultNo*/ true); - os << " = "; - } + if (op->getNumResults()) { + printValueID(op->getResult(0), /*dontPrintResultNo*/ true); + os << " = "; } // Check to see if this is a known operation. If so, use the registered @@ -576,14 +571,8 @@ void FunctionState::printOperation(const Operation *op) { // TODO: escape name if necessary. os << "\"" << op->getName().str() << "\"("; - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - // TODO: Use getOperands() when we have it. - interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) { - printValueID(operand.get()); - }); - } + interleaveComma(op->getOperands(), + [&](const SSAValue *value) { printValueID(value); }); os << ')'; auto attrs = op->getAttrs(); @@ -596,26 +585,19 @@ void FunctionState::printOperation(const Operation *op) { os << '}'; } - // TODO: When we have SSAValue version of operands & results wired into - // Operation this check can go away. - if (auto *inst = dyn_cast(op)) { - // Print the type signature of the operation. - os << " : ("; - // TODO: Switch to getOperands() when we have it. - interleaveComma(inst->getInstOperands(), - [&](const InstOperand &op) { print(op.get()->getType()); }); - os << ") -> "; + // Print the type signature of the operation. + os << " : ("; + interleaveComma(op->getOperands(), + [&](const SSAValue *value) { print(value->getType()); }); + os << ") -> "; - // TODO: Switch to getResults() when we have it. - if (inst->getNumResults() == 1) { - print(inst->getInstResult(0).getType()); - } else { - os << '('; - interleaveComma(inst->getInstResults(), [&](const InstResult &result) { - print(result.getType()); - }); - os << ')'; - } + if (op->getNumResults() == 1) { + print(op->getResult(0)->getType()); + } else { + os << '('; + interleaveComma(op->getResults(), + [&](const SSAValue *result) { print(result->getType()); }); + os << ')'; } } @@ -733,11 +715,10 @@ void CFGFunctionPrinter::print(const ReturnInst *inst) { if (inst->getNumOperands() != 0) os << ' '; - // TODO: Use getOperands() when we have it. - interleaveComma(inst->getInstOperands(), [&](const InstOperand &operand) { - printValueID(operand.get()); + interleaveComma(inst->getOperands(), [&](const CFGValue *operand) { + printValueID(operand); os << " : "; - ModulePrinter::print(operand.get()->getType()); + ModulePrinter::print(operand->getType()); }); } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 12c0c1b6309d..765903d1f41f 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -17,6 +17,8 @@ #include "mlir/IR/Operation.h" #include "AttributeListStorage.h" +#include "mlir/IR/Instructions.h" +#include "mlir/IR/Statements.h" using namespace mlir; Operation::Operation(Identifier name, bool isInstruction, @@ -30,7 +32,53 @@ Operation::Operation(Identifier name, bool isInstruction, #endif } -Operation::~Operation() { +Operation::~Operation() {} + +/// Return the number of operands this operation has. +unsigned Operation::getNumOperands() const { + if (auto *inst = dyn_cast(this)) { + return inst->getNumOperands(); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + return 0; + } +} + +SSAValue *Operation::getOperand(unsigned idx) { + if (auto *inst = dyn_cast(this)) { + return inst->getOperand(idx); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + abort(); + } +} + +/// Return the number of results this operation has. +unsigned Operation::getNumResults() const { + if (auto *inst = dyn_cast(this)) { + return inst->getNumResults(); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add results to OperationStmt. + return 0; + } +} + +/// Return the indicated result. +SSAValue *Operation::getResult(unsigned idx) { + if (auto *inst = dyn_cast(this)) { + return inst->getResult(idx); + } else { + auto *stmt = cast(this); + (void)stmt; + // TODO: Add operands to OperationStmt. + abort(); + } } ArrayRef Operation::getAttrs() const {