diff --git a/mlir/include/mlir/AffineOps/AffineOps.h b/mlir/include/mlir/AffineOps/AffineOps.h deleted file mode 100644 index d511f628c3c2..000000000000 --- a/mlir/include/mlir/AffineOps/AffineOps.h +++ /dev/null @@ -1,91 +0,0 @@ -//===- AffineOps.h - MLIR Affine Operations -------------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -// -// This file defines convenience types for working with Affine operations -// in the MLIR instruction set. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_AFFINEOPS_AFFINEOPS_H -#define MLIR_AFFINEOPS_AFFINEOPS_H - -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/StandardTypes.h" - -namespace mlir { - -class AffineOpsDialect : public Dialect { -public: - AffineOpsDialect(MLIRContext *context); -}; - -/// The "if" operation represents an if–then–else construct for conditionally -/// executing two regions of code. The operands to an if operation are an -/// IntegerSet condition and a set of symbol/dimension operands to the -/// condition set. The operation produces no results. For example: -/// -/// if #set(%i) { -/// ... -/// } else { -/// ... -/// } -/// -/// The 'else' blocks to the if operation are optional, and may be omitted. For -/// example: -/// -/// if #set(%i) { -/// ... -/// } -/// -class AffineIfOp - : public Op { -public: - // Hooks to customize behavior of this op. - static void build(Builder *builder, OperationState *result, - IntegerSet condition, ArrayRef conditionOperands); - - static StringRef getOperationName() { return "if"; } - static StringRef getConditionAttrName() { return "condition"; } - - IntegerSet getIntegerSet() const; - void setIntegerSet(IntegerSet newSet); - - /// Returns the list of 'then' blocks. - BlockList &getThenBlocks(); - const BlockList &getThenBlocks() const { - return const_cast(this)->getThenBlocks(); - } - - /// Returns the list of 'else' blocks. - BlockList &getElseBlocks(); - const BlockList &getElseBlocks() const { - return const_cast(this)->getElseBlocks(); - } - - bool verify() const; - static bool parse(OpAsmParser *parser, OperationState *result); - void print(OpAsmPrinter *p) const; - -private: - friend class OperationInst; - explicit AffineIfOp(const OperationInst *state) : Op(state) {} -}; - -} // end namespace mlir - -#endif diff --git a/mlir/include/mlir/Analysis/NestedMatcher.h b/mlir/include/mlir/Analysis/NestedMatcher.h index 161bb217a10c..c205d55488e6 100644 --- a/mlir/include/mlir/Analysis/NestedMatcher.h +++ b/mlir/include/mlir/Analysis/NestedMatcher.h @@ -128,6 +128,7 @@ private: void matchOne(Instruction *elem); void visitForInst(ForInst *forInst) { matchOne(forInst); } + void visitIfInst(IfInst *ifInst) { matchOne(ifInst); } void visitOperationInst(OperationInst *opInst) { matchOne(opInst); } /// POD paylod. diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index e85ea772d0b8..1b14d925d32d 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -26,6 +26,7 @@ #include "llvm/ADT/PointerUnion.h" namespace mlir { +class IfInst; class BlockList; class BlockAndValueMapping; @@ -61,7 +62,7 @@ public: } /// Returns the function that this block is part of, even if the block is - /// nested under an OperationInst or ForInst. + /// nested under an IfInst or ForInst. Function *getFunction(); const Function *getFunction() const { return const_cast(this)->getFunction(); @@ -324,7 +325,7 @@ private: namespace mlir { /// This class contains a list of basic blocks and has a notion of the object it -/// is part of - a Function or OperationInst or ForInst. +/// is part of - a Function or IfInst or ForInst. class BlockList { public: explicit BlockList(Function *container); @@ -364,14 +365,14 @@ public: return &BlockList::blocks; } - /// A BlockList is part of a Function or and OperationInst/ForInst. If it is - /// part of an OperationInst/ForInst, then return it, otherwise return null. + /// A BlockList is part of a Function or and IfInst/ForInst. If it is + /// part of an IfInst/ForInst, then return it, otherwise return null. Instruction *getContainingInst(); const Instruction *getContainingInst() const { return const_cast(this)->getContainingInst(); } - /// A BlockList is part of a Function or and OperationInst/ForInst. If it is + /// A BlockList is part of a Function or and IfInst/ForInst. If it is /// part of a Function, then return it, otherwise return null. Function *getContainingFunction(); const Function *getContainingFunction() const { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 3271c12afde4..156bd02bb52a 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -286,6 +286,10 @@ public: // Default step is 1. ForInst *createFor(Location loc, int64_t lb, int64_t ub, int64_t step = 1); + /// Creates if instruction. + IfInst *createIf(Location location, ArrayRef operands, + IntegerSet set); + private: Function *function; Block *block = nullptr; diff --git a/mlir/include/mlir/IR/InstVisitor.h b/mlir/include/mlir/IR/InstVisitor.h index 78810da909d1..b6a759e76f5a 100644 --- a/mlir/include/mlir/IR/InstVisitor.h +++ b/mlir/include/mlir/IR/InstVisitor.h @@ -44,7 +44,7 @@ // lc.walk(function); // numLoops = lc.numLoops; // -// There are 'visit' methods for OperationInst, ForInst, and +// There are 'visit' methods for OperationInst, ForInst, IfInst, and // Function, which recursively process all contained instructions. // // Note that if you don't implement visitXXX for some instruction type, @@ -85,6 +85,8 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->visitForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->visitIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->visitOperationInst( cast(s)); @@ -102,6 +104,7 @@ public: // When visiting a for inst, if inst, or an operation inst directly, these // methods get called to indicate when transitioning into a new unit. void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} }; @@ -163,6 +166,23 @@ public: static_cast(this)->visitForInst(forInst); } + void walkIfInst(IfInst *ifInst) { + static_cast(this)->visitIfInst(ifInst); + static_cast(this)->walk(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walk(elseBlock->begin(), elseBlock->end()); + } + + void walkIfInstPostOrder(IfInst *ifInst) { + static_cast(this)->walkPostOrder(ifInst->getThen()->begin(), + ifInst->getThen()->end()); + if (auto *elseBlock = ifInst->getElse()) + static_cast(this)->walkPostOrder(elseBlock->begin(), + elseBlock->end()); + static_cast(this)->visitIfInst(ifInst); + } + // Function to walk a instruction. RetTy walk(Instruction *s) { static_assert(std::is_base_of::value, @@ -173,6 +193,8 @@ public: switch (s->getKind()) { case Instruction::Kind::For: return static_cast(this)->walkForInst(cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInst(cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInst(cast(s)); } @@ -188,6 +210,9 @@ public: case Instruction::Kind::For: return static_cast(this)->walkForInstPostOrder( cast(s)); + case Instruction::Kind::If: + return static_cast(this)->walkIfInstPostOrder( + cast(s)); case Instruction::Kind::OperationInst: return static_cast(this)->walkOpInstPostOrder( cast(s)); @@ -206,6 +231,7 @@ public: // processing their descendants in some way. When using RetTy, all of these // need to be overridden. void visitForInst(ForInst *forInst) {} + void visitIfInst(IfInst *ifInst) {} void visitOperationInst(OperationInst *opInst) {} void visitInstruction(Instruction *inst) {} }; diff --git a/mlir/include/mlir/IR/Instruction.h b/mlir/include/mlir/IR/Instruction.h index 3dc1e76dd20d..6a296b7348eb 100644 --- a/mlir/include/mlir/IR/Instruction.h +++ b/mlir/include/mlir/IR/Instruction.h @@ -75,6 +75,7 @@ public: enum class Kind { OperationInst = (int)IROperandOwner::Kind::OperationInst, For = (int)IROperandOwner::Kind::ForInst, + If = (int)IROperandOwner::Kind::IfInst, }; Kind getKind() const { return (Kind)IROperandOwner::getKind(); } diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index fb6b1b97ca08..71d832b8b90d 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -794,6 +794,130 @@ private: friend class ForInst; }; + +/// If instruction restricts execution to a subset of the loop iteration space. +class IfInst : public Instruction { +public: + static IfInst *create(Location location, ArrayRef operands, + IntegerSet set); + ~IfInst(); + + //===--------------------------------------------------------------------===// + // Then, else, condition. + //===--------------------------------------------------------------------===// + + Block *getThen() { return &thenClause.front(); } + const Block *getThen() const { return &thenClause.front(); } + Block *getElse() { return elseClause ? &elseClause->front() : nullptr; } + const Block *getElse() const { + return elseClause ? &elseClause->front() : nullptr; + } + bool hasElse() const { return elseClause != nullptr; } + + Block *createElse() { + assert(elseClause == nullptr && "already has an else clause!"); + elseClause = new BlockList(this); + elseClause->push_back(new Block()); + return &elseClause->front(); + } + + const AffineCondition getCondition() const; + + IntegerSet getIntegerSet() const { return set; } + void setIntegerSet(IntegerSet newSet) { + assert(newSet.getNumOperands() == operands.size()); + set = newSet; + } + + //===--------------------------------------------------------------------===// + // Operands + //===--------------------------------------------------------------------===// + + /// Operand iterators. + using operand_iterator = OperandIterator; + using const_operand_iterator = OperandIterator; + + /// Operand iterator range. + using operand_range = llvm::iterator_range; + using const_operand_range = llvm::iterator_range; + + unsigned getNumOperands() const { return operands.size(); } + + Value *getOperand(unsigned idx) { return getInstOperand(idx).get(); } + const Value *getOperand(unsigned idx) const { + return getInstOperand(idx).get(); + } + void setOperand(unsigned idx, Value *value) { + getInstOperand(idx).set(value); + } + + operand_iterator operand_begin() { return operand_iterator(this, 0); } + operand_iterator operand_end() { + return operand_iterator(this, getNumOperands()); + } + + const_operand_iterator operand_begin() const { + return const_operand_iterator(this, 0); + } + const_operand_iterator operand_end() const { + return const_operand_iterator(this, getNumOperands()); + } + + ArrayRef getInstOperands() const { return operands; } + MutableArrayRef getInstOperands() { return operands; } + InstOperand &getInstOperand(unsigned idx) { return getInstOperands()[idx]; } + const InstOperand &getInstOperand(unsigned idx) const { + return getInstOperands()[idx]; + } + + //===--------------------------------------------------------------------===// + // Other + //===--------------------------------------------------------------------===// + + MLIRContext *getContext() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool classof(const IROperandOwner *ptr) { + return ptr->getKind() == IROperandOwner::Kind::IfInst; + } + +private: + // it is always present. + BlockList thenClause; + // 'else' clause of the if instruction. 'nullptr' if there is no else clause. + BlockList *elseClause; + + // The integer set capturing the conditional guard. + IntegerSet set; + + // Condition operands. + std::vector operands; + + explicit IfInst(Location location, unsigned numOperands, IntegerSet set); +}; + +/// AffineCondition represents a condition of the 'if' instruction. +/// Its life span should not exceed that of the objects it refers to. +/// AffineCondition does not provide its own methods for iterating over +/// the operands since the iterators of the if instruction accomplish +/// the same purpose. +/// +/// AffineCondition is trivially copyable, so it should be passed by value. +class AffineCondition { +public: + const IfInst *getIfInst() const { return &inst; } + IntegerSet getIntegerSet() const { return set; } + +private: + // 'if' instruction that contains this affine condition. + const IfInst &inst; + // Integer set for this affine condition. + IntegerSet set; + + AffineCondition(const IfInst &inst, IntegerSet set) : inst(inst), set(set) {} + + friend class IfInst; +}; } // end namespace mlir #endif // MLIR_IR_INSTRUCTIONS_H diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index d3a5d35427f5..1e319db35710 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -89,9 +89,6 @@ public: /// Print the entire operation with the default generic assembly form. virtual void printGenericOp(const OperationInst *op) = 0; - /// Prints a block list. - virtual void printBlockList(const BlockList &blocks) = 0; - private: OpAsmPrinter(const OpAsmPrinter &) = delete; void operator=(const OpAsmPrinter &) = delete; @@ -198,19 +195,7 @@ public: virtual bool parseColonTypeList(SmallVectorImpl &result) = 0; /// Parse a keyword followed by a type. - bool parseKeywordType(const char *keyword, Type &result) { - return parseKeyword(keyword) || parseType(result); - } - - /// Parse a keyword. - bool parseKeyword(const char *keyword) { - if (parseOptionalKeyword(keyword)) - return emitError(getNameLoc(), "expected '" + Twine(keyword) + "'"); - return false; - } - - /// If a keyword is present, then parse it. - virtual bool parseOptionalKeyword(const char *keyword) = 0; + virtual bool parseKeywordType(const char *keyword, Type &result) = 0; /// Add the specified type to the end of the specified type list and return /// false. This is a helper designed to allow parse methods to be simple and @@ -311,10 +296,6 @@ public: int requiredOperandCount = -1, Delimiter delimiter = Delimiter::None) = 0; - /// Parses a block list. Any parsed blocks are filled in to the - /// operation's block lists after the operation is created. - virtual bool parseBlockList() = 0; - //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/UseDefLists.h b/mlir/include/mlir/IR/UseDefLists.h index 80cd21362ceb..053d3520103c 100644 --- a/mlir/include/mlir/IR/UseDefLists.h +++ b/mlir/include/mlir/IR/UseDefLists.h @@ -81,9 +81,10 @@ public: enum class Kind { OperationInst, ForInst, + IfInst, /// These enums define ranges used for classof implementations. - INST_LAST = ForInst, + INST_LAST = IfInst, }; Kind getKind() const { return locationAndKind.getInt(); } diff --git a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h index 00c6577240cd..978fa45ab232 100644 --- a/mlir/include/mlir/Transforms/MLPatternLoweringPass.h +++ b/mlir/include/mlir/Transforms/MLPatternLoweringPass.h @@ -93,7 +93,7 @@ using OwningMLLoweringPatternList = /// next _original_ operation is considered. /// In other words, for each operation, the pass applies the first matching /// rewriter in the list and advances to the (lexically) next operation. -/// Non-operation instructions (ForInst) are ignored. +/// Non-operation instructions (ForInst and IfInst) are ignored. /// This is similar to greedy worklist-based pattern rewriter, except that this /// operates on ML functions using an ML builder and does not maintain the work /// list. Note that, as of the time of writing, worklist-based rewriter did not diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp deleted file mode 100644 index 5b29467fc443..000000000000 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ /dev/null @@ -1,151 +0,0 @@ -//===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/OpImplementation.h" -using namespace mlir; - -//===----------------------------------------------------------------------===// -// AffineOpsDialect -//===----------------------------------------------------------------------===// - -AffineOpsDialect::AffineOpsDialect(MLIRContext *context) - : Dialect(/*namePrefix=*/"", context) { - addOperations(); -} - -//===----------------------------------------------------------------------===// -// AffineIfOp -//===----------------------------------------------------------------------===// - -void AffineIfOp::build(Builder *builder, OperationState *result, - IntegerSet condition, - ArrayRef conditionOperands) { - result->addAttribute(getConditionAttrName(), IntegerSetAttr::get(condition)); - result->addOperands(conditionOperands); - - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); -} - -bool AffineIfOp::verify() const { - // Verify that we have a condition attribute. - auto conditionAttr = getAttrOfType(getConditionAttrName()); - if (!conditionAttr) - return emitOpError("requires an integer set attribute named 'condition'"); - - // Verify that the operands are valid dimension/symbols. - IntegerSet condition = conditionAttr.getValue(); - for (unsigned i = 0, e = getNumOperands(); i != e; ++i) { - const Value *operand = getOperand(i); - if (i < condition.getNumDims() && !operand->isValidDim()) - return emitOpError("operand cannot be used as a dimension id"); - if (i >= condition.getNumDims() && !operand->isValidSymbol()) - return emitOpError("operand cannot be used as a symbol"); - } - - // Verify that the entry of each child blocklist does not have arguments. - for (const auto &blockList : getInstruction()->getBlockLists()) { - if (blockList.empty()) - continue; - - // TODO(riverriddle) We currently do not allow multiple blocks in child - // block lists. - if (std::next(blockList.begin()) != blockList.end()) - return emitOpError( - "expects only one block per 'if' or 'else' block list"); - if (blockList.front().getTerminator()) - return emitOpError("expects region block to not have a terminator"); - - for (const auto &b : blockList) - if (b.getNumArguments() != 0) - return emitOpError( - "requires that child entry blocks have no arguments"); - } - return false; -} - -bool AffineIfOp::parse(OpAsmParser *parser, OperationState *result) { - // Parse the condition attribute set. - IntegerSetAttr conditionAttr; - unsigned numDims; - if (parser->parseAttribute(conditionAttr, getConditionAttrName().data(), - result->attributes) || - parseDimAndSymbolList(parser, result->operands, numDims)) - return true; - - // Verify the condition operands. - auto set = conditionAttr.getValue(); - if (set.getNumDims() != numDims) - return parser->emitError( - parser->getNameLoc(), - "dim operand count and integer set dim count must match"); - if (numDims + set.getNumSymbols() != result->operands.size()) - return parser->emitError( - parser->getNameLoc(), - "symbol operand count and integer set symbol count must match"); - - // Parse the 'then' block list. - if (parser->parseBlockList()) - return true; - - // If we find an 'else' keyword then parse the else block list. - if (!parser->parseOptionalKeyword("else")) { - if (parser->parseBlockList()) - return true; - } - - // Reserve 2 block lists, one for the 'then' and one for the 'else' regions. - result->reserveBlockLists(2); - return false; -} - -void AffineIfOp::print(OpAsmPrinter *p) const { - auto conditionAttr = getAttrOfType(getConditionAttrName()); - *p << "if " << conditionAttr; - printDimAndSymbolList(operand_begin(), operand_end(), - conditionAttr.getValue().getNumDims(), p); - p->printBlockList(getInstruction()->getBlockList(0)); - - // Print the 'else' block list if it has any blocks. - const auto &elseBlockList = getInstruction()->getBlockList(1); - if (!elseBlockList.empty()) { - *p << " else"; - p->printBlockList(elseBlockList); - } -} - -IntegerSet AffineIfOp::getIntegerSet() const { - return getAttrOfType(getConditionAttrName()).getValue(); -} -void AffineIfOp::setIntegerSet(IntegerSet newSet) { - setAttr( - Identifier::get(getConditionAttrName(), getInstruction()->getContext()), - IntegerSetAttr::get(newSet)); -} - -/// Returns the list of 'then' blocks. -BlockList &AffineIfOp::getThenBlocks() { - return getInstruction()->getBlockList(0); -} - -/// Returns the list of 'else' blocks. -BlockList &AffineIfOp::getElseBlocks() { - return getInstruction()->getBlockList(1); -} diff --git a/mlir/lib/AffineOps/DialectRegistration.cpp b/mlir/lib/AffineOps/DialectRegistration.cpp deleted file mode 100644 index 0afb32c1bd61..000000000000 --- a/mlir/lib/AffineOps/DialectRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- DialectRegistration.cpp - Register Affine Op dialect ---------------===// -// -// Copyright 2019 The MLIR Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= - -#include "mlir/AffineOps/AffineOps.h" -using namespace mlir; - -// Static initialization for Affine op dialect registration. -static DialectRegistration StandardOps; diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index 07c903a66132..219f356807ad 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -21,7 +21,6 @@ #include "mlir/Analysis/LoopAnalysis.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/NestedMatcher.h" @@ -247,16 +246,6 @@ static bool isVectorizableLoopWithCond(const ForInst &loop, return false; } - // No vectorization across unknown regions. - auto regions = matcher::Op([](const Instruction &inst) -> bool { - auto &opInst = cast(inst); - return opInst.getNumBlockLists() != 0 && !opInst.isa(); - }); - auto regionsMatched = regions.match(forInst); - if (!regionsMatched.empty()) { - return false; - } - auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite); auto vectorTransfersMatched = vectorTransfers.match(forInst); if (!vectorTransfersMatched.empty()) { diff --git a/mlir/lib/Analysis/NestedMatcher.cpp b/mlir/lib/Analysis/NestedMatcher.cpp index 491a9bef1b9f..4f32e9b22f4e 100644 --- a/mlir/lib/Analysis/NestedMatcher.cpp +++ b/mlir/lib/Analysis/NestedMatcher.cpp @@ -16,7 +16,6 @@ // ============================================================================= #include "mlir/Analysis/NestedMatcher.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/StandardOps/StandardOps.h" #include "llvm/ADT/ArrayRef.h" @@ -187,11 +186,6 @@ FilterFunctionType NestedPattern::getFilterFunction() { return storage->filter; } -static bool isAffineIfOp(const Instruction &inst) { - return isa(inst) && - cast(inst).isa(); -} - namespace mlir { namespace matcher { @@ -200,22 +194,16 @@ NestedPattern Op(FilterFunctionType filter) { } NestedPattern If(NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, isAffineIfOp); + return NestedPattern(Instruction::Kind::If, child, defaultFilterFunction); } NestedPattern If(FilterFunctionType filter, NestedPattern child) { - return NestedPattern(Instruction::Kind::OperationInst, child, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(Instruction::Kind::If, child, filter); } NestedPattern If(ArrayRef nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, isAffineIfOp); + return NestedPattern(Instruction::Kind::If, nested, defaultFilterFunction); } NestedPattern If(FilterFunctionType filter, ArrayRef nested) { - return NestedPattern(Instruction::Kind::OperationInst, nested, - [filter](const Instruction &inst) { - return isAffineIfOp(inst) && filter(inst); - }); + return NestedPattern(Instruction::Kind::If, nested, filter); } NestedPattern For(NestedPattern child) { diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0e77d4d9084b..939a2ede618e 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -22,7 +22,6 @@ #include "mlir/Analysis/Utils.h" -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/IR/Builders.h" @@ -44,7 +43,7 @@ void mlir::getLoopIVs(const Instruction &inst, // Traverse up the hierarchy collecing all 'for' instruction while skipping // over 'if' instructions. while (currInst && ((currForInst = dyn_cast(currInst)) || - cast(currInst)->isa())) { + isa(currInst))) { if (currForInst) loops->push_back(currForInst); currInst = currInst->getParentInst(); @@ -360,12 +359,21 @@ static Instruction *getInstAtPosition(ArrayRef positions, if (auto *childForInst = dyn_cast(&inst)) return getInstAtPosition(positions, level + 1, childForInst->getBody()); - for (auto &blockList : cast(&inst)->getBlockLists()) { - for (auto &b : blockList) - if (auto *ret = getInstAtPosition(positions, level + 1, &b)) - return ret; + if (auto *ifInst = dyn_cast(&inst)) { + auto *ret = getInstAtPosition(positions, level + 1, ifInst->getThen()); + if (ret != nullptr) + return ret; + if (auto *elseClause = ifInst->getElse()) + return getInstAtPosition(positions, level + 1, elseClause); + } + if (auto *opInst = dyn_cast(&inst)) { + for (auto &blockList : opInst->getBlockLists()) { + for (auto &b : blockList) + if (auto *ret = getInstAtPosition(positions, level + 1, &b)) + return ret; + } + return nullptr; } - return nullptr; } return nullptr; } diff --git a/mlir/lib/Analysis/Verifier.cpp b/mlir/lib/Analysis/Verifier.cpp index 474eeb2a28e3..383a4878c35f 100644 --- a/mlir/lib/Analysis/Verifier.cpp +++ b/mlir/lib/Analysis/Verifier.cpp @@ -73,6 +73,7 @@ public: bool verifyBlock(const Block &block, bool isTopLevel); bool verifyOperation(const OperationInst &op); bool verifyForInst(const ForInst &forInst); + bool verifyIfInst(const IfInst &ifInst); bool verifyDominance(const Block &block); bool verifyInstDominance(const Instruction &inst); @@ -179,6 +180,10 @@ bool FuncVerifier::verifyBlock(const Block &block, bool isTopLevel) { if (verifyForInst(cast(inst))) return true; break; + case Instruction::Kind::If: + if (verifyIfInst(cast(inst))) + return true; + break; } } @@ -245,6 +250,18 @@ bool FuncVerifier::verifyForInst(const ForInst &forInst) { return verifyBlock(*forInst.getBody(), /*isTopLevel=*/false); } +bool FuncVerifier::verifyIfInst(const IfInst &ifInst) { + // TODO: check that if conditions are properly formed. + if (verifyBlock(*ifInst.getThen(), /*isTopLevel*/ false)) + return true; + + if (auto *elseClause = ifInst.getElse()) + if (verifyBlock(*elseClause, /*isTopLevel*/ false)) + return true; + + return false; +} + bool FuncVerifier::verifyDominance(const Block &block) { for (auto &inst : block) { // Check that all operands on the instruction are ok. @@ -266,6 +283,14 @@ bool FuncVerifier::verifyDominance(const Block &block) { if (verifyDominance(*cast(inst).getBody())) return true; break; + case Instruction::Kind::If: + auto &ifInst = cast(inst); + if (verifyDominance(*ifInst.getThen())) + return true; + if (auto *elseClause = ifInst.getElse()) + if (verifyDominance(*elseClause)) + return true; + break; } } return false; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index cb4c1f0edcee..21bc3b824b12 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -145,6 +145,7 @@ private: // Visit functions. void visitInstruction(const Instruction *inst); void visitForInst(const ForInst *forInst); + void visitIfInst(const IfInst *ifInst); void visitOperationInst(const OperationInst *opInst); void visitType(Type type); void visitAttribute(Attribute attr); @@ -196,6 +197,10 @@ void ModuleState::visitAttribute(Attribute attr) { } } +void ModuleState::visitIfInst(const IfInst *ifInst) { + recordIntegerSetReference(ifInst->getIntegerSet()); +} + void ModuleState::visitForInst(const ForInst *forInst) { AffineMap lbMap = forInst->getLowerBoundMap(); if (!hasCustomForm(lbMap)) @@ -220,6 +225,8 @@ void ModuleState::visitOperationInst(const OperationInst *op) { void ModuleState::visitInstruction(const Instruction *inst) { switch (inst->getKind()) { + case Instruction::Kind::If: + return visitIfInst(cast(inst)); case Instruction::Kind::For: return visitForInst(cast(inst)); case Instruction::Kind::OperationInst: @@ -1070,6 +1077,7 @@ public: void print(const Instruction *inst); void print(const OperationInst *inst); void print(const ForInst *inst); + void print(const IfInst *inst); void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); @@ -1117,9 +1125,6 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) override { - printBlockList(blocks, /*printEntryBlockArgs=*/true); - } void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; if (!blocks.empty()) { @@ -1209,6 +1214,12 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { // Recursively number the stuff in the body. numberValuesInBlock(*cast(&inst)->getBody()); break; + case Instruction::Kind::If: { + auto *ifInst = cast(&inst); + numberValuesInBlock(*ifInst->getThen()); + if (auto *elseBlock = ifInst->getElse()) + numberValuesInBlock(*elseBlock); + } } } } @@ -1349,7 +1360,8 @@ void FunctionPrinter::printFunctionSignature() { } void FunctionPrinter::print(const Block *block, bool printBlockArgs) { - // Print the block label and argument list if requested. + // Print the block label and argument list, unless this is the first block of + // the function, or the first block of an IfInst/ForInst with no arguments. if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1406,6 +1418,8 @@ void FunctionPrinter::print(const Instruction *inst) { return print(cast(inst)); case Instruction::Kind::For: return print(cast(inst)); + case Instruction::Kind::If: + return print(cast(inst)); } } @@ -1433,6 +1447,22 @@ void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "}"; } +void FunctionPrinter::print(const IfInst *inst) { + os.indent(currentIndent) << "if "; + IntegerSet set = inst->getIntegerSet(); + printIntegerSetReference(set); + printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); + printTrailingLocation(inst->getLoc()); + os << " {\n"; + print(inst->getThen(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + if (inst->hasElse()) { + os << " else {\n"; + print(inst->getElse(), /*printBlockArgs=*/false); + os.indent(currentIndent) << "}"; + } +} + void FunctionPrinter::printValueID(const Value *value, bool printResultNo) const { int resultNo = -1; diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index e174fdc1d003..4471ff25e946 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -327,3 +327,10 @@ ForInst *FuncBuilder::createFor(Location location, int64_t lb, int64_t ub, auto ubMap = AffineMap::getConstantMap(ub, context); return createFor(location, {}, lbMap, {}, ubMap, step); } + +IfInst *FuncBuilder::createIf(Location location, ArrayRef operands, + IntegerSet set) { + auto *inst = IfInst::create(location, operands, set); + block->getInstructions().insert(insertPoint, inst); + return inst; +} diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index 0ccab2305ec1..6d74ed142571 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -73,6 +73,9 @@ void Instruction::destroy() { case Kind::For: delete cast(this); break; + case Kind::If: + delete cast(this); + break; } } @@ -138,6 +141,8 @@ unsigned Instruction::getNumOperands() const { return cast(this)->getNumOperands(); case Kind::For: return cast(this)->getNumOperands(); + case Kind::If: + return cast(this)->getNumOperands(); } } @@ -147,6 +152,8 @@ MutableArrayRef Instruction::getInstOperands() { return cast(this)->getInstOperands(); case Kind::For: return cast(this)->getInstOperands(); + case Kind::If: + return cast(this)->getInstOperands(); } } @@ -280,6 +287,15 @@ void Instruction::dropAllReferences() { // Make sure to drop references held by instructions within the body. cast(this)->getBody()->dropAllReferences(); break; + case Kind::If: { + // Make sure to drop references held by instructions within the 'then' and + // 'else' blocks. + auto *ifInst = cast(this); + ifInst->getThen()->dropAllReferences(); + if (auto *elseBlock = ifInst->getElse()) + elseBlock->dropAllReferences(); + break; + } case Kind::OperationInst: { auto *opInst = cast(this); if (isTerminator()) @@ -793,6 +809,54 @@ mlir::extractForInductionVars(ArrayRef forInsts) { results.push_back(forInst->getInductionVar()); return results; } +//===----------------------------------------------------------------------===// +// IfInst +//===----------------------------------------------------------------------===// + +IfInst::IfInst(Location location, unsigned numOperands, IntegerSet set) + : Instruction(Kind::If, location), thenClause(this), elseClause(nullptr), + set(set) { + operands.reserve(numOperands); + + // The then of an 'if' inst always has one block. + thenClause.push_back(new Block()); +} + +IfInst::~IfInst() { + if (elseClause) + delete elseClause; + + // An IfInst's IntegerSet 'set' should not be deleted since it is + // allocated through MLIRContext's bump pointer allocator. +} + +IfInst *IfInst::create(Location location, ArrayRef operands, + IntegerSet set) { + unsigned numOperands = operands.size(); + assert(numOperands == set.getNumOperands() && + "operand cound does not match the integer set operand count"); + + IfInst *inst = new IfInst(location, numOperands, set); + + for (auto *op : operands) + inst->operands.emplace_back(InstOperand(inst, op)); + + return inst; +} + +const AffineCondition IfInst::getCondition() const { + return AffineCondition(*this, set); +} + +MLIRContext *IfInst::getContext() const { + // Check for degenerate case of if instruction with no operands. + // This is unlikely, but legal. + if (operands.empty()) + return getFunction()->getContext(); + + return getOperand(0)->getType().getContext(); +} + //===----------------------------------------------------------------------===// // Instruction Cloning //===----------------------------------------------------------------------===// @@ -867,23 +931,40 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, for (auto *opValue : getOperands()) operands.push_back(mapper.lookupOrDefault(const_cast(opValue))); - // Otherwise, this must be a ForInst. - auto *forInst = cast(this); - auto lbMap = forInst->getLowerBoundMap(); - auto ubMap = forInst->getUpperBoundMap(); + if (auto *forInst = dyn_cast(this)) { + auto lbMap = forInst->getLowerBoundMap(); + auto ubMap = forInst->getUpperBoundMap(); - auto *newFor = ForInst::create( - getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), - lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), ubMap, - forInst->getStep()); + auto *newFor = ForInst::create( + getLoc(), ArrayRef(operands).take_front(lbMap.getNumInputs()), + lbMap, ArrayRef(operands).take_back(ubMap.getNumInputs()), + ubMap, forInst->getStep()); - // Remember the induction variable mapping. - mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); + // Remember the induction variable mapping. + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); - // Recursively clone the body of the for loop. - for (auto &subInst : *forInst->getBody()) - newFor->getBody()->push_back(subInst.clone(mapper, context)); - return newFor; + // Recursively clone the body of the for loop. + for (auto &subInst : *forInst->getBody()) + newFor->getBody()->push_back(subInst.clone(mapper, context)); + + return newFor; + } + + // Otherwise, we must have an If instruction. + auto *ifInst = cast(this); + auto *newIf = IfInst::create(getLoc(), operands, ifInst->getIntegerSet()); + + auto *resultThen = newIf->getThen(); + for (auto &childInst : *ifInst->getThen()) + resultThen->push_back(childInst.clone(mapper, context)); + + if (ifInst->hasElse()) { + auto *resultElse = newIf->createElse(); + for (auto &childInst : *ifInst->getElse()) + resultElse->push_back(childInst.clone(mapper, context)); + } + + return newIf; } Instruction *Instruction::clone(MLIRContext *context) const { diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 2ab151f8913a..099b218892f7 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -281,7 +281,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { if (!block || &block->back() != op) return op->emitOpError("must be the last instruction in the parent block"); - // TODO(riverriddle) Terminators may not exist with an operation region. + // Terminators may not exist in ForInst and IfInst. if (block->getContainingInst()) return op->emitOpError("may only be at the top level of a function"); diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7103eeb7389e..6418b062dc16 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -66,6 +66,8 @@ MLIRContext *IROperandOwner::getContext() const { return cast(this)->getContext(); case Kind::ForInst: return cast(this)->getContext(); + case Kind::IfInst: + return cast(this)->getContext(); } } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e5d6aa46565a..c477ad1bbc5c 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -996,7 +996,8 @@ Attribute Parser::parseAttribute(Type type) { AffineMap map; IntegerSet set; if (parseAffineMapOrIntegerSetReference(map, set)) - return nullptr; + return (emitError("expected affine map or integer set attribute value"), + nullptr); if (map) return builder.getAffineMapAttr(map); assert(set); @@ -2208,6 +2209,8 @@ public: const char *affineStructName); ParseResult parseBound(SmallVectorImpl &operands, AffineMap &map, bool isLower); + ParseResult parseIfInst(); + ParseResult parseElseClause(Block *elseClause); ParseResult parseInstructions(Block *block); private: @@ -2389,6 +2392,10 @@ ParseResult FunctionParser::parseBlockBody(Block *block) { if (parseForInst()) return ParseFailure; break; + case Token::kw_if: + if (parseIfInst()) + return ParseFailure; + break; } } @@ -2928,18 +2935,12 @@ public: return false; } - /// Parse an optional keyword. - bool parseOptionalKeyword(const char *keyword) override { - // Check that the current token is a bare identifier or keyword. - if (parser.getToken().isNot(Token::bare_identifier) && - !parser.getToken().isKeyword()) - return true; - - if (parser.getTokenSpelling() == keyword) { - parser.consumeToken(); - return false; - } - return true; + /// Parse a keyword followed by a type. + bool parseKeywordType(const char *keyword, Type &result) override { + if (parser.getTokenSpelling() != keyword) + return parser.emitError("expected '" + Twine(keyword) + "'"); + parser.consumeToken(); + return !(result = parser.parseType()); } /// Parse an arbitrary attribute of a given type and return it in result. This @@ -3077,15 +3078,6 @@ public: return result == nullptr; } - /// Parses a list of blocks. - bool parseBlockList() override { - SmallVector results; - if (parser.parseOperationBlockList(results)) - return true; - parsedBlockLists.emplace_back(results); - return false; - } - //===--------------------------------------------------------------------===// // Methods for interacting with the parser //===--------------------------------------------------------------------===// @@ -3107,11 +3099,6 @@ public: /// Emit a diagnostic at the specified location and return true. bool emitError(llvm::SMLoc loc, const Twine &message) override { - // If we emit an error, then cleanup any parsed block lists. - for (auto &blockList : parsedBlockLists) - parser.cleanupInvalidBlocks(blockList); - parsedBlockLists.clear(); - parser.emitError(loc, "custom op '" + Twine(opName) + "' " + message); emittedError = true; return true; @@ -3119,13 +3106,7 @@ public: bool didEmitError() const { return emittedError; } - /// Returns the block lists that were parsed. - MutableArrayRef> getParsedBlockLists() { - return parsedBlockLists; - } - private: - std::vector> parsedBlockLists; SMLoc nameLoc; StringRef opName; FunctionParser &parser; @@ -3164,25 +3145,8 @@ OperationInst *FunctionParser::parseCustomOperation() { if (opAsmParser.didEmitError()) return nullptr; - // Check that enough block lists were reserved for those that were parsed. - auto parsedBlockLists = opAsmParser.getParsedBlockLists(); - if (parsedBlockLists.size() > opState.numBlockLists) { - opAsmParser.emitError( - opLoc, - "parsed more block lists than those reserved in the operation state"); - return nullptr; - } - // Otherwise, we succeeded. Use the state it parsed as our op information. - auto *opInst = builder.createOperation(opState); - - // Resolve any parsed block lists. - for (unsigned i = 0, e = parsedBlockLists.size(); i != e; ++i) { - auto &opBlockList = opInst->getBlockList(i).getBlocks(); - opBlockList.insert(opBlockList.end(), parsedBlockLists[i].begin(), - parsedBlockLists[i].end()); - } - return opInst; + return builder.createOperation(opState); } /// For instruction. @@ -3474,6 +3438,69 @@ IntegerSet AffineParser::parseIntegerSetConstraints(unsigned numDims, return builder.getIntegerSet(numDims, numSymbols, constraints, isEqs); } +/// If instruction. +/// +/// ml-if-head ::= `if` ml-if-cond trailing-location? `{` inst* `}` +/// | ml-if-head `else` `if` ml-if-cond trailing-location? +/// `{` inst* `}` +/// ml-if-inst ::= ml-if-head +/// | ml-if-head `else` `{` inst* `}` +/// +ParseResult FunctionParser::parseIfInst() { + auto loc = getToken().getLoc(); + consumeToken(Token::kw_if); + + IntegerSet set = parseIntegerSetReference(); + if (!set) + return ParseFailure; + + SmallVector operands; + if (parseDimAndSymbolList(operands, set.getNumDims(), set.getNumOperands(), + "integer set")) + return ParseFailure; + + IfInst *ifInst = + builder.createIf(getEncodedSourceLocation(loc), operands, set); + + // Try to parse the optional trailing location. + if (parseOptionalTrailingLocation(ifInst)) + return ParseFailure; + + Block *thenClause = ifInst->getThen(); + + // When parsing of an if instruction body fails, the IR contains + // the if instruction with the portion of the body that has been + // successfully parsed. + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseBlock(thenClause) || + parseToken(Token::r_brace, "expected '}' after instruction list")) + return ParseFailure; + + if (consumeIf(Token::kw_else)) { + auto *elseClause = ifInst->createElse(); + if (parseElseClause(elseClause)) + return ParseFailure; + } + + // Reset insertion point to the current block. + builder.setInsertionPointToEnd(ifInst->getBlock()); + + return ParseSuccess; +} + +ParseResult FunctionParser::parseElseClause(Block *elseClause) { + if (getToken().is(Token::kw_if)) { + builder.setInsertionPointToEnd(elseClause); + return parseIfInst(); + } + + if (parseToken(Token::l_brace, "expected '{' before instruction list") || + parseBlock(elseClause) || + parseToken(Token::r_brace, "expected '}' after instruction list")) + return ParseFailure; + return ParseSuccess; +} + //===----------------------------------------------------------------------===// // Top-level entity parsing. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def index ec00f98b3f5c..40e98b25cb33 100644 --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -91,6 +91,7 @@ TOK_KEYWORD(attributes) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(dense) +TOK_KEYWORD(else) TOK_KEYWORD(splat) TOK_KEYWORD(f16) TOK_KEYWORD(f32) @@ -99,6 +100,7 @@ TOK_KEYWORD(false) TOK_KEYWORD(floordiv) TOK_KEYWORD(for) TOK_KEYWORD(func) +TOK_KEYWORD(if) TOK_KEYWORD(index) TOK_KEYWORD(loc) TOK_KEYWORD(max) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index afd18a49b793..c2e1636626d3 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -188,6 +188,16 @@ void CSE::simplifyBlock(Block *bb) { simplifyBlock(cast(i).getBody()); break; } + case Instruction::Kind::If: { + auto &ifInst = cast(i); + if (auto *elseBlock = ifInst.getElse()) { + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(elseBlock); + } + ScopedMapTy::ScopeTy scope(knownValues); + simplifyBlock(ifInst.getThen()); + break; + } } } } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp index eebbbe9daa77..cee0a08a63cf 100644 --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -19,7 +19,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -100,16 +99,16 @@ public: SmallVector forInsts; SmallVector loadOpInsts; SmallVector storeOpInsts; - bool hasNonForRegion = false; + bool hasIfInst = false; void visitForInst(ForInst *forInst) { forInsts.push_back(forInst); } + void visitIfInst(IfInst *ifInst) { hasIfInst = true; } + void visitOperationInst(OperationInst *opInst) { - if (opInst->getNumBlockLists() != 0) - hasNonForRegion = true; - else if (opInst->isa()) + if (opInst->isa()) loadOpInsts.push_back(opInst); - else if (opInst->isa()) + if (opInst->isa()) storeOpInsts.push_back(opInst); } }; @@ -411,8 +410,8 @@ bool MemRefDependenceGraph::init(Function *f) { // all loads and store accesses it contains. LoopNestStateCollector collector; collector.walkForInst(forInst); - // Return false if a non 'for' region was found (not currently supported). - if (collector.hasNonForRegion) + // Return false if IfInsts are found (not currently supported). + if (collector.hasIfInst) return false; Node node(id++, &inst); for (auto *opInst : collector.loadOpInsts) { @@ -435,18 +434,19 @@ bool MemRefDependenceGraph::init(Function *f) { auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (auto storeOp = opInst->dyn_cast()) { + } + if (auto storeOp = opInst->dyn_cast()) { // Create graph node for top-level store op. Node node(id++, &inst); node.stores.push_back(opInst); auto *memref = opInst->cast()->getMemRef(); memrefAccesses[memref].insert(node.id); nodes.insert({node.id, node}); - } else if (opInst->getNumBlockLists() != 0) { - // Return false if another region is found (not currently supported). - return false; } } + // Return false if IfInsts are found (not currently supported). + if (isa(&inst)) + return false; } // Walk memref access lists and add graph edges between dependent nodes. diff --git a/mlir/lib/Transforms/LoopUnroll.cpp b/mlir/lib/Transforms/LoopUnroll.cpp index 6d63e4afd2d4..39ef758833ba 100644 --- a/mlir/lib/Transforms/LoopUnroll.cpp +++ b/mlir/lib/Transforms/LoopUnroll.cpp @@ -119,6 +119,15 @@ PassResult LoopUnroll::runOnFunction(Function *f) { return true; } + bool walkIfInstPostOrder(IfInst *ifInst) { + bool hasInnerLoops = + walkPostOrder(ifInst->getThen()->begin(), ifInst->getThen()->end()); + if (ifInst->hasElse()) + hasInnerLoops |= + walkPostOrder(ifInst->getElse()->begin(), ifInst->getElse()->end()); + return hasInnerLoops; + } + bool walkOpInstPostOrder(OperationInst *opInst) { for (auto &blockList : opInst->getBlockLists()) for (auto &block : blockList) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index f770684f5198..ab37ff63badd 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -247,7 +246,7 @@ public: PassResult runOnFunction(Function *function) override; bool lowerForInst(ForInst *forInst); - bool lowerAffineIf(AffineIfOp *ifOp); + bool lowerIfInst(IfInst *ifInst); bool lowerAffineApply(AffineApplyOp *op); static char passID; @@ -410,7 +409,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // enabling easy nesting of "if" instructions and if-then-else-if chains. // // +--------------------------------+ -// | | +// | | // | %zero = constant 0 : index | // | %v = affine_apply #expr1(%ops) | // | %c = cmpi "sge" %v, %zero | @@ -454,11 +453,10 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { // v v // +--------------------------------+ // | continue: | -// | | +// | | // +--------------------------------+ // -bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { - auto *ifInst = ifOp->getInstruction(); +bool LowerAffinePass::lowerIfInst(IfInst *ifInst) { auto loc = ifInst->getLoc(); // Start by splitting the block containing the 'if' into two parts. The part @@ -468,38 +466,22 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { auto *continueBlock = condBlock->splitBlock(ifInst); // Create a block for the 'then' code, inserting it between the cond and - // continue blocks. Move the instructions over from the AffineIfOp and add a + // continue blocks. Move the instructions over from the IfInst and add a // branch to the continuation point. Block *thenBlock = new Block(); thenBlock->insertBefore(continueBlock); - // If the 'then' block is not empty, then splice the instructions. - auto &oldThenBlocks = ifOp->getThenBlocks(); - if (!oldThenBlocks.empty()) { - // We currently only handle one 'then' block. - if (std::next(oldThenBlocks.begin()) != oldThenBlocks.end()) - return true; - - Block *oldThen = &oldThenBlocks.front(); - - thenBlock->getInstructions().splice(thenBlock->begin(), - oldThen->getInstructions(), - oldThen->begin(), oldThen->end()); - } - + auto *oldThen = ifInst->getThen(); + thenBlock->getInstructions().splice(thenBlock->begin(), + oldThen->getInstructions(), + oldThen->begin(), oldThen->end()); FuncBuilder builder(thenBlock); builder.create(loc, continueBlock); // Handle the 'else' block the same way, but we skip it if we have no else // code. Block *elseBlock = continueBlock; - auto &oldElseBlocks = ifOp->getElseBlocks(); - if (!oldElseBlocks.empty()) { - // We currently only handle one 'else' block. - if (std::next(oldElseBlocks.begin()) != oldElseBlocks.end()) - return true; - - auto *oldElse = &oldElseBlocks.front(); + if (auto *oldElse = ifInst->getElse()) { elseBlock = new Block(); elseBlock->insertBefore(continueBlock); @@ -511,7 +493,7 @@ bool LowerAffinePass::lowerAffineIf(AffineIfOp *ifOp) { } // Ok, now we just have to handle the condition logic. - auto integerSet = ifOp->getIntegerSet(); + auto integerSet = ifInst->getCondition().getIntegerSet(); // Implement short-circuit logic. For each affine expression in the 'if' // condition, convert it into an affine map and call `affine_apply` to obtain @@ -611,30 +593,29 @@ bool LowerAffinePass::lowerAffineApply(AffineApplyOp *op) { PassResult LowerAffinePass::runOnFunction(Function *function) { SmallVector instsToRewrite; - // Collect all the For instructions as well as AffineIfOps and AffineApplyOps. - // We do this as a prepass to avoid invalidating the walker with our rewrite. + // Collect all the If and For instructions as well as AffineApplyOps. We do + // this as a prepass to avoid invalidating the walker with our rewrite. function->walkInsts([&](Instruction *inst) { - if (isa(inst)) + if (isa(inst) || isa(inst)) instsToRewrite.push_back(inst); auto op = dyn_cast(inst); - if (op && (op->isa() || op->isa())) + if (op && op->isa()) instsToRewrite.push_back(inst); }); // Rewrite all of the ifs and fors. We walked the instructions in preorder, // so we know that we will rewrite them in the same order. for (auto *inst : instsToRewrite) - if (auto *forInst = dyn_cast(inst)) { + if (auto *ifInst = dyn_cast(inst)) { + if (lowerIfInst(ifInst)) + return failure(); + } else if (auto *forInst = dyn_cast(inst)) { if (lowerForInst(forInst)) return failure(); } else { auto op = cast(inst); - if (auto ifOp = op->dyn_cast()) { - if (lowerAffineIf(ifOp)) - return failure(); - } else if (lowerAffineApply(op->cast())) { + if (lowerAffineApply(op->cast())) return failure(); - } } return success(); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 2744b1d624c0..09d961f85cd4 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -20,7 +20,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/AffineOps/AffineOps.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/Dominance.h" #include "mlir/Analysis/LoopAnalysis.h" @@ -560,6 +559,9 @@ static bool instantiateMaterialization(Instruction *inst, if (isa(inst)) return inst->emitError("NYI path ForInst"); + if (isa(inst)) + return inst->emitError("NYI path IfInst"); + // Create a builder here for unroll-and-jam effects. FuncBuilder b(inst); auto *opInst = cast(inst); @@ -568,9 +570,6 @@ static bool instantiateMaterialization(Instruction *inst, if (opInst->isa()) { return false; } - if (opInst->getNumBlockLists() != 0) - return inst->emitError("NYI path Op with region"); - if (auto write = opInst->dyn_cast()) { auto *clone = instantiate(&b, write, state->hwVectorType, state->hwVectorInstance, state->substitutionsMap); diff --git a/mlir/lib/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Transforms/SimplifyAffineStructures.cpp index ba59123c7004..bd39e47786ae 100644 --- a/mlir/lib/Transforms/SimplifyAffineStructures.cpp +++ b/mlir/lib/Transforms/SimplifyAffineStructures.cpp @@ -28,6 +28,7 @@ #define DEBUG_TYPE "simplify-affine-structure" using namespace mlir; +using llvm::report_fatal_error; namespace { @@ -41,6 +42,9 @@ struct SimplifyAffineStructures : public FunctionPass { PassResult runOnFunction(Function *f) override; + void visitIfInst(IfInst *ifInst); + void visitOperationInst(OperationInst *opInst); + static char passID; }; @@ -62,19 +66,28 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) { return set; } -PassResult SimplifyAffineStructures::runOnFunction(Function *f) { - f->walkOps([&](OperationInst *opInst) { - for (auto attr : opInst->getAttrs()) { - if (auto mapAttr = attr.second.dyn_cast()) { - MutableAffineMap mMap(mapAttr.getValue()); - mMap.simplify(); - auto map = mMap.getAffineMap(); - opInst->setAttr(attr.first, AffineMapAttr::get(map)); - } else if (auto setAttr = attr.second.dyn_cast()) { - auto simplified = simplifyIntegerSet(setAttr.getValue()); - opInst->setAttr(attr.first, IntegerSetAttr::get(simplified)); - } +void SimplifyAffineStructures::visitIfInst(IfInst *ifInst) { + auto set = ifInst->getCondition().getIntegerSet(); + ifInst->setIntegerSet(simplifyIntegerSet(set)); +} + +void SimplifyAffineStructures::visitOperationInst(OperationInst *opInst) { + for (auto attr : opInst->getAttrs()) { + if (auto mapAttr = attr.second.dyn_cast()) { + MutableAffineMap mMap(mapAttr.getValue()); + mMap.simplify(); + auto map = mMap.getAffineMap(); + opInst->setAttr(attr.first, AffineMapAttr::get(map)); } + } +} + +PassResult SimplifyAffineStructures::runOnFunction(Function *f) { + f->walkInsts([&](Instruction *inst) { + if (auto *opInst = dyn_cast(inst)) + visitOperationInst(opInst); + if (auto *ifInst = dyn_cast(inst)) + visitIfInst(ifInst); }); return success(); diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 595991c01097..bae112dd3b9f 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -243,6 +243,14 @@ func @non_instruction() { // ----- +func @invalid_if_conditional1() { + for %i = 1 to 10 { + if () { // expected-error {{expected ':' or '['}} + } +} + +// ----- + func @invalid_if_conditional2() { for %i = 1 to 10 { if (i)[N] : (i >= ) // expected-error {{expected '== 0' or '>= 0' at end of affine constraint}} @@ -656,11 +664,7 @@ func @invalid_if_operands2(%N : index) { func @invalid_if_operands3(%N : index) { for %i = 1 to 10 { if #set0(%i)[%i] { - // expected-error@-1 {{operand cannot be used as a symbol}} - } - } - return -} + // expected-error@-1 {{value '%i' cannot be used as a symbol}} // ----- // expected-error@+1 {{expected '"' in string literal}} diff --git a/mlir/test/IR/locations.mlir b/mlir/test/IR/locations.mlir index 8a90d12bd03c..e3e1bbbbfad6 100644 --- a/mlir/test/IR/locations.mlir +++ b/mlir/test/IR/locations.mlir @@ -16,9 +16,9 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } loc(fused<"myPass">["foo", "foo2"]) - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + // CHECK: ) loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc(unknown) diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 626f24569c68..331096065385 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -287,15 +287,13 @@ func @ifinst(%N: index) { // CHECK: %c1_i32 = constant 1 : i32 %y = "add"(%x, %i) : (i32, index) -> i32 // CHECK: %0 = "add"(%c1_i32, %i0) : (i32, index) -> i32 %z = "mul"(%y, %y) : (i32, i32) -> i32 // CHECK: %1 = "mul"(%0, %0) : (i32, i32) -> i32 - } else { // CHECK } else { - if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK if (#set1(%i0)[%arg0]) { - // CHECK: %c1 = constant 1 : index - %u = constant 1 : index - // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] - %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] - } else { // CHECK } else { - %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 - } + } else if (i)[N] : (i - 2 >= 0, 4 - i >= 0)(%i)[%N] { // CHECK } else if (#set1(%i0)[%arg0]) { + // CHECK: %c1 = constant 1 : index + %u = constant 1 : index + // CHECK: %2 = affine_apply #map{{.*}}(%i0, %i0)[%c1] + %w = affine_apply (d0,d1)[s0] -> (d0+d1+s0) (%i, %i) [%u] + } else { // CHECK } else { + %v = constant 3 : i32 // %c3_i32 = constant 3 : i32 } // CHECK } } // CHECK } return // CHECK return @@ -753,11 +751,11 @@ func @type_alias() -> !i32_type_alias { func @verbose_if(%N: index) { %c = constant 200 : index - // CHECK: if #set0(%c200)[%arg0, %c200] { - "if"(%c, %N, %c) { condition: #set0 } : (index, index, index) -> () { + // CHECK: "if"(%c200, %arg0, %c200) {cond: #set0} : (index, index, index) -> () { + "if"(%c, %N, %c) { cond: #set0 } : (index, index, index) -> () { // CHECK-NEXT: "add" %y = "add"(%c, %N) : (index, index) -> index - // CHECK-NEXT: } else { + // CHECK-NEXT: } { } { // The else block list. // CHECK-NEXT: "add" %z = "add"(%c, %c) : (index, index) -> index diff --git a/mlir/test/IR/pretty-locations.mlir b/mlir/test/IR/pretty-locations.mlir index 69dace451654..cb2e14a56d5d 100644 --- a/mlir/test/IR/pretty-locations.mlir +++ b/mlir/test/IR/pretty-locations.mlir @@ -21,10 +21,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } <"myPass">["foo", "foo2"] - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + // CHECK: ) <"myPass">["foo", "foo2"] + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 [unknown] return %1 : i32 loc(unknown) -} +} \ No newline at end of file diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir index 162f193f6629..d170ce590f7f 100644 --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -483,7 +483,7 @@ func @should_not_fuse_if_inst_at_top_level() { %c0 = constant 4 : index if #set0(%c0) { } - // Top-level IfOp should prevent fusion. + // Top-level IfInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } @@ -512,7 +512,7 @@ func @should_not_fuse_if_inst_in_loop_nest() { %v0 = load %m[%i1] : memref<10xf32> } - // IfOp in ForInst should prevent fusion. + // IfInst in ForInst should prevent fusion. // CHECK: for %i0 = 0 to 10 { // CHECK-NEXT: store %cst, %0[%i0] : memref<10xf32> // CHECK-NEXT: } diff --git a/mlir/test/Transforms/memref-dependence-check.mlir b/mlir/test/Transforms/memref-dependence-check.mlir index 628044ed77ae..6f6ad3fafc79 100644 --- a/mlir/test/Transforms/memref-dependence-check.mlir +++ b/mlir/test/Transforms/memref-dependence-check.mlir @@ -10,7 +10,7 @@ func @store_may_execute_before_load() { %cf7 = constant 7.0 : f32 %c0 = constant 4 : index // There is a dependence from store 0 to load 1 at depth 1 because the - // ancestor IfOp of the store, dominates the ancestor ForSmt of the load, + // ancestor IfInst of the store, dominates the ancestor ForSmt of the load, // and thus the store "may" conditionally execute before the load. if #set0(%c0) { for %i0 = 0 to 10 { diff --git a/mlir/test/Transforms/strip-debug-info.mlir b/mlir/test/Transforms/strip-debug-info.mlir index 13f009deb701..5509c7aba551 100644 --- a/mlir/test/Transforms/strip-debug-info.mlir +++ b/mlir/test/Transforms/strip-debug-info.mlir @@ -13,10 +13,10 @@ func @inline_notation() -> i32 loc("mysource.cc":10:8) { for %i0 = 0 to 8 loc(fused["foo", "mysource.cc":10:8]) { } - // CHECK: } loc(unknown) + // CHECK: if #set0(%c4) loc(unknown) %2 = constant 4 : index - if #set0(%2) { - } loc(fused<"myPass">["foo", "foo2"]) + if #set0(%2) loc(fused<"myPass">["foo", "foo2"]) { + } // CHECK: return %0 : i32 loc(unknown) return %1 : i32 loc("bar")