Change the ForInst induction variable to be a block argument of the body instead of the ForInst itself. This is a necessary step in converting ForInst into an operation.

PiperOrigin-RevId: 231064139
This commit is contained in:
River Riddle 2019-01-26 12:40:12 -08:00 committed by jpienaar
parent 0e7a8a9027
commit 36babbd781
21 changed files with 172 additions and 135 deletions

View File

@ -555,19 +555,12 @@ inline auto OperationInst::getResultTypes() const
}
/// For instruction represents an affine loop nest.
class ForInst : public Instruction, public Value {
class ForInst : public Instruction {
public:
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
AffineMap lbMap, ArrayRef<Value *> ubOperands,
AffineMap ubMap, int64_t step);
~ForInst() {
// There may be references to the induction variable of this loop within its
// body or, in case of ill-formed code during parsing, outside its body.
// Explicitly drop all uses of the induction variable before destroying it.
dropAllUses();
}
/// Resolve base class ambiguity.
using Instruction::getFunction;
@ -700,7 +693,9 @@ public:
//===--------------------------------------------------------------------===//
/// Return the context this operation is associated with.
MLIRContext *getContext() const { return getType().getContext(); }
MLIRContext *getContext() const {
return getInductionVar()->getType().getContext();
}
using Instruction::dump;
using Instruction::print;
@ -710,11 +705,10 @@ public:
return ptr->getKind() == IROperandOwner::Kind::ForInst;
}
// For instruction represents implicitly represents induction variable by
// inheriting from Value class. Whenever you need to refer to the loop
// induction variable, just use the for instruction itself.
static bool classof(const Value *value) {
return value->getKind() == Value::Kind::ForInst;
/// Returns the induction variable for this loop.
Value *getInductionVar();
const Value *getInductionVar() const {
return const_cast<ForInst *>(this)->getInductionVar();
}
private:
@ -738,6 +732,17 @@ private:
AffineMap ubMap, int64_t step);
};
/// Returns if the provided value is the induction variable of a ForInst.
bool isForInductionVar(const Value *val);
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForInst *getForInductionVarOwner(Value *val);
const ForInst *getForInductionVarOwner(const Value *val);
/// Extracts the induction variables from a list of ForInsts and returns them.
SmallVector<Value *, 8> extractForInductionVars(ArrayRef<ForInst *> forInsts);
/// AffineBound represents a lower or upper bound in the for instruction.
/// This class does not own the underlying operands. Instead, it refers
/// to the operands stored in the ForInst. Its life span should not exceed

View File

@ -45,7 +45,6 @@ public:
enum class Kind {
BlockArgument, // block argument
InstResult, // operation instruction result
ForInst, // 'for' instruction induction variable
};
~Value() {}
@ -141,6 +140,9 @@ public:
/// Returns the number of this argument.
unsigned getArgNumber() const;
/// Returns if the current argument is a function argument.
bool isFunctionArgument() const;
private:
friend class Block; // For access to private constructor.
BlockArgument(Type type, Block *owner)

View File

@ -555,7 +555,7 @@ void mlir::getReachableAffineApplyOps(
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
FlatAffineConstraints *domain) {
SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
auto indices = extractForInductionVars(forInsts);
// Reset while associated Values in 'indices' to the domain.
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
for (auto *forInst : forInsts) {
@ -677,7 +677,7 @@ static void buildDimAndSymbolPositionMaps(
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
for (unsigned i = 0, e = values.size(); i < e; ++i) {
auto *value = values[i];
if (!isa<ForInst>(values[i])) {
if (!isForInductionVar(values[i])) {
assert(values[i]->isValidSymbol() &&
"access operand has to be either a loop IV or a symbol");
valuePosMap->addSymbolValue(value);
@ -739,7 +739,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
// Set values for the symbolic identifier dimensions.
auto setSymbolIds = [&](ArrayRef<Value *> values) {
for (auto *value : values) {
if (!isa<ForInst>(value)) {
if (!isForInductionVar(value)) {
assert(value->isValidSymbol() && "expected symbol");
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
}
@ -907,7 +907,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
// Add equality constraints for any operands that are defined by constant ops.
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
if (isa<ForInst>(operands[i]))
if (isForInductionVar(operands[i]))
continue;
auto *symbol = operands[i];
assert(symbol->isValidSymbol());
@ -976,8 +976,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
unsigned numCommonLoops = 0;
for (unsigned i = 0; i < minNumLoops; ++i) {
if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
!isa<ForInst>(dstDomain.getIdValue(i)) ||
if (!isForInductionVar(srcDomain.getIdValue(i)) ||
!isForInductionVar(dstDomain.getIdValue(i)) ||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
break;
++numCommonLoops;
@ -998,8 +998,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
return block;
}
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
assert(isa<ForInst>(commonForValue));
return cast<ForInst>(commonForValue)->getBody();
auto *forInst = getForInductionVarOwner(commonForValue);
assert(forInst && "commonForValue was not an induction variable");
return forInst->getBody();
}
// Returns true if the ancestor operation instruction of 'srcAccess' appears

View File

@ -1251,7 +1251,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
unsigned pos;
// Pre-condition for this method.
if (!findId(forInst, &pos)) {
if (!findId(*forInst.getInductionVar(), &pos)) {
assert(0 && "Value not found");
return false;
}

View File

@ -53,9 +53,9 @@ bool DominanceInfo::properlyDominates(const Block *a, const Block *b) {
if (blockListA == blockListB)
return DominatorTreeBase::properlyDominates(a, b);
// Otherwise, 'a' properly dominates 'b' if 'b' is defined in an
// IfInst/ForInst that (recursively) ends up being dominated by 'a'. Walk up
// the list of containers enclosing B.
// Otherwise, 'a' properly dominates 'b' if 'b' is defined in an instruction
// region that (recursively) ends up being dominated by 'a'. Walk up the list
// of containers enclosing B.
Instruction *bAncestor;
do {
bAncestor = blockListB->getContainingInst();
@ -106,11 +106,6 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
if (auto *aInst = a->getDefiningInst())
return properlyDominates(aInst, b);
// The induction variable of a ForInst properly dominantes its body, so we
// can just do a simple block dominance check.
if (auto *forInst = dyn_cast<ForInst>(a))
return dominates(forInst->getBody(), b->getBlock());
// block arguments properly dominate all instructions in their own block, so
// we use a dominates check here, not a properlyDominates check.
return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());

View File

@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
}
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
assert(isa<ForInst>(iv) && "iv must be a ForInst");
assert(isForInductionVar(&iv) && "iv must be a ForInst");
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
SmallVector<OperationInst *, 4> affineApplyOps;
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
@ -288,8 +288,10 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
auto load = op.dyn_cast<LoadOp>();
auto store = op.dyn_cast<StoreOp>();
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
: isContiguousAccess(loop, *store, fastestVaryingDim);
return load ? isContiguousAccess(*loop.getInductionVar(), *load,
fastestVaryingDim)
: isContiguousAccess(*loop.getInductionVar(), *store,
fastestVaryingDim);
});
return isVectorizableLoopWithCond(loop, fun);
}

View File

@ -64,7 +64,7 @@ void mlir::getForwardSlice(Instruction *inst,
}
}
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
for (auto &u : forInst->getUses()) {
for (auto &u : forInst->getInductionVar()->getUses()) {
auto *ownerInst = u.getOwner();
if (forwardSlice->count(ownerInst) == 0) {
getForwardSlice(ownerInst, forwardSlice, filter,

View File

@ -149,7 +149,8 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
// A rank 0 memref has a 0-d region.
SmallVector<ForInst *, 4> ivs;
getLoopIVs(*opInst, &ivs);
SmallVector<Value *, 4> regionSymbols(ivs.begin(), ivs.end());
SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs);
regionCst->reset(0, loopDepth, 0, regionSymbols);
return true;
}
@ -172,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
unsigned numSymbols = accessMap.getNumSymbols();
// Add inequalties for loop lower/upper bounds.
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
// Note that regionCst can now have more dimensions than accessMap if the
// bounds expressions involve outer loops or other symbols.
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
@ -207,7 +208,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
outerIVs.resize(loopDepth);
for (auto *operand : accessValueMap.getOperands()) {
ForInst *iv;
if ((iv = dyn_cast<ForInst>(operand)) &&
if ((iv = getForInductionVarOwner(operand)) &&
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
regionCst->projectOut(operand);
}

View File

@ -113,7 +113,8 @@ static AffineMap makePermutationMap(
getAffineConstantExpr(0, context));
for (auto kvp : enclosingLoopToVectorDim) {
assert(kvp.second < perm.size());
auto invariants = getInvariantAccesses(*kvp.first, unwrappedIndices);
auto invariants =
getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices);
unsigned numIndices = unwrappedIndices.size();
unsigned countInvariantIndices = 0;
for (unsigned dim = 0; dim < numIndices; ++dim) {

View File

@ -133,9 +133,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
inst->print(os);
return;
}
// &v is required here otherwise we get:
// non-pointer operand type 'const mlir::ForInst' incompatible with nullptr
if (auto *forInst = dyn_cast<ForInst>(&v)) {
if (auto *forInst = getForInductionVarOwner(&v)) {
forInst->print(os);
} else {
os << "unknown_ssa_value";
@ -296,7 +294,7 @@ Value *MLIREmitter::emit(Expr e) {
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
auto step =
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
res = builder->createFor(location, lb, ub, step);
res = builder->createFor(location, lb, ub, step)->getInductionVar();
}
}
@ -347,7 +345,8 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
bind(stmt.getLHS(), val);
if (stmt.getRHS().getKind() == ExprKind::For) {
// Step into the loop.
builder->setInsertionPointToStart(cast<ForInst>(val)->getBody());
builder->setInsertionPointToStart(
getForInductionVarOwner(val)->getBody());
}
}
emitStmts(stmt.getEnclosedStmts());

View File

@ -1078,7 +1078,7 @@ public:
void print(const OperationInst *inst);
void print(const ForInst *inst);
void print(const IfInst *inst);
void print(const Block *block);
void print(const Block *block, bool printBlockArgs = true);
void printOperation(const OperationInst *op);
void printGenericOp(const OperationInst *op);
@ -1125,10 +1125,15 @@ public:
unsigned index) override;
/// Print a block list.
void printBlockList(const BlockList &blocks) {
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
os << " {\n";
for (auto &b : blocks)
print(&b);
if (!blocks.empty()) {
auto *entryBlock = &blocks.front();
print(entryBlock,
printEntryBlockArgs && entryBlock->getNumArguments() != 0);
for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
print(&b);
}
os.indent(currentIndent) << "}";
}
@ -1164,8 +1169,8 @@ private:
/// This is the next value ID to assign in numbering.
unsigned nextValueID = 0;
/// This is the ID to assign to the next induction variable.
unsigned nextLoopID = 0;
/// This is the ID to assign to the next region entry block argument.
unsigned nextRegionArgumentID = 0;
/// This is the next ID to assign to a Function argument.
unsigned nextArgumentID = 0;
/// This is the next ID to assign when a name conflict is detected.
@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
numberValuesInBlock(block);
break;
}
case Instruction::Kind::For: {
auto *forInst = cast<ForInst>(&inst);
// Number the induction variable.
numberValueID(forInst);
case Instruction::Kind::For:
// Recursively number the stuff in the body.
numberValuesInBlock(*forInst->getBody());
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
break;
}
case Instruction::Kind::If: {
auto *ifInst = cast<IfInst>(&inst);
numberValuesInBlock(*ifInst->getThen());
@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) {
if (specialNameBuffer.empty()) {
switch (value->getKind()) {
case Value::Kind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->getFunction())
if (&fn->getBlockList().front() == block) {
// If this is an argument to the function, give it an 'arg' name. If the
// argument is to an entry block of an operation region, give it an 'i'
// name.
if (auto *block = cast<BlockArgument>(value)->getOwner()) {
auto *parentBlockList = block->getParent();
if (parentBlockList && block == &parentBlockList->front()) {
if (parentBlockList->getContainingFunction())
specialName << "arg" << nextArgumentID++;
break;
}
else
specialName << "i" << nextRegionArgumentID++;
break;
}
}
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) {
// done with it.
valueIDs[value] = nextValueID++;
return;
case Value::Kind::ForInst:
specialName << 'i' << nextLoopID++;
break;
}
}
@ -1312,10 +1316,8 @@ void FunctionPrinter::print() {
printTrailingLocation(function->getLoc());
if (!function->empty()) {
os << " {\n";
for (const auto &block : *function)
print(&block);
os << "}\n";
printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false);
os << "\n";
}
os << '\n';
}
@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() {
}
}
/// Return true if the introducer for the specified block should be printed.
static bool shouldPrintBlockArguments(const Block *block) {
// Never print the entry block of the function - it is included in the
// argument list.
if (block == &block->getFunction()->front())
return false;
// If this is the first block in a nested region, and if there are no
// arguments, then we can omit it.
if (block == &block->getParent()->front() && block->getNumArguments() == 0)
return false;
// Otherwise print it.
return true;
}
void FunctionPrinter::print(const Block *block) {
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
// Print the block label and argument list, unless this is the first block of
// the function, or the first block of an IfInst/ForInst with no arguments.
if (shouldPrintBlockArguments(block)) {
if (printBlockArgs) {
os.indent(currentIndent);
printBlockName(block);
@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) {
void FunctionPrinter::print(const ForInst *inst) {
os.indent(currentIndent) << "for ";
printOperand(inst);
printOperand(inst->getInductionVar());
os << " = ";
printBound(inst->getLowerBound(), "max");
os << " to ";
@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) {
printTrailingLocation(inst->getLoc());
os << " {\n";
print(inst->getBody());
print(inst->getBody(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) {
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
printTrailingLocation(inst->getLoc());
os << " {\n";
print(inst->getThen());
print(inst->getThen(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
if (inst->hasElse()) {
os << " else {\n";
print(inst->getElse());
print(inst->getElse(), /*printBlockArgs=*/false);
os.indent(currentIndent) << "}";
}
}
@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) {
// Print any trailing block lists.
for (auto &blockList : op->getBlockLists())
printBlockList(blockList);
printBlockList(blockList, /*printEntryBlockArgs=*/true);
}
void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const {
return;
case Value::Kind::InstResult:
return getDefiningInst()->print(os);
case Value::Kind::ForInst:
return cast<ForInst>(this)->print(os);
}
}

View File

@ -126,9 +126,9 @@ bool Value::isValidSymbol() const {
return op->isValidSymbol();
return false;
}
// This value is either a function argument or an induction variable.
// Function argument is ok, induction variable is not.
return isa<BlockArgument>(this);
// Otherwise, the only valid symbol is a function argument.
auto *arg = dyn_cast<BlockArgument>(this);
return arg && arg->isFunctionArgument();
}
void Instruction::setOperand(unsigned idx, Value *value) {
@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
AffineMap ubMap, int64_t step)
: Instruction(Instruction::Kind::For, location),
Value(Value::Kind::ForInst,
Type::getIndex(lbMap.getResult(0).getContext())),
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
: Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
ubMap(ubMap), step(step) {
// The body of a for inst always has one block.
body.push_back(new Block());
auto *bodyEntry = new Block();
body.push_back(bodyEntry);
// Add an argument to the block for the induction variable.
bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
operands.reserve(numOperands);
}
@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
v.walkPostOrder(this);
}
/// Returns the induction variable for this loop.
Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
/// Returns if the provided value is the induction variable of a ForInst.
bool mlir::isForInductionVar(const Value *val) {
return getForInductionVarOwner(val) != nullptr;
}
/// Returns the loop parent of an induction variable. If the provided value is
/// not an induction variable, then return nullptr.
ForInst *mlir::getForInductionVarOwner(Value *val) {
const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
if (!ivArg || !ivArg->getOwner())
return nullptr;
return dyn_cast_or_null<ForInst>(
ivArg->getOwner()->getParent()->getContainingInst());
}
const ForInst *mlir::getForInductionVarOwner(const Value *val) {
return getForInductionVarOwner(const_cast<Value *>(val));
}
/// Extracts the induction variables from a list of ForInsts and returns them.
SmallVector<Value *, 8>
mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
SmallVector<Value *, 8> results;
for (auto *forInst : forInsts)
results.push_back(forInst->getInductionVar());
return results;
}
//===----------------------------------------------------------------------===//
// IfInst
//===----------------------------------------------------------------------===//
@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
ubMap, forInst->getStep());
// Remember the induction variable mapping.
mapper.map(forInst, newFor);
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
// Recursively clone the body of the for loop.
for (auto &subInst : *forInst->getBody())

View File

@ -35,8 +35,6 @@ Function *Value::getFunction() {
return cast<BlockArgument>(this)->getFunction();
case Value::Kind::InstResult:
return getDefiningInst()->getFunction();
case Value::Kind::ForInst:
return cast<ForInst>(this)->getFunction();
}
}
@ -83,3 +81,9 @@ Function *BlockArgument::getFunction() {
return owner->getFunction();
return nullptr;
}
/// Returns if the current argument is a function argument.
bool BlockArgument::isFunctionArgument() const {
auto *containingFn = getFunction();
return containingFn && &containingFn->front() == getOwner();
}

View File

@ -3201,7 +3201,8 @@ ParseResult FunctionParser::parseForInst() {
ubOperands, ubMap, step);
// Create SSA value definition for the induction variable.
if (addDefinition({inductionVariableName, 0, loc}, forInst))
if (addDefinition({inductionVariableName, 0, loc},
forInst->getInductionVar()))
return ParseFailure;
// Try to parse the optional trailing location.
@ -3347,7 +3348,7 @@ ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
// Create an identity map using dim id for an induction variable and
// symbol otherwise. This representation is optimized for storage.
// Analysis passes may expand it into a multi-dimensional map if desired.
if (isa<ForInst>(operands[0]))
if (isForInductionVar(operands[0]))
map = builder.getDimIdentityMap();
else
map = builder.getSymbolIdentityMap();

View File

@ -171,7 +171,8 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
getLoopIVs(*opInst, &ivs);
auto *regionCst = region->getConstraints();
SmallVector<Value *, 4> symbols(ivs.begin(), ivs.end());
SmallVector<Value *, 8> symbols = extractForInductionVars(ivs);
regionCst->reset(rank, numSymbols, 0, symbols);
// Memref dim sizes provide the bounds.

View File

@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
// The lower bound is just the tile-space loop.
AffineMap lbMap = b.getDimIdentityMap();
newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
newLoops[width + i]->setLowerBound(
/*operands=*/newLoops[i]->getInductionVar(), lbMap);
// Set the upper bound.
if (mayBeConstantCount.hasValue() &&
@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// with 'i' (tile-space loop) appended to it. The new upper bound map is
// the original one with an additional expression i + tileSize appended.
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
ubOperands.push_back(newLoops[i]);
ubOperands.push_back(newLoops[i]->getInductionVar());
auto origUbMap = origLoops[i]->getUpperBoundMap();
SmallVector<AffineExpr, 4> boundExprs;
@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
// No need of the min expression.
auto dim = b.getAffineDimExpr(0);
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap);
}
}
}
@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
// Move the loop body of the original nest to the new one.
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
SmallVector<Value *, 8> origLoopIVs = extractForInductionVars(band);
SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
FlatAffineConstraints cst;
getIndexSet(band, &cst);
@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
// In this case, the point loop IVs just replace the original ones.
for (unsigned i = 0; i < width; i++) {
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar());
}
// Erase the old loop nest.

View File

@ -215,6 +215,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
int64_t step = forInst->getStep();
forInst->setStep(step * unrollJamFactor);
auto *forInstIV = forInst->getInductionVar();
for (auto &subBlock : subBlocks) {
// Builder to insert unroll-jammed bodies. Insert right at the end of
// sub-block.
@ -226,14 +227,15 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forInst->use_empty()) {
if (!forInstIV->use_empty()) {
// iv' = iv + i, i = 1 to unrollJamFactor-1.
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
builder
.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
->getResult(0);
operandMapping.map(forInst, ivUnroll);
operandMapping.map(forInstIV, ivUnroll);
}
// Clone the sub-block being unroll-jammed.
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {

View File

@ -348,7 +348,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
oldBody->begin(), oldBody->end());
// The code in the body of the forInst now uses 'iv' as its indvar.
forInst->replaceAllUsesWith(iv);
forInst->getInductionVar()->replaceAllUsesWith(iv);
// Append the induction variable stepping logic and branch back to the exit
// condition block. Construct an affine expression f : (x -> x+step) and

View File

@ -121,8 +121,8 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
int64_t step = forInst->getStep();
auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
{d0.floorDiv(step) % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap,
forInst->getInductionVar());
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
// non-deferencing uses of the memref.

View File

@ -99,24 +99,25 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
return false;
// Replaces all IV uses to its single iteration value.
if (!forInst->use_empty()) {
auto *iv = forInst->getInductionVar();
if (!iv->use_empty()) {
if (forInst->hasConstantLowerBound()) {
auto *mlFunc = forInst->getFunction();
FuncBuilder topBuilder(mlFunc);
auto constOp = topBuilder.create<ConstantIndexOp>(
forInst->getLoc(), forInst->getConstantLowerBound());
forInst->replaceAllUsesWith(constOp);
iv->replaceAllUsesWith(constOp);
} else {
const AffineBound lb = forInst->getLowerBound();
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
if (lb.getMap() == builder.getDimIdentityMap()) {
// No need of generating an affine_apply.
forInst->replaceAllUsesWith(lbOperands[0]);
iv->replaceAllUsesWith(lbOperands[0]);
} else {
auto affineApplyOp = builder.create<AffineApplyOp>(
forInst->getLoc(), lb.getMap(), lbOperands);
forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
iv->replaceAllUsesWith(affineApplyOp->getResult(0));
}
}
}
@ -161,6 +162,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
ubOperands, ubMap, srcForInst->getStep());
auto *loopChunkIV = loopChunk->getInductionVar();
auto *srcIV = srcForInst->getInductionVar();
BlockAndValueMapping operandMap;
@ -172,17 +175,17 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
// remapped to results of cloned instructions, and their IV used remapped.
// Generate the remapping if the shift is not zero: remappedIV = newIV -
// shift.
if (!srcForInst->use_empty() && shift != 0) {
if (!srcIV->use_empty() && shift != 0) {
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
auto *ivRemap = b.create<AffineApplyOp>(
srcForInst->getLoc(),
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
srcForInst->getStep() * shift)),
loopChunk)
loopChunkIV)
->getResult(0);
operandMap.map(srcForInst, ivRemap);
operandMap.map(srcIV, ivRemap);
} else {
operandMap.map(srcForInst, loopChunk);
operandMap.map(srcIV, loopChunkIV);
}
for (auto *inst : insts) {
loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
@ -419,19 +422,20 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
auto *forInstIV = forInst->getInductionVar();
for (unsigned i = 1; i < unrollFactor; i++) {
BlockAndValueMapping operandMap;
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
if (!forInst->use_empty()) {
if (!forInstIV->use_empty()) {
// iv' = iv + 1/2/3...unrollFactor-1;
auto d0 = builder.getAffineDimExpr(0);
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
auto *ivUnroll =
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
->getResult(0);
operandMap.map(forInst, ivUnroll);
operandMap.map(forInstIV, ivUnroll);
}
// Clone the original body of 'forInst'.

View File

@ -881,8 +881,9 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
auto load = opInst->dyn_cast<LoadOp>();
auto store = opInst->dyn_cast<StoreOp>();
LLVM_DEBUG(opInst->print(dbgs()));
auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
: vectorizeRootOrTerminal(loop, store, state);
auto fail =
load ? vectorizeRootOrTerminal(loop->getInductionVar(), load, state)
: vectorizeRootOrTerminal(loop->getInductionVar(), store, state);
if (fail) {
return fail;
}
@ -1210,7 +1211,8 @@ static bool vectorizeRootMatches(NestedMatch matches,
/// RAII.
ScopeGuard sg2([&fail, loop, clonedLoop]() {
if (fail) {
loop->replaceAllUsesWith(clonedLoop);
loop->getInductionVar()->replaceAllUsesWith(
clonedLoop->getInductionVar());
loop->erase();
} else {
clonedLoop->erase();