From 36babbd7815519db5d26f55695fa3ec500997bcd Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 26 Jan 2019 12:40:12 -0800 Subject: [PATCH] Change the ForInst induction variable to be a block argument of the body instead of the ForInst itself. This is a necessary step in converting ForInst into an operation. PiperOrigin-RevId: 231064139 --- mlir/include/mlir/IR/Instructions.h | 33 ++++---- mlir/include/mlir/IR/Value.h | 4 +- mlir/lib/Analysis/AffineAnalysis.cpp | 17 ++-- mlir/lib/Analysis/AffineStructures.cpp | 2 +- mlir/lib/Analysis/Dominance.cpp | 11 +-- mlir/lib/Analysis/LoopAnalysis.cpp | 8 +- mlir/lib/Analysis/SliceAnalysis.cpp | 2 +- mlir/lib/Analysis/Utils.cpp | 7 +- mlir/lib/Analysis/VectorAnalysis.cpp | 3 +- mlir/lib/EDSC/MLIREmitter.cpp | 9 +-- mlir/lib/IR/AsmPrinter.cpp | 84 ++++++++------------ mlir/lib/IR/Instruction.cpp | 50 +++++++++--- mlir/lib/IR/Value.cpp | 8 +- mlir/lib/Parser/Parser.cpp | 5 +- mlir/lib/Transforms/DmaGeneration.cpp | 3 +- mlir/lib/Transforms/LoopTiling.cpp | 13 +-- mlir/lib/Transforms/LoopUnrollAndJam.cpp | 8 +- mlir/lib/Transforms/LowerAffine.cpp | 2 +- mlir/lib/Transforms/PipelineDataTransfer.cpp | 4 +- mlir/lib/Transforms/Utils/LoopUtils.cpp | 26 +++--- mlir/lib/Transforms/Vectorize.cpp | 8 +- 21 files changed, 172 insertions(+), 135 deletions(-) diff --git a/mlir/include/mlir/IR/Instructions.h b/mlir/include/mlir/IR/Instructions.h index 8085e7202608..71d832b8b90d 100644 --- a/mlir/include/mlir/IR/Instructions.h +++ b/mlir/include/mlir/IR/Instructions.h @@ -555,19 +555,12 @@ inline auto OperationInst::getResultTypes() const } /// For instruction represents an affine loop nest. -class ForInst : public Instruction, public Value { +class ForInst : public Instruction { public: static ForInst *create(Location location, ArrayRef lbOperands, AffineMap lbMap, ArrayRef ubOperands, AffineMap ubMap, int64_t step); - ~ForInst() { - // There may be references to the induction variable of this loop within its - // body or, in case of ill-formed code during parsing, outside its body. - // Explicitly drop all uses of the induction variable before destroying it. - dropAllUses(); - } - /// Resolve base class ambiguity. using Instruction::getFunction; @@ -700,7 +693,9 @@ public: //===--------------------------------------------------------------------===// /// Return the context this operation is associated with. - MLIRContext *getContext() const { return getType().getContext(); } + MLIRContext *getContext() const { + return getInductionVar()->getType().getContext(); + } using Instruction::dump; using Instruction::print; @@ -710,11 +705,10 @@ public: return ptr->getKind() == IROperandOwner::Kind::ForInst; } - // For instruction represents implicitly represents induction variable by - // inheriting from Value class. Whenever you need to refer to the loop - // induction variable, just use the for instruction itself. - static bool classof(const Value *value) { - return value->getKind() == Value::Kind::ForInst; + /// Returns the induction variable for this loop. + Value *getInductionVar(); + const Value *getInductionVar() const { + return const_cast(this)->getInductionVar(); } private: @@ -738,6 +732,17 @@ private: AffineMap ubMap, int64_t step); }; +/// Returns if the provided value is the induction variable of a ForInst. +bool isForInductionVar(const Value *val); + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +ForInst *getForInductionVarOwner(Value *val); +const ForInst *getForInductionVarOwner(const Value *val); + +/// Extracts the induction variables from a list of ForInsts and returns them. +SmallVector extractForInductionVars(ArrayRef forInsts); + /// AffineBound represents a lower or upper bound in the for instruction. /// This class does not own the underlying operands. Instead, it refers /// to the operands stored in the ForInst. Its life span should not exceed diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 48af9c71be63..90f1f484b1fe 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -45,7 +45,6 @@ public: enum class Kind { BlockArgument, // block argument InstResult, // operation instruction result - ForInst, // 'for' instruction induction variable }; ~Value() {} @@ -141,6 +140,9 @@ public: /// Returns the number of this argument. unsigned getArgNumber() const; + /// Returns if the current argument is a function argument. + bool isFunctionArgument() const; + private: friend class Block; // For access to private constructor. BlockArgument(Type type, Block *owner) diff --git a/mlir/lib/Analysis/AffineAnalysis.cpp b/mlir/lib/Analysis/AffineAnalysis.cpp index 1ecad8d4e908..a4d969bc2033 100644 --- a/mlir/lib/Analysis/AffineAnalysis.cpp +++ b/mlir/lib/Analysis/AffineAnalysis.cpp @@ -555,7 +555,7 @@ void mlir::getReachableAffineApplyOps( // setExprStride(ArrayRef expr, int64_t stride) bool mlir::getIndexSet(ArrayRef forInsts, FlatAffineConstraints *domain) { - SmallVector indices(forInsts.begin(), forInsts.end()); + auto indices = extractForInductionVars(forInsts); // Reset while associated Values in 'indices' to the domain. domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices); for (auto *forInst : forInsts) { @@ -677,7 +677,7 @@ static void buildDimAndSymbolPositionMaps( auto updateValuePosMap = [&](ArrayRef values, bool isSrc) { for (unsigned i = 0, e = values.size(); i < e; ++i) { auto *value = values[i]; - if (!isa(values[i])) { + if (!isForInductionVar(values[i])) { assert(values[i]->isValidSymbol() && "access operand has to be either a loop IV or a symbol"); valuePosMap->addSymbolValue(value); @@ -739,7 +739,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain, // Set values for the symbolic identifier dimensions. auto setSymbolIds = [&](ArrayRef values) { for (auto *value : values) { - if (!isa(value)) { + if (!isForInductionVar(value)) { assert(value->isValidSymbol() && "expected symbol"); dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value); } @@ -907,7 +907,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap, // Add equality constraints for any operands that are defined by constant ops. auto addEqForConstOperands = [&](ArrayRef operands) { for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (isa(operands[i])) + if (isForInductionVar(operands[i])) continue; auto *symbol = operands[i]; assert(symbol->isValidSymbol()); @@ -976,8 +976,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain, std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds()); unsigned numCommonLoops = 0; for (unsigned i = 0; i < minNumLoops; ++i) { - if (!isa(srcDomain.getIdValue(i)) || - !isa(dstDomain.getIdValue(i)) || + if (!isForInductionVar(srcDomain.getIdValue(i)) || + !isForInductionVar(dstDomain.getIdValue(i)) || srcDomain.getIdValue(i) != dstDomain.getIdValue(i)) break; ++numCommonLoops; @@ -998,8 +998,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess, return block; } auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1); - assert(isa(commonForValue)); - return cast(commonForValue)->getBody(); + auto *forInst = getForInductionVarOwner(commonForValue); + assert(forInst && "commonForValue was not an induction variable"); + return forInst->getBody(); } // Returns true if the ancestor operation instruction of 'srcAccess' appears diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 268fbe0c9c60..7aa23bbe4808 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -1251,7 +1251,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) { bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) { unsigned pos; // Pre-condition for this method. - if (!findId(forInst, &pos)) { + if (!findId(*forInst.getInductionVar(), &pos)) { assert(0 && "Value not found"); return false; } diff --git a/mlir/lib/Analysis/Dominance.cpp b/mlir/lib/Analysis/Dominance.cpp index b98efa73e542..5cdeebbdf4ac 100644 --- a/mlir/lib/Analysis/Dominance.cpp +++ b/mlir/lib/Analysis/Dominance.cpp @@ -53,9 +53,9 @@ bool DominanceInfo::properlyDominates(const Block *a, const Block *b) { if (blockListA == blockListB) return DominatorTreeBase::properlyDominates(a, b); - // Otherwise, 'a' properly dominates 'b' if 'b' is defined in an - // IfInst/ForInst that (recursively) ends up being dominated by 'a'. Walk up - // the list of containers enclosing B. + // Otherwise, 'a' properly dominates 'b' if 'b' is defined in an instruction + // region that (recursively) ends up being dominated by 'a'. Walk up the list + // of containers enclosing B. Instruction *bAncestor; do { bAncestor = blockListB->getContainingInst(); @@ -106,11 +106,6 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) { if (auto *aInst = a->getDefiningInst()) return properlyDominates(aInst, b); - // The induction variable of a ForInst properly dominantes its body, so we - // can just do a simple block dominance check. - if (auto *forInst = dyn_cast(a)) - return dominates(forInst->getBody(), b->getBlock()); - // block arguments properly dominate all instructions in their own block, so // we use a dominates check here, not a properlyDominates check. return dominates(cast(a)->getOwner(), b->getBlock()); diff --git a/mlir/lib/Analysis/LoopAnalysis.cpp b/mlir/lib/Analysis/LoopAnalysis.cpp index b154ebab1052..640984bf866b 100644 --- a/mlir/lib/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Analysis/LoopAnalysis.cpp @@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) { } bool mlir::isAccessInvariant(const Value &iv, const Value &index) { - assert(isa(iv) && "iv must be a ForInst"); + assert(isForInductionVar(&iv) && "iv must be a ForInst"); assert(index.getType().isa() && "index must be of IndexType"); SmallVector affineApplyOps; getReachableAffineApplyOps({const_cast(&index)}, affineApplyOps); @@ -288,8 +288,10 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim( [fastestVaryingDim](const ForInst &loop, const OperationInst &op) { auto load = op.dyn_cast(); auto store = op.dyn_cast(); - return load ? isContiguousAccess(loop, *load, fastestVaryingDim) - : isContiguousAccess(loop, *store, fastestVaryingDim); + return load ? isContiguousAccess(*loop.getInductionVar(), *load, + fastestVaryingDim) + : isContiguousAccess(*loop.getInductionVar(), *store, + fastestVaryingDim); }); return isVectorizableLoopWithCond(loop, fun); } diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index a8cec771f0d8..d16a7fcb1b31 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -64,7 +64,7 @@ void mlir::getForwardSlice(Instruction *inst, } } } else if (auto *forInst = dyn_cast(inst)) { - for (auto &u : forInst->getUses()) { + for (auto &u : forInst->getInductionVar()->getUses()) { auto *ownerInst = u.getOwner(); if (forwardSlice->count(ownerInst) == 0) { getForwardSlice(ownerInst, forwardSlice, filter, diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 39e58e8983c5..939a2ede618e 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -149,7 +149,8 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, // A rank 0 memref has a 0-d region. SmallVector ivs; getLoopIVs(*opInst, &ivs); - SmallVector regionSymbols(ivs.begin(), ivs.end()); + + SmallVector regionSymbols = extractForInductionVars(ivs); regionCst->reset(0, loopDepth, 0, regionSymbols); return true; } @@ -172,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, unsigned numSymbols = accessMap.getNumSymbols(); // Add inequalties for loop lower/upper bounds. for (unsigned i = 0; i < numDims + numSymbols; ++i) { - if (auto *loop = dyn_cast(accessValueMap.getOperand(i))) { + if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) { // Note that regionCst can now have more dimensions than accessMap if the // bounds expressions involve outer loops or other symbols. // TODO(bondhugula): rewrite this to use getInstIndexSet; this way @@ -207,7 +208,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth, outerIVs.resize(loopDepth); for (auto *operand : accessValueMap.getOperands()) { ForInst *iv; - if ((iv = dyn_cast(operand)) && + if ((iv = getForInductionVarOwner(operand)) && std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) { regionCst->projectOut(operand); } diff --git a/mlir/lib/Analysis/VectorAnalysis.cpp b/mlir/lib/Analysis/VectorAnalysis.cpp index 37eed71508fd..125020e92a35 100644 --- a/mlir/lib/Analysis/VectorAnalysis.cpp +++ b/mlir/lib/Analysis/VectorAnalysis.cpp @@ -113,7 +113,8 @@ static AffineMap makePermutationMap( getAffineConstantExpr(0, context)); for (auto kvp : enclosingLoopToVectorDim) { assert(kvp.second < perm.size()); - auto invariants = getInvariantAccesses(*kvp.first, unwrappedIndices); + auto invariants = + getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices); unsigned numIndices = unwrappedIndices.size(); unsigned countInvariantIndices = 0; for (unsigned dim = 0; dim < numIndices; ++dim) { diff --git a/mlir/lib/EDSC/MLIREmitter.cpp b/mlir/lib/EDSC/MLIREmitter.cpp index 56f211ec5787..c2a6dc1f90a2 100644 --- a/mlir/lib/EDSC/MLIREmitter.cpp +++ b/mlir/lib/EDSC/MLIREmitter.cpp @@ -133,9 +133,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) { inst->print(os); return; } - // &v is required here otherwise we get: - // non-pointer operand type 'const mlir::ForInst' incompatible with nullptr - if (auto *forInst = dyn_cast(&v)) { + if (auto *forInst = getForInductionVarOwner(&v)) { forInst->print(os); } else { os << "unknown_ssa_value"; @@ -296,7 +294,7 @@ Value *MLIREmitter::emit(Expr e) { exprs[1]->getDefiningInst()->cast()->getValue(); auto step = exprs[2]->getDefiningInst()->cast()->getValue(); - res = builder->createFor(location, lb, ub, step); + res = builder->createFor(location, lb, ub, step)->getInductionVar(); } } @@ -347,7 +345,8 @@ void MLIREmitter::emitStmt(const Stmt &stmt) { bind(stmt.getLHS(), val); if (stmt.getRHS().getKind() == ExprKind::For) { // Step into the loop. - builder->setInsertionPointToStart(cast(val)->getBody()); + builder->setInsertionPointToStart( + getForInductionVarOwner(val)->getBody()); } } emitStmts(stmt.getEnclosedStmts()); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index af996213418d..21bc3b824b12 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1078,7 +1078,7 @@ public: void print(const OperationInst *inst); void print(const ForInst *inst); void print(const IfInst *inst); - void print(const Block *block); + void print(const Block *block, bool printBlockArgs = true); void printOperation(const OperationInst *op); void printGenericOp(const OperationInst *op); @@ -1125,10 +1125,15 @@ public: unsigned index) override; /// Print a block list. - void printBlockList(const BlockList &blocks) { + void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) { os << " {\n"; - for (auto &b : blocks) - print(&b); + if (!blocks.empty()) { + auto *entryBlock = &blocks.front(); + print(entryBlock, + printEntryBlockArgs && entryBlock->getNumArguments() != 0); + for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1)) + print(&b); + } os.indent(currentIndent) << "}"; } @@ -1164,8 +1169,8 @@ private: /// This is the next value ID to assign in numbering. unsigned nextValueID = 0; - /// This is the ID to assign to the next induction variable. - unsigned nextLoopID = 0; + /// This is the ID to assign to the next region entry block argument. + unsigned nextRegionArgumentID = 0; /// This is the next ID to assign to a Function argument. unsigned nextArgumentID = 0; /// This is the next ID to assign when a name conflict is detected. @@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) { numberValuesInBlock(block); break; } - case Instruction::Kind::For: { - auto *forInst = cast(&inst); - // Number the induction variable. - numberValueID(forInst); + case Instruction::Kind::For: // Recursively number the stuff in the body. - numberValuesInBlock(*forInst->getBody()); + numberValuesInBlock(*cast(&inst)->getBody()); break; - } case Instruction::Kind::If: { auto *ifInst = cast(&inst); numberValuesInBlock(*ifInst->getThen()); @@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) { if (specialNameBuffer.empty()) { switch (value->getKind()) { case Value::Kind::BlockArgument: - // If this is an argument to the function, give it an 'arg' name. - if (auto *block = cast(value)->getOwner()) - if (auto *fn = block->getFunction()) - if (&fn->getBlockList().front() == block) { + // If this is an argument to the function, give it an 'arg' name. If the + // argument is to an entry block of an operation region, give it an 'i' + // name. + if (auto *block = cast(value)->getOwner()) { + auto *parentBlockList = block->getParent(); + if (parentBlockList && block == &parentBlockList->front()) { + if (parentBlockList->getContainingFunction()) specialName << "arg" << nextArgumentID++; - break; - } + else + specialName << "i" << nextRegionArgumentID++; + break; + } + } // Otherwise number it normally. valueIDs[value] = nextValueID++; return; @@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) { // done with it. valueIDs[value] = nextValueID++; return; - case Value::Kind::ForInst: - specialName << 'i' << nextLoopID++; - break; } } @@ -1312,10 +1316,8 @@ void FunctionPrinter::print() { printTrailingLocation(function->getLoc()); if (!function->empty()) { - os << " {\n"; - for (const auto &block : *function) - print(&block); - os << "}\n"; + printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false); + os << "\n"; } os << '\n'; } @@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() { } } -/// Return true if the introducer for the specified block should be printed. -static bool shouldPrintBlockArguments(const Block *block) { - // Never print the entry block of the function - it is included in the - // argument list. - if (block == &block->getFunction()->front()) - return false; - - // If this is the first block in a nested region, and if there are no - // arguments, then we can omit it. - if (block == &block->getParent()->front() && block->getNumArguments() == 0) - return false; - - // Otherwise print it. - return true; -} - -void FunctionPrinter::print(const Block *block) { +void FunctionPrinter::print(const Block *block, bool printBlockArgs) { // Print the block label and argument list, unless this is the first block of // the function, or the first block of an IfInst/ForInst with no arguments. - if (shouldPrintBlockArguments(block)) { + if (printBlockArgs) { os.indent(currentIndent); printBlockName(block); @@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) { void FunctionPrinter::print(const ForInst *inst) { os.indent(currentIndent) << "for "; - printOperand(inst); + printOperand(inst->getInductionVar()); os << " = "; printBound(inst->getLowerBound(), "max"); os << " to "; @@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) { printTrailingLocation(inst->getLoc()); os << " {\n"; - print(inst->getBody()); + print(inst->getBody(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; } @@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) { printDimAndSymbolList(inst->getInstOperands(), set.getNumDims()); printTrailingLocation(inst->getLoc()); os << " {\n"; - print(inst->getThen()); + print(inst->getThen(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; if (inst->hasElse()) { os << " else {\n"; - print(inst->getElse()); + print(inst->getElse(), /*printBlockArgs=*/false); os.indent(currentIndent) << "}"; } } @@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) { // Print any trailing block lists. for (auto &blockList : op->getBlockLists()) - printBlockList(blockList); + printBlockList(blockList, /*printEntryBlockArgs=*/true); } void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term, @@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const { return; case Value::Kind::InstResult: return getDefiningInst()->print(os); - case Value::Kind::ForInst: - return cast(this)->print(os); } } diff --git a/mlir/lib/IR/Instruction.cpp b/mlir/lib/IR/Instruction.cpp index b8a3e5813293..6d74ed142571 100644 --- a/mlir/lib/IR/Instruction.cpp +++ b/mlir/lib/IR/Instruction.cpp @@ -126,9 +126,9 @@ bool Value::isValidSymbol() const { return op->isValidSymbol(); return false; } - // This value is either a function argument or an induction variable. - // Function argument is ok, induction variable is not. - return isa(this); + // Otherwise, the only valid symbol is a function argument. + auto *arg = dyn_cast(this); + return arg && arg->isFunctionArgument(); } void Instruction::setOperand(unsigned idx, Value *value) { @@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef lbOperands, ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap, AffineMap ubMap, int64_t step) - : Instruction(Instruction::Kind::For, location), - Value(Value::Kind::ForInst, - Type::getIndex(lbMap.getResult(0).getContext())), - body(this), lbMap(lbMap), ubMap(ubMap), step(step) { + : Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap), + ubMap(ubMap), step(step) { // The body of a for inst always has one block. - body.push_back(new Block()); + auto *bodyEntry = new Block(); + body.push_back(bodyEntry); + + // Add an argument to the block for the induction variable. + bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext())); + operands.reserve(numOperands); } @@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function callback) { v.walkPostOrder(this); } +/// Returns the induction variable for this loop. +Value *ForInst::getInductionVar() { return getBody()->getArgument(0); } + +/// Returns if the provided value is the induction variable of a ForInst. +bool mlir::isForInductionVar(const Value *val) { + return getForInductionVarOwner(val) != nullptr; +} + +/// Returns the loop parent of an induction variable. If the provided value is +/// not an induction variable, then return nullptr. +ForInst *mlir::getForInductionVarOwner(Value *val) { + const BlockArgument *ivArg = dyn_cast(val); + if (!ivArg || !ivArg->getOwner()) + return nullptr; + return dyn_cast_or_null( + ivArg->getOwner()->getParent()->getContainingInst()); +} +const ForInst *mlir::getForInductionVarOwner(const Value *val) { + return getForInductionVarOwner(const_cast(val)); +} + +/// Extracts the induction variables from a list of ForInsts and returns them. +SmallVector +mlir::extractForInductionVars(ArrayRef forInsts) { + SmallVector results; + for (auto *forInst : forInsts) + results.push_back(forInst->getInductionVar()); + return results; +} //===----------------------------------------------------------------------===// // IfInst //===----------------------------------------------------------------------===// @@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper, ubMap, forInst->getStep()); // Remember the induction variable mapping. - mapper.map(forInst, newFor); + mapper.map(forInst->getInductionVar(), newFor->getInductionVar()); // Recursively clone the body of the for loop. for (auto &subInst : *forInst->getBody()) diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index a2cb9910ab84..6418b062dc16 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -35,8 +35,6 @@ Function *Value::getFunction() { return cast(this)->getFunction(); case Value::Kind::InstResult: return getDefiningInst()->getFunction(); - case Value::Kind::ForInst: - return cast(this)->getFunction(); } } @@ -83,3 +81,9 @@ Function *BlockArgument::getFunction() { return owner->getFunction(); return nullptr; } + +/// Returns if the current argument is a function argument. +bool BlockArgument::isFunctionArgument() const { + auto *containingFn = getFunction(); + return containingFn && &containingFn->front() == getOwner(); +} diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index ecb7fbc779e5..c477ad1bbc5c 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3201,7 +3201,8 @@ ParseResult FunctionParser::parseForInst() { ubOperands, ubMap, step); // Create SSA value definition for the induction variable. - if (addDefinition({inductionVariableName, 0, loc}, forInst)) + if (addDefinition({inductionVariableName, 0, loc}, + forInst->getInductionVar())) return ParseFailure; // Try to parse the optional trailing location. @@ -3347,7 +3348,7 @@ ParseResult FunctionParser::parseBound(SmallVectorImpl &operands, // Create an identity map using dim id for an induction variable and // symbol otherwise. This representation is optimized for storage. // Analysis passes may expand it into a multi-dimensional map if desired. - if (isa(operands[0])) + if (isForInductionVar(operands[0])) map = builder.getDimIdentityMap(); else map = builder.getSymbolIdentityMap(); diff --git a/mlir/lib/Transforms/DmaGeneration.cpp b/mlir/lib/Transforms/DmaGeneration.cpp index 0437fb143e0a..04eb38e9fc94 100644 --- a/mlir/lib/Transforms/DmaGeneration.cpp +++ b/mlir/lib/Transforms/DmaGeneration.cpp @@ -171,7 +171,8 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols, getLoopIVs(*opInst, &ivs); auto *regionCst = region->getConstraints(); - SmallVector symbols(ivs.begin(), ivs.end()); + + SmallVector symbols = extractForInductionVars(ivs); regionCst->reset(rank, numSymbols, 0, symbols); // Memref dim sizes provide the bounds. diff --git a/mlir/lib/Transforms/LoopTiling.cpp b/mlir/lib/Transforms/LoopTiling.cpp index 2a4b7bcd262d..396fc8eb658c 100644 --- a/mlir/lib/Transforms/LoopTiling.cpp +++ b/mlir/lib/Transforms/LoopTiling.cpp @@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, auto mayBeConstantCount = getConstantTripCount(*origLoops[i]); // The lower bound is just the tile-space loop. AffineMap lbMap = b.getDimIdentityMap(); - newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap); + newLoops[width + i]->setLowerBound( + /*operands=*/newLoops[i]->getInductionVar(), lbMap); // Set the upper bound. if (mayBeConstantCount.hasValue() && @@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, // with 'i' (tile-space loop) appended to it. The new upper bound map is // the original one with an additional expression i + tileSize appended. SmallVector ubOperands(origLoops[i]->getUpperBoundOperands()); - ubOperands.push_back(newLoops[i]); + ubOperands.push_back(newLoops[i]->getInductionVar()); auto origUbMap = origLoops[i]->getUpperBoundMap(); SmallVector boundExprs; @@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef origLoops, // No need of the min expression. auto dim = b.getAffineDimExpr(0); auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {}); - newLoops[width + i]->setUpperBound(newLoops[i], ubMap); + newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap); } } } @@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef band, // Move the loop body of the original nest to the new one. moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop); - SmallVector origLoopIVs(band.begin(), band.end()); - SmallVector, 6> ids(band.begin(), band.end()); + SmallVector origLoopIVs = extractForInductionVars(band); + SmallVector, 6> ids(origLoopIVs.begin(), origLoopIVs.end()); FlatAffineConstraints cst; getIndexSet(band, &cst); @@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef band, constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes); // In this case, the point loop IVs just replace the original ones. for (unsigned i = 0; i < width; i++) { - origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]); + origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar()); } // Erase the old loop nest. diff --git a/mlir/lib/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Transforms/LoopUnrollAndJam.cpp index 71d778172548..a8ec57c04269 100644 --- a/mlir/lib/Transforms/LoopUnrollAndJam.cpp +++ b/mlir/lib/Transforms/LoopUnrollAndJam.cpp @@ -215,6 +215,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { int64_t step = forInst->getStep(); forInst->setStep(step * unrollJamFactor); + auto *forInstIV = forInst->getInductionVar(); for (auto &subBlock : subBlocks) { // Builder to insert unroll-jammed bodies. Insert right at the end of // sub-block. @@ -226,14 +227,15 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) { // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInst->use_empty()) { + if (!forInstIV->use_empty()) { // iv' = iv + i, i = 1 to unrollJamFactor-1. auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create(forInst->getLoc(), bumpMap, forInst) + builder + .create(forInst->getLoc(), bumpMap, forInstIV) ->getResult(0); - operandMapping.map(forInst, ivUnroll); + operandMapping.map(forInstIV, ivUnroll); } // Clone the sub-block being unroll-jammed. for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 99ee603bb05e..94f300bd16a2 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -348,7 +348,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) { oldBody->begin(), oldBody->end()); // The code in the body of the forInst now uses 'iv' as its indvar. - forInst->replaceAllUsesWith(iv); + forInst->getInductionVar()->replaceAllUsesWith(iv); // Append the induction variable stepping logic and branch back to the exit // condition block. Construct an affine expression f : (x -> x+step) and diff --git a/mlir/lib/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Transforms/PipelineDataTransfer.cpp index e72b9ef80df8..0019714b6a3d 100644 --- a/mlir/lib/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Transforms/PipelineDataTransfer.cpp @@ -121,8 +121,8 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) { int64_t step = forInst->getStep(); auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0.floorDiv(step) % 2}, {}); - auto ivModTwoOp = - bInner.create(forInst->getLoc(), modTwoMap, forInst); + auto ivModTwoOp = bInner.create(forInst->getLoc(), modTwoMap, + forInst->getInductionVar()); // replaceAllMemRefUsesWith will always succeed unless the forInst body has // non-deferencing uses of the memref. diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp index d41614545d21..03673eaa535a 100644 --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -99,24 +99,25 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) { return false; // Replaces all IV uses to its single iteration value. - if (!forInst->use_empty()) { + auto *iv = forInst->getInductionVar(); + if (!iv->use_empty()) { if (forInst->hasConstantLowerBound()) { auto *mlFunc = forInst->getFunction(); FuncBuilder topBuilder(mlFunc); auto constOp = topBuilder.create( forInst->getLoc(), forInst->getConstantLowerBound()); - forInst->replaceAllUsesWith(constOp); + iv->replaceAllUsesWith(constOp); } else { const AffineBound lb = forInst->getLowerBound(); SmallVector lbOperands(lb.operand_begin(), lb.operand_end()); FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst)); if (lb.getMap() == builder.getDimIdentityMap()) { // No need of generating an affine_apply. - forInst->replaceAllUsesWith(lbOperands[0]); + iv->replaceAllUsesWith(lbOperands[0]); } else { auto affineApplyOp = builder.create( forInst->getLoc(), lb.getMap(), lbOperands); - forInst->replaceAllUsesWith(affineApplyOp->getResult(0)); + iv->replaceAllUsesWith(affineApplyOp->getResult(0)); } } } @@ -161,6 +162,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap, ubOperands, ubMap, srcForInst->getStep()); + auto *loopChunkIV = loopChunk->getInductionVar(); + auto *srcIV = srcForInst->getInductionVar(); BlockAndValueMapping operandMap; @@ -172,17 +175,17 @@ generateLoop(AffineMap lbMap, AffineMap ubMap, // remapped to results of cloned instructions, and their IV used remapped. // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. - if (!srcForInst->use_empty() && shift != 0) { + if (!srcIV->use_empty() && shift != 0) { auto b = FuncBuilder::getForInstBodyBuilder(loopChunk); auto *ivRemap = b.create( srcForInst->getLoc(), b.getSingleDimShiftAffineMap(-static_cast( srcForInst->getStep() * shift)), - loopChunk) + loopChunkIV) ->getResult(0); - operandMap.map(srcForInst, ivRemap); + operandMap.map(srcIV, ivRemap); } else { - operandMap.map(srcForInst, loopChunk); + operandMap.map(srcIV, loopChunkIV); } for (auto *inst : insts) { loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext())); @@ -419,19 +422,20 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) { Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end()); // Unroll the contents of 'forInst' (append unrollFactor-1 additional copies). + auto *forInstIV = forInst->getInductionVar(); for (unsigned i = 1; i < unrollFactor; i++) { BlockAndValueMapping operandMap; // If the induction variable is used, create a remapping to the value for // this unrolled instance. - if (!forInst->use_empty()) { + if (!forInstIV->use_empty()) { // iv' = iv + 1/2/3...unrollFactor-1; auto d0 = builder.getAffineDimExpr(0); auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {}); auto *ivUnroll = - builder.create(forInst->getLoc(), bumpMap, forInst) + builder.create(forInst->getLoc(), bumpMap, forInstIV) ->getResult(0); - operandMap.map(forInst, ivUnroll); + operandMap.map(forInstIV, ivUnroll); } // Clone the original body of 'forInst'. diff --git a/mlir/lib/Transforms/Vectorize.cpp b/mlir/lib/Transforms/Vectorize.cpp index e9b37fcc04c2..cfde1ecf0a8d 100644 --- a/mlir/lib/Transforms/Vectorize.cpp +++ b/mlir/lib/Transforms/Vectorize.cpp @@ -881,8 +881,9 @@ static bool vectorizeForInst(ForInst *loop, int64_t step, auto load = opInst->dyn_cast(); auto store = opInst->dyn_cast(); LLVM_DEBUG(opInst->print(dbgs())); - auto fail = load ? vectorizeRootOrTerminal(loop, load, state) - : vectorizeRootOrTerminal(loop, store, state); + auto fail = + load ? vectorizeRootOrTerminal(loop->getInductionVar(), load, state) + : vectorizeRootOrTerminal(loop->getInductionVar(), store, state); if (fail) { return fail; } @@ -1210,7 +1211,8 @@ static bool vectorizeRootMatches(NestedMatch matches, /// RAII. ScopeGuard sg2([&fail, loop, clonedLoop]() { if (fail) { - loop->replaceAllUsesWith(clonedLoop); + loop->getInductionVar()->replaceAllUsesWith( + clonedLoop->getInductionVar()); loop->erase(); } else { clonedLoop->erase();