Change the ForInst induction variable to be a block argument of the body instead of the ForInst itself. This is a necessary step in converting ForInst into an operation.
PiperOrigin-RevId: 231064139
This commit is contained in:
parent
0e7a8a9027
commit
36babbd781
|
@ -555,19 +555,12 @@ inline auto OperationInst::getResultTypes() const
|
|||
}
|
||||
|
||||
/// For instruction represents an affine loop nest.
|
||||
class ForInst : public Instruction, public Value {
|
||||
class ForInst : public Instruction {
|
||||
public:
|
||||
static ForInst *create(Location location, ArrayRef<Value *> lbOperands,
|
||||
AffineMap lbMap, ArrayRef<Value *> ubOperands,
|
||||
AffineMap ubMap, int64_t step);
|
||||
|
||||
~ForInst() {
|
||||
// There may be references to the induction variable of this loop within its
|
||||
// body or, in case of ill-formed code during parsing, outside its body.
|
||||
// Explicitly drop all uses of the induction variable before destroying it.
|
||||
dropAllUses();
|
||||
}
|
||||
|
||||
/// Resolve base class ambiguity.
|
||||
using Instruction::getFunction;
|
||||
|
||||
|
@ -700,7 +693,9 @@ public:
|
|||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Return the context this operation is associated with.
|
||||
MLIRContext *getContext() const { return getType().getContext(); }
|
||||
MLIRContext *getContext() const {
|
||||
return getInductionVar()->getType().getContext();
|
||||
}
|
||||
|
||||
using Instruction::dump;
|
||||
using Instruction::print;
|
||||
|
@ -710,11 +705,10 @@ public:
|
|||
return ptr->getKind() == IROperandOwner::Kind::ForInst;
|
||||
}
|
||||
|
||||
// For instruction represents implicitly represents induction variable by
|
||||
// inheriting from Value class. Whenever you need to refer to the loop
|
||||
// induction variable, just use the for instruction itself.
|
||||
static bool classof(const Value *value) {
|
||||
return value->getKind() == Value::Kind::ForInst;
|
||||
/// Returns the induction variable for this loop.
|
||||
Value *getInductionVar();
|
||||
const Value *getInductionVar() const {
|
||||
return const_cast<ForInst *>(this)->getInductionVar();
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -738,6 +732,17 @@ private:
|
|||
AffineMap ubMap, int64_t step);
|
||||
};
|
||||
|
||||
/// Returns if the provided value is the induction variable of a ForInst.
|
||||
bool isForInductionVar(const Value *val);
|
||||
|
||||
/// Returns the loop parent of an induction variable. If the provided value is
|
||||
/// not an induction variable, then return nullptr.
|
||||
ForInst *getForInductionVarOwner(Value *val);
|
||||
const ForInst *getForInductionVarOwner(const Value *val);
|
||||
|
||||
/// Extracts the induction variables from a list of ForInsts and returns them.
|
||||
SmallVector<Value *, 8> extractForInductionVars(ArrayRef<ForInst *> forInsts);
|
||||
|
||||
/// AffineBound represents a lower or upper bound in the for instruction.
|
||||
/// This class does not own the underlying operands. Instead, it refers
|
||||
/// to the operands stored in the ForInst. Its life span should not exceed
|
||||
|
|
|
@ -45,7 +45,6 @@ public:
|
|||
enum class Kind {
|
||||
BlockArgument, // block argument
|
||||
InstResult, // operation instruction result
|
||||
ForInst, // 'for' instruction induction variable
|
||||
};
|
||||
|
||||
~Value() {}
|
||||
|
@ -141,6 +140,9 @@ public:
|
|||
/// Returns the number of this argument.
|
||||
unsigned getArgNumber() const;
|
||||
|
||||
/// Returns if the current argument is a function argument.
|
||||
bool isFunctionArgument() const;
|
||||
|
||||
private:
|
||||
friend class Block; // For access to private constructor.
|
||||
BlockArgument(Type type, Block *owner)
|
||||
|
|
|
@ -555,7 +555,7 @@ void mlir::getReachableAffineApplyOps(
|
|||
// setExprStride(ArrayRef<int64_t> expr, int64_t stride)
|
||||
bool mlir::getIndexSet(ArrayRef<ForInst *> forInsts,
|
||||
FlatAffineConstraints *domain) {
|
||||
SmallVector<Value *, 4> indices(forInsts.begin(), forInsts.end());
|
||||
auto indices = extractForInductionVars(forInsts);
|
||||
// Reset while associated Values in 'indices' to the domain.
|
||||
domain->reset(forInsts.size(), /*numSymbols=*/0, /*numLocals=*/0, indices);
|
||||
for (auto *forInst : forInsts) {
|
||||
|
@ -677,7 +677,7 @@ static void buildDimAndSymbolPositionMaps(
|
|||
auto updateValuePosMap = [&](ArrayRef<Value *> values, bool isSrc) {
|
||||
for (unsigned i = 0, e = values.size(); i < e; ++i) {
|
||||
auto *value = values[i];
|
||||
if (!isa<ForInst>(values[i])) {
|
||||
if (!isForInductionVar(values[i])) {
|
||||
assert(values[i]->isValidSymbol() &&
|
||||
"access operand has to be either a loop IV or a symbol");
|
||||
valuePosMap->addSymbolValue(value);
|
||||
|
@ -739,7 +739,7 @@ void initDependenceConstraints(const FlatAffineConstraints &srcDomain,
|
|||
// Set values for the symbolic identifier dimensions.
|
||||
auto setSymbolIds = [&](ArrayRef<Value *> values) {
|
||||
for (auto *value : values) {
|
||||
if (!isa<ForInst>(value)) {
|
||||
if (!isForInductionVar(value)) {
|
||||
assert(value->isValidSymbol() && "expected symbol");
|
||||
dependenceConstraints->setIdValue(valuePosMap.getSymPos(value), value);
|
||||
}
|
||||
|
@ -907,7 +907,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
|
|||
// Add equality constraints for any operands that are defined by constant ops.
|
||||
auto addEqForConstOperands = [&](ArrayRef<const Value *> operands) {
|
||||
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
|
||||
if (isa<ForInst>(operands[i]))
|
||||
if (isForInductionVar(operands[i]))
|
||||
continue;
|
||||
auto *symbol = operands[i];
|
||||
assert(symbol->isValidSymbol());
|
||||
|
@ -976,8 +976,8 @@ static unsigned getNumCommonLoops(const FlatAffineConstraints &srcDomain,
|
|||
std::min(srcDomain.getNumDimIds(), dstDomain.getNumDimIds());
|
||||
unsigned numCommonLoops = 0;
|
||||
for (unsigned i = 0; i < minNumLoops; ++i) {
|
||||
if (!isa<ForInst>(srcDomain.getIdValue(i)) ||
|
||||
!isa<ForInst>(dstDomain.getIdValue(i)) ||
|
||||
if (!isForInductionVar(srcDomain.getIdValue(i)) ||
|
||||
!isForInductionVar(dstDomain.getIdValue(i)) ||
|
||||
srcDomain.getIdValue(i) != dstDomain.getIdValue(i))
|
||||
break;
|
||||
++numCommonLoops;
|
||||
|
@ -998,8 +998,9 @@ static const Block *getCommonBlock(const MemRefAccess &srcAccess,
|
|||
return block;
|
||||
}
|
||||
auto *commonForValue = srcDomain.getIdValue(numCommonLoops - 1);
|
||||
assert(isa<ForInst>(commonForValue));
|
||||
return cast<ForInst>(commonForValue)->getBody();
|
||||
auto *forInst = getForInductionVarOwner(commonForValue);
|
||||
assert(forInst && "commonForValue was not an induction variable");
|
||||
return forInst->getBody();
|
||||
}
|
||||
|
||||
// Returns true if the ancestor operation instruction of 'srcAccess' appears
|
||||
|
|
|
@ -1251,7 +1251,7 @@ void FlatAffineConstraints::setDimSymbolSeparation(unsigned newSymbolCount) {
|
|||
bool FlatAffineConstraints::addForInstDomain(const ForInst &forInst) {
|
||||
unsigned pos;
|
||||
// Pre-condition for this method.
|
||||
if (!findId(forInst, &pos)) {
|
||||
if (!findId(*forInst.getInductionVar(), &pos)) {
|
||||
assert(0 && "Value not found");
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -53,9 +53,9 @@ bool DominanceInfo::properlyDominates(const Block *a, const Block *b) {
|
|||
if (blockListA == blockListB)
|
||||
return DominatorTreeBase::properlyDominates(a, b);
|
||||
|
||||
// Otherwise, 'a' properly dominates 'b' if 'b' is defined in an
|
||||
// IfInst/ForInst that (recursively) ends up being dominated by 'a'. Walk up
|
||||
// the list of containers enclosing B.
|
||||
// Otherwise, 'a' properly dominates 'b' if 'b' is defined in an instruction
|
||||
// region that (recursively) ends up being dominated by 'a'. Walk up the list
|
||||
// of containers enclosing B.
|
||||
Instruction *bAncestor;
|
||||
do {
|
||||
bAncestor = blockListB->getContainingInst();
|
||||
|
@ -106,11 +106,6 @@ bool DominanceInfo::properlyDominates(const Value *a, const Instruction *b) {
|
|||
if (auto *aInst = a->getDefiningInst())
|
||||
return properlyDominates(aInst, b);
|
||||
|
||||
// The induction variable of a ForInst properly dominantes its body, so we
|
||||
// can just do a simple block dominance check.
|
||||
if (auto *forInst = dyn_cast<ForInst>(a))
|
||||
return dominates(forInst->getBody(), b->getBlock());
|
||||
|
||||
// block arguments properly dominate all instructions in their own block, so
|
||||
// we use a dominates check here, not a properlyDominates check.
|
||||
return dominates(cast<BlockArgument>(a)->getOwner(), b->getBlock());
|
||||
|
|
|
@ -125,7 +125,7 @@ uint64_t mlir::getLargestDivisorOfTripCount(const ForInst &forInst) {
|
|||
}
|
||||
|
||||
bool mlir::isAccessInvariant(const Value &iv, const Value &index) {
|
||||
assert(isa<ForInst>(iv) && "iv must be a ForInst");
|
||||
assert(isForInductionVar(&iv) && "iv must be a ForInst");
|
||||
assert(index.getType().isa<IndexType>() && "index must be of IndexType");
|
||||
SmallVector<OperationInst *, 4> affineApplyOps;
|
||||
getReachableAffineApplyOps({const_cast<Value *>(&index)}, affineApplyOps);
|
||||
|
@ -288,8 +288,10 @@ bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
|
|||
[fastestVaryingDim](const ForInst &loop, const OperationInst &op) {
|
||||
auto load = op.dyn_cast<LoadOp>();
|
||||
auto store = op.dyn_cast<StoreOp>();
|
||||
return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
|
||||
: isContiguousAccess(loop, *store, fastestVaryingDim);
|
||||
return load ? isContiguousAccess(*loop.getInductionVar(), *load,
|
||||
fastestVaryingDim)
|
||||
: isContiguousAccess(*loop.getInductionVar(), *store,
|
||||
fastestVaryingDim);
|
||||
});
|
||||
return isVectorizableLoopWithCond(loop, fun);
|
||||
}
|
||||
|
|
|
@ -64,7 +64,7 @@ void mlir::getForwardSlice(Instruction *inst,
|
|||
}
|
||||
}
|
||||
} else if (auto *forInst = dyn_cast<ForInst>(inst)) {
|
||||
for (auto &u : forInst->getUses()) {
|
||||
for (auto &u : forInst->getInductionVar()->getUses()) {
|
||||
auto *ownerInst = u.getOwner();
|
||||
if (forwardSlice->count(ownerInst) == 0) {
|
||||
getForwardSlice(ownerInst, forwardSlice, filter,
|
||||
|
|
|
@ -149,7 +149,8 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
|
|||
// A rank 0 memref has a 0-d region.
|
||||
SmallVector<ForInst *, 4> ivs;
|
||||
getLoopIVs(*opInst, &ivs);
|
||||
SmallVector<Value *, 4> regionSymbols(ivs.begin(), ivs.end());
|
||||
|
||||
SmallVector<Value *, 8> regionSymbols = extractForInductionVars(ivs);
|
||||
regionCst->reset(0, loopDepth, 0, regionSymbols);
|
||||
return true;
|
||||
}
|
||||
|
@ -172,7 +173,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
|
|||
unsigned numSymbols = accessMap.getNumSymbols();
|
||||
// Add inequalties for loop lower/upper bounds.
|
||||
for (unsigned i = 0; i < numDims + numSymbols; ++i) {
|
||||
if (auto *loop = dyn_cast<ForInst>(accessValueMap.getOperand(i))) {
|
||||
if (auto *loop = getForInductionVarOwner(accessValueMap.getOperand(i))) {
|
||||
// Note that regionCst can now have more dimensions than accessMap if the
|
||||
// bounds expressions involve outer loops or other symbols.
|
||||
// TODO(bondhugula): rewrite this to use getInstIndexSet; this way
|
||||
|
@ -207,7 +208,7 @@ bool mlir::getMemRefRegion(OperationInst *opInst, unsigned loopDepth,
|
|||
outerIVs.resize(loopDepth);
|
||||
for (auto *operand : accessValueMap.getOperands()) {
|
||||
ForInst *iv;
|
||||
if ((iv = dyn_cast<ForInst>(operand)) &&
|
||||
if ((iv = getForInductionVarOwner(operand)) &&
|
||||
std::find(outerIVs.begin(), outerIVs.end(), iv) == outerIVs.end()) {
|
||||
regionCst->projectOut(operand);
|
||||
}
|
||||
|
|
|
@ -113,7 +113,8 @@ static AffineMap makePermutationMap(
|
|||
getAffineConstantExpr(0, context));
|
||||
for (auto kvp : enclosingLoopToVectorDim) {
|
||||
assert(kvp.second < perm.size());
|
||||
auto invariants = getInvariantAccesses(*kvp.first, unwrappedIndices);
|
||||
auto invariants =
|
||||
getInvariantAccesses(*kvp.first->getInductionVar(), unwrappedIndices);
|
||||
unsigned numIndices = unwrappedIndices.size();
|
||||
unsigned countInvariantIndices = 0;
|
||||
for (unsigned dim = 0; dim < numIndices; ++dim) {
|
||||
|
|
|
@ -133,9 +133,7 @@ static void printDefininingStatement(llvm::raw_ostream &os, const Value &v) {
|
|||
inst->print(os);
|
||||
return;
|
||||
}
|
||||
// &v is required here otherwise we get:
|
||||
// non-pointer operand type 'const mlir::ForInst' incompatible with nullptr
|
||||
if (auto *forInst = dyn_cast<ForInst>(&v)) {
|
||||
if (auto *forInst = getForInductionVarOwner(&v)) {
|
||||
forInst->print(os);
|
||||
} else {
|
||||
os << "unknown_ssa_value";
|
||||
|
@ -296,7 +294,7 @@ Value *MLIREmitter::emit(Expr e) {
|
|||
exprs[1]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||
auto step =
|
||||
exprs[2]->getDefiningInst()->cast<ConstantIndexOp>()->getValue();
|
||||
res = builder->createFor(location, lb, ub, step);
|
||||
res = builder->createFor(location, lb, ub, step)->getInductionVar();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -347,7 +345,8 @@ void MLIREmitter::emitStmt(const Stmt &stmt) {
|
|||
bind(stmt.getLHS(), val);
|
||||
if (stmt.getRHS().getKind() == ExprKind::For) {
|
||||
// Step into the loop.
|
||||
builder->setInsertionPointToStart(cast<ForInst>(val)->getBody());
|
||||
builder->setInsertionPointToStart(
|
||||
getForInductionVarOwner(val)->getBody());
|
||||
}
|
||||
}
|
||||
emitStmts(stmt.getEnclosedStmts());
|
||||
|
|
|
@ -1078,7 +1078,7 @@ public:
|
|||
void print(const OperationInst *inst);
|
||||
void print(const ForInst *inst);
|
||||
void print(const IfInst *inst);
|
||||
void print(const Block *block);
|
||||
void print(const Block *block, bool printBlockArgs = true);
|
||||
|
||||
void printOperation(const OperationInst *op);
|
||||
void printGenericOp(const OperationInst *op);
|
||||
|
@ -1125,10 +1125,15 @@ public:
|
|||
unsigned index) override;
|
||||
|
||||
/// Print a block list.
|
||||
void printBlockList(const BlockList &blocks) {
|
||||
void printBlockList(const BlockList &blocks, bool printEntryBlockArgs) {
|
||||
os << " {\n";
|
||||
for (auto &b : blocks)
|
||||
print(&b);
|
||||
if (!blocks.empty()) {
|
||||
auto *entryBlock = &blocks.front();
|
||||
print(entryBlock,
|
||||
printEntryBlockArgs && entryBlock->getNumArguments() != 0);
|
||||
for (auto &b : llvm::drop_begin(blocks.getBlocks(), 1))
|
||||
print(&b);
|
||||
}
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
|
||||
|
@ -1164,8 +1169,8 @@ private:
|
|||
|
||||
/// This is the next value ID to assign in numbering.
|
||||
unsigned nextValueID = 0;
|
||||
/// This is the ID to assign to the next induction variable.
|
||||
unsigned nextLoopID = 0;
|
||||
/// This is the ID to assign to the next region entry block argument.
|
||||
unsigned nextRegionArgumentID = 0;
|
||||
/// This is the next ID to assign to a Function argument.
|
||||
unsigned nextArgumentID = 0;
|
||||
/// This is the next ID to assign when a name conflict is detected.
|
||||
|
@ -1205,14 +1210,10 @@ void FunctionPrinter::numberValuesInBlock(const Block &block) {
|
|||
numberValuesInBlock(block);
|
||||
break;
|
||||
}
|
||||
case Instruction::Kind::For: {
|
||||
auto *forInst = cast<ForInst>(&inst);
|
||||
// Number the induction variable.
|
||||
numberValueID(forInst);
|
||||
case Instruction::Kind::For:
|
||||
// Recursively number the stuff in the body.
|
||||
numberValuesInBlock(*forInst->getBody());
|
||||
numberValuesInBlock(*cast<ForInst>(&inst)->getBody());
|
||||
break;
|
||||
}
|
||||
case Instruction::Kind::If: {
|
||||
auto *ifInst = cast<IfInst>(&inst);
|
||||
numberValuesInBlock(*ifInst->getThen());
|
||||
|
@ -1251,13 +1252,19 @@ void FunctionPrinter::numberValueID(const Value *value) {
|
|||
if (specialNameBuffer.empty()) {
|
||||
switch (value->getKind()) {
|
||||
case Value::Kind::BlockArgument:
|
||||
// If this is an argument to the function, give it an 'arg' name.
|
||||
if (auto *block = cast<BlockArgument>(value)->getOwner())
|
||||
if (auto *fn = block->getFunction())
|
||||
if (&fn->getBlockList().front() == block) {
|
||||
// If this is an argument to the function, give it an 'arg' name. If the
|
||||
// argument is to an entry block of an operation region, give it an 'i'
|
||||
// name.
|
||||
if (auto *block = cast<BlockArgument>(value)->getOwner()) {
|
||||
auto *parentBlockList = block->getParent();
|
||||
if (parentBlockList && block == &parentBlockList->front()) {
|
||||
if (parentBlockList->getContainingFunction())
|
||||
specialName << "arg" << nextArgumentID++;
|
||||
break;
|
||||
}
|
||||
else
|
||||
specialName << "i" << nextRegionArgumentID++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Otherwise number it normally.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
|
@ -1266,9 +1273,6 @@ void FunctionPrinter::numberValueID(const Value *value) {
|
|||
// done with it.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case Value::Kind::ForInst:
|
||||
specialName << 'i' << nextLoopID++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1312,10 +1316,8 @@ void FunctionPrinter::print() {
|
|||
printTrailingLocation(function->getLoc());
|
||||
|
||||
if (!function->empty()) {
|
||||
os << " {\n";
|
||||
for (const auto &block : *function)
|
||||
print(&block);
|
||||
os << "}\n";
|
||||
printBlockList(function->getBlockList(), /*printEntryBlockArgs=*/false);
|
||||
os << "\n";
|
||||
}
|
||||
os << '\n';
|
||||
}
|
||||
|
@ -1357,26 +1359,10 @@ void FunctionPrinter::printFunctionSignature() {
|
|||
}
|
||||
}
|
||||
|
||||
/// Return true if the introducer for the specified block should be printed.
|
||||
static bool shouldPrintBlockArguments(const Block *block) {
|
||||
// Never print the entry block of the function - it is included in the
|
||||
// argument list.
|
||||
if (block == &block->getFunction()->front())
|
||||
return false;
|
||||
|
||||
// If this is the first block in a nested region, and if there are no
|
||||
// arguments, then we can omit it.
|
||||
if (block == &block->getParent()->front() && block->getNumArguments() == 0)
|
||||
return false;
|
||||
|
||||
// Otherwise print it.
|
||||
return true;
|
||||
}
|
||||
|
||||
void FunctionPrinter::print(const Block *block) {
|
||||
void FunctionPrinter::print(const Block *block, bool printBlockArgs) {
|
||||
// Print the block label and argument list, unless this is the first block of
|
||||
// the function, or the first block of an IfInst/ForInst with no arguments.
|
||||
if (shouldPrintBlockArguments(block)) {
|
||||
if (printBlockArgs) {
|
||||
os.indent(currentIndent);
|
||||
printBlockName(block);
|
||||
|
||||
|
@ -1445,7 +1431,7 @@ void FunctionPrinter::print(const OperationInst *inst) {
|
|||
|
||||
void FunctionPrinter::print(const ForInst *inst) {
|
||||
os.indent(currentIndent) << "for ";
|
||||
printOperand(inst);
|
||||
printOperand(inst->getInductionVar());
|
||||
os << " = ";
|
||||
printBound(inst->getLowerBound(), "max");
|
||||
os << " to ";
|
||||
|
@ -1457,7 +1443,7 @@ void FunctionPrinter::print(const ForInst *inst) {
|
|||
printTrailingLocation(inst->getLoc());
|
||||
|
||||
os << " {\n";
|
||||
print(inst->getBody());
|
||||
print(inst->getBody(), /*printBlockArgs=*/false);
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
|
||||
|
@ -1468,11 +1454,11 @@ void FunctionPrinter::print(const IfInst *inst) {
|
|||
printDimAndSymbolList(inst->getInstOperands(), set.getNumDims());
|
||||
printTrailingLocation(inst->getLoc());
|
||||
os << " {\n";
|
||||
print(inst->getThen());
|
||||
print(inst->getThen(), /*printBlockArgs=*/false);
|
||||
os.indent(currentIndent) << "}";
|
||||
if (inst->hasElse()) {
|
||||
os << " else {\n";
|
||||
print(inst->getElse());
|
||||
print(inst->getElse(), /*printBlockArgs=*/false);
|
||||
os.indent(currentIndent) << "}";
|
||||
}
|
||||
}
|
||||
|
@ -1583,7 +1569,7 @@ void FunctionPrinter::printGenericOp(const OperationInst *op) {
|
|||
|
||||
// Print any trailing block lists.
|
||||
for (auto &blockList : op->getBlockLists())
|
||||
printBlockList(blockList);
|
||||
printBlockList(blockList, /*printEntryBlockArgs=*/true);
|
||||
}
|
||||
|
||||
void FunctionPrinter::printSuccessorAndUseList(const OperationInst *term,
|
||||
|
@ -1729,8 +1715,6 @@ void Value::print(raw_ostream &os) const {
|
|||
return;
|
||||
case Value::Kind::InstResult:
|
||||
return getDefiningInst()->print(os);
|
||||
case Value::Kind::ForInst:
|
||||
return cast<ForInst>(this)->print(os);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -126,9 +126,9 @@ bool Value::isValidSymbol() const {
|
|||
return op->isValidSymbol();
|
||||
return false;
|
||||
}
|
||||
// This value is either a function argument or an induction variable.
|
||||
// Function argument is ok, induction variable is not.
|
||||
return isa<BlockArgument>(this);
|
||||
// Otherwise, the only valid symbol is a function argument.
|
||||
auto *arg = dyn_cast<BlockArgument>(this);
|
||||
return arg && arg->isFunctionArgument();
|
||||
}
|
||||
|
||||
void Instruction::setOperand(unsigned idx, Value *value) {
|
||||
|
@ -635,13 +635,16 @@ ForInst *ForInst::create(Location location, ArrayRef<Value *> lbOperands,
|
|||
|
||||
ForInst::ForInst(Location location, unsigned numOperands, AffineMap lbMap,
|
||||
AffineMap ubMap, int64_t step)
|
||||
: Instruction(Instruction::Kind::For, location),
|
||||
Value(Value::Kind::ForInst,
|
||||
Type::getIndex(lbMap.getResult(0).getContext())),
|
||||
body(this), lbMap(lbMap), ubMap(ubMap), step(step) {
|
||||
: Instruction(Instruction::Kind::For, location), body(this), lbMap(lbMap),
|
||||
ubMap(ubMap), step(step) {
|
||||
|
||||
// The body of a for inst always has one block.
|
||||
body.push_back(new Block());
|
||||
auto *bodyEntry = new Block();
|
||||
body.push_back(bodyEntry);
|
||||
|
||||
// Add an argument to the block for the induction variable.
|
||||
bodyEntry->addArgument(Type::getIndex(lbMap.getResult(0).getContext()));
|
||||
|
||||
operands.reserve(numOperands);
|
||||
}
|
||||
|
||||
|
@ -777,6 +780,35 @@ void ForInst::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
|
|||
v.walkPostOrder(this);
|
||||
}
|
||||
|
||||
/// Returns the induction variable for this loop.
|
||||
Value *ForInst::getInductionVar() { return getBody()->getArgument(0); }
|
||||
|
||||
/// Returns if the provided value is the induction variable of a ForInst.
|
||||
bool mlir::isForInductionVar(const Value *val) {
|
||||
return getForInductionVarOwner(val) != nullptr;
|
||||
}
|
||||
|
||||
/// Returns the loop parent of an induction variable. If the provided value is
|
||||
/// not an induction variable, then return nullptr.
|
||||
ForInst *mlir::getForInductionVarOwner(Value *val) {
|
||||
const BlockArgument *ivArg = dyn_cast<BlockArgument>(val);
|
||||
if (!ivArg || !ivArg->getOwner())
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<ForInst>(
|
||||
ivArg->getOwner()->getParent()->getContainingInst());
|
||||
}
|
||||
const ForInst *mlir::getForInductionVarOwner(const Value *val) {
|
||||
return getForInductionVarOwner(const_cast<Value *>(val));
|
||||
}
|
||||
|
||||
/// Extracts the induction variables from a list of ForInsts and returns them.
|
||||
SmallVector<Value *, 8>
|
||||
mlir::extractForInductionVars(ArrayRef<ForInst *> forInsts) {
|
||||
SmallVector<Value *, 8> results;
|
||||
for (auto *forInst : forInsts)
|
||||
results.push_back(forInst->getInductionVar());
|
||||
return results;
|
||||
}
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IfInst
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -909,7 +941,7 @@ Instruction *Instruction::clone(BlockAndValueMapping &mapper,
|
|||
ubMap, forInst->getStep());
|
||||
|
||||
// Remember the induction variable mapping.
|
||||
mapper.map(forInst, newFor);
|
||||
mapper.map(forInst->getInductionVar(), newFor->getInductionVar());
|
||||
|
||||
// Recursively clone the body of the for loop.
|
||||
for (auto &subInst : *forInst->getBody())
|
||||
|
|
|
@ -35,8 +35,6 @@ Function *Value::getFunction() {
|
|||
return cast<BlockArgument>(this)->getFunction();
|
||||
case Value::Kind::InstResult:
|
||||
return getDefiningInst()->getFunction();
|
||||
case Value::Kind::ForInst:
|
||||
return cast<ForInst>(this)->getFunction();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -83,3 +81,9 @@ Function *BlockArgument::getFunction() {
|
|||
return owner->getFunction();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Returns if the current argument is a function argument.
|
||||
bool BlockArgument::isFunctionArgument() const {
|
||||
auto *containingFn = getFunction();
|
||||
return containingFn && &containingFn->front() == getOwner();
|
||||
}
|
||||
|
|
|
@ -3201,7 +3201,8 @@ ParseResult FunctionParser::parseForInst() {
|
|||
ubOperands, ubMap, step);
|
||||
|
||||
// Create SSA value definition for the induction variable.
|
||||
if (addDefinition({inductionVariableName, 0, loc}, forInst))
|
||||
if (addDefinition({inductionVariableName, 0, loc},
|
||||
forInst->getInductionVar()))
|
||||
return ParseFailure;
|
||||
|
||||
// Try to parse the optional trailing location.
|
||||
|
@ -3347,7 +3348,7 @@ ParseResult FunctionParser::parseBound(SmallVectorImpl<Value *> &operands,
|
|||
// Create an identity map using dim id for an induction variable and
|
||||
// symbol otherwise. This representation is optimized for storage.
|
||||
// Analysis passes may expand it into a multi-dimensional map if desired.
|
||||
if (isa<ForInst>(operands[0]))
|
||||
if (isForInductionVar(operands[0]))
|
||||
map = builder.getDimIdentityMap();
|
||||
else
|
||||
map = builder.getSymbolIdentityMap();
|
||||
|
|
|
@ -171,7 +171,8 @@ static bool getFullMemRefAsRegion(OperationInst *opInst, unsigned numSymbols,
|
|||
getLoopIVs(*opInst, &ivs);
|
||||
|
||||
auto *regionCst = region->getConstraints();
|
||||
SmallVector<Value *, 4> symbols(ivs.begin(), ivs.end());
|
||||
|
||||
SmallVector<Value *, 8> symbols = extractForInductionVars(ivs);
|
||||
regionCst->reset(rank, numSymbols, 0, symbols);
|
||||
|
||||
// Memref dim sizes provide the bounds.
|
||||
|
|
|
@ -103,7 +103,8 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
|
|||
auto mayBeConstantCount = getConstantTripCount(*origLoops[i]);
|
||||
// The lower bound is just the tile-space loop.
|
||||
AffineMap lbMap = b.getDimIdentityMap();
|
||||
newLoops[width + i]->setLowerBound(/*operands=*/newLoops[i], lbMap);
|
||||
newLoops[width + i]->setLowerBound(
|
||||
/*operands=*/newLoops[i]->getInductionVar(), lbMap);
|
||||
|
||||
// Set the upper bound.
|
||||
if (mayBeConstantCount.hasValue() &&
|
||||
|
@ -117,7 +118,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
|
|||
// with 'i' (tile-space loop) appended to it. The new upper bound map is
|
||||
// the original one with an additional expression i + tileSize appended.
|
||||
SmallVector<Value *, 4> ubOperands(origLoops[i]->getUpperBoundOperands());
|
||||
ubOperands.push_back(newLoops[i]);
|
||||
ubOperands.push_back(newLoops[i]->getInductionVar());
|
||||
|
||||
auto origUbMap = origLoops[i]->getUpperBoundMap();
|
||||
SmallVector<AffineExpr, 4> boundExprs;
|
||||
|
@ -135,7 +136,7 @@ static void constructTiledIndexSetHyperRect(ArrayRef<ForInst *> origLoops,
|
|||
// No need of the min expression.
|
||||
auto dim = b.getAffineDimExpr(0);
|
||||
auto ubMap = b.getAffineMap(1, 0, dim + tileSizes[i], {});
|
||||
newLoops[width + i]->setUpperBound(newLoops[i], ubMap);
|
||||
newLoops[width + i]->setUpperBound(newLoops[i]->getInductionVar(), ubMap);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -194,8 +195,8 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
|
|||
// Move the loop body of the original nest to the new one.
|
||||
moveLoopBody(origLoops[origLoops.size() - 1], innermostPointLoop);
|
||||
|
||||
SmallVector<Value *, 6> origLoopIVs(band.begin(), band.end());
|
||||
SmallVector<Optional<Value *>, 6> ids(band.begin(), band.end());
|
||||
SmallVector<Value *, 8> origLoopIVs = extractForInductionVars(band);
|
||||
SmallVector<Optional<Value *>, 6> ids(origLoopIVs.begin(), origLoopIVs.end());
|
||||
FlatAffineConstraints cst;
|
||||
getIndexSet(band, &cst);
|
||||
|
||||
|
@ -208,7 +209,7 @@ UtilResult mlir::tileCodeGen(ArrayRef<ForInst *> band,
|
|||
constructTiledIndexSetHyperRect(origLoops, newLoops, tileSizes);
|
||||
// In this case, the point loop IVs just replace the original ones.
|
||||
for (unsigned i = 0; i < width; i++) {
|
||||
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]);
|
||||
origLoopIVs[i]->replaceAllUsesWith(newLoops[i + width]->getInductionVar());
|
||||
}
|
||||
|
||||
// Erase the old loop nest.
|
||||
|
|
|
@ -215,6 +215,7 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
|
|||
int64_t step = forInst->getStep();
|
||||
forInst->setStep(step * unrollJamFactor);
|
||||
|
||||
auto *forInstIV = forInst->getInductionVar();
|
||||
for (auto &subBlock : subBlocks) {
|
||||
// Builder to insert unroll-jammed bodies. Insert right at the end of
|
||||
// sub-block.
|
||||
|
@ -226,14 +227,15 @@ bool mlir::loopUnrollJamByFactor(ForInst *forInst, uint64_t unrollJamFactor) {
|
|||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
if (!forInst->use_empty()) {
|
||||
if (!forInstIV->use_empty()) {
|
||||
// iv' = iv + i, i = 1 to unrollJamFactor-1.
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
|
||||
builder
|
||||
.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
|
||||
->getResult(0);
|
||||
operandMapping.map(forInst, ivUnroll);
|
||||
operandMapping.map(forInstIV, ivUnroll);
|
||||
}
|
||||
// Clone the sub-block being unroll-jammed.
|
||||
for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) {
|
||||
|
|
|
@ -348,7 +348,7 @@ bool LowerAffinePass::lowerForInst(ForInst *forInst) {
|
|||
oldBody->begin(), oldBody->end());
|
||||
|
||||
// The code in the body of the forInst now uses 'iv' as its indvar.
|
||||
forInst->replaceAllUsesWith(iv);
|
||||
forInst->getInductionVar()->replaceAllUsesWith(iv);
|
||||
|
||||
// Append the induction variable stepping logic and branch back to the exit
|
||||
// condition block. Construct an affine expression f : (x -> x+step) and
|
||||
|
|
|
@ -121,8 +121,8 @@ static bool doubleBuffer(Value *oldMemRef, ForInst *forInst) {
|
|||
int64_t step = forInst->getStep();
|
||||
auto modTwoMap = bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0,
|
||||
{d0.floorDiv(step) % 2}, {});
|
||||
auto ivModTwoOp =
|
||||
bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap, forInst);
|
||||
auto ivModTwoOp = bInner.create<AffineApplyOp>(forInst->getLoc(), modTwoMap,
|
||||
forInst->getInductionVar());
|
||||
|
||||
// replaceAllMemRefUsesWith will always succeed unless the forInst body has
|
||||
// non-deferencing uses of the memref.
|
||||
|
|
|
@ -99,24 +99,25 @@ bool mlir::promoteIfSingleIteration(ForInst *forInst) {
|
|||
return false;
|
||||
|
||||
// Replaces all IV uses to its single iteration value.
|
||||
if (!forInst->use_empty()) {
|
||||
auto *iv = forInst->getInductionVar();
|
||||
if (!iv->use_empty()) {
|
||||
if (forInst->hasConstantLowerBound()) {
|
||||
auto *mlFunc = forInst->getFunction();
|
||||
FuncBuilder topBuilder(mlFunc);
|
||||
auto constOp = topBuilder.create<ConstantIndexOp>(
|
||||
forInst->getLoc(), forInst->getConstantLowerBound());
|
||||
forInst->replaceAllUsesWith(constOp);
|
||||
iv->replaceAllUsesWith(constOp);
|
||||
} else {
|
||||
const AffineBound lb = forInst->getLowerBound();
|
||||
SmallVector<Value *, 4> lbOperands(lb.operand_begin(), lb.operand_end());
|
||||
FuncBuilder builder(forInst->getBlock(), Block::iterator(forInst));
|
||||
if (lb.getMap() == builder.getDimIdentityMap()) {
|
||||
// No need of generating an affine_apply.
|
||||
forInst->replaceAllUsesWith(lbOperands[0]);
|
||||
iv->replaceAllUsesWith(lbOperands[0]);
|
||||
} else {
|
||||
auto affineApplyOp = builder.create<AffineApplyOp>(
|
||||
forInst->getLoc(), lb.getMap(), lbOperands);
|
||||
forInst->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
iv->replaceAllUsesWith(affineApplyOp->getResult(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -161,6 +162,8 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
|
||||
auto *loopChunk = b->createFor(srcForInst->getLoc(), lbOperands, lbMap,
|
||||
ubOperands, ubMap, srcForInst->getStep());
|
||||
auto *loopChunkIV = loopChunk->getInductionVar();
|
||||
auto *srcIV = srcForInst->getInductionVar();
|
||||
|
||||
BlockAndValueMapping operandMap;
|
||||
|
||||
|
@ -172,17 +175,17 @@ generateLoop(AffineMap lbMap, AffineMap ubMap,
|
|||
// remapped to results of cloned instructions, and their IV used remapped.
|
||||
// Generate the remapping if the shift is not zero: remappedIV = newIV -
|
||||
// shift.
|
||||
if (!srcForInst->use_empty() && shift != 0) {
|
||||
if (!srcIV->use_empty() && shift != 0) {
|
||||
auto b = FuncBuilder::getForInstBodyBuilder(loopChunk);
|
||||
auto *ivRemap = b.create<AffineApplyOp>(
|
||||
srcForInst->getLoc(),
|
||||
b.getSingleDimShiftAffineMap(-static_cast<int64_t>(
|
||||
srcForInst->getStep() * shift)),
|
||||
loopChunk)
|
||||
loopChunkIV)
|
||||
->getResult(0);
|
||||
operandMap.map(srcForInst, ivRemap);
|
||||
operandMap.map(srcIV, ivRemap);
|
||||
} else {
|
||||
operandMap.map(srcForInst, loopChunk);
|
||||
operandMap.map(srcIV, loopChunkIV);
|
||||
}
|
||||
for (auto *inst : insts) {
|
||||
loopChunk->getBody()->push_back(inst->clone(operandMap, b->getContext()));
|
||||
|
@ -419,19 +422,20 @@ bool mlir::loopUnrollByFactor(ForInst *forInst, uint64_t unrollFactor) {
|
|||
Block::iterator srcBlockEnd = std::prev(forInst->getBody()->end());
|
||||
|
||||
// Unroll the contents of 'forInst' (append unrollFactor-1 additional copies).
|
||||
auto *forInstIV = forInst->getInductionVar();
|
||||
for (unsigned i = 1; i < unrollFactor; i++) {
|
||||
BlockAndValueMapping operandMap;
|
||||
|
||||
// If the induction variable is used, create a remapping to the value for
|
||||
// this unrolled instance.
|
||||
if (!forInst->use_empty()) {
|
||||
if (!forInstIV->use_empty()) {
|
||||
// iv' = iv + 1/2/3...unrollFactor-1;
|
||||
auto d0 = builder.getAffineDimExpr(0);
|
||||
auto bumpMap = builder.getAffineMap(1, 0, {d0 + i * step}, {});
|
||||
auto *ivUnroll =
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInst)
|
||||
builder.create<AffineApplyOp>(forInst->getLoc(), bumpMap, forInstIV)
|
||||
->getResult(0);
|
||||
operandMap.map(forInst, ivUnroll);
|
||||
operandMap.map(forInstIV, ivUnroll);
|
||||
}
|
||||
|
||||
// Clone the original body of 'forInst'.
|
||||
|
|
|
@ -881,8 +881,9 @@ static bool vectorizeForInst(ForInst *loop, int64_t step,
|
|||
auto load = opInst->dyn_cast<LoadOp>();
|
||||
auto store = opInst->dyn_cast<StoreOp>();
|
||||
LLVM_DEBUG(opInst->print(dbgs()));
|
||||
auto fail = load ? vectorizeRootOrTerminal(loop, load, state)
|
||||
: vectorizeRootOrTerminal(loop, store, state);
|
||||
auto fail =
|
||||
load ? vectorizeRootOrTerminal(loop->getInductionVar(), load, state)
|
||||
: vectorizeRootOrTerminal(loop->getInductionVar(), store, state);
|
||||
if (fail) {
|
||||
return fail;
|
||||
}
|
||||
|
@ -1210,7 +1211,8 @@ static bool vectorizeRootMatches(NestedMatch matches,
|
|||
/// RAII.
|
||||
ScopeGuard sg2([&fail, loop, clonedLoop]() {
|
||||
if (fail) {
|
||||
loop->replaceAllUsesWith(clonedLoop);
|
||||
loop->getInductionVar()->replaceAllUsesWith(
|
||||
clonedLoop->getInductionVar());
|
||||
loop->erase();
|
||||
} else {
|
||||
clonedLoop->erase();
|
||||
|
|
Loading…
Reference in New Issue