Replace the walkOps/visitOperationInst variants from the InstWalkers with the Instruction variants.
PiperOrigin-RevId: 232322030
This commit is contained in:
parent
9ca0691b06
commit
a3d9ccaecb
|
@ -182,11 +182,11 @@ public:
|
||||||
|
|
||||||
/// Walk the operation instructions in the 'for' instruction in preorder,
|
/// Walk the operation instructions in the 'for' instruction in preorder,
|
||||||
/// calling the callback for each operation.
|
/// calling the callback for each operation.
|
||||||
void walkOps(std::function<void(Instruction *)> callback);
|
void walk(std::function<void(Instruction *)> callback);
|
||||||
|
|
||||||
/// Walk the operation instructions in the 'for' instruction in postorder,
|
/// Walk the operation instructions in the 'for' instruction in postorder,
|
||||||
/// calling the callback for each operation.
|
/// calling the callback for each operation.
|
||||||
void walkOpsPostOrder(std::function<void(Instruction *)> callback);
|
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Instruction;
|
friend class Instruction;
|
||||||
|
|
|
@ -127,7 +127,7 @@ private:
|
||||||
struct State : public InstWalker<State> {
|
struct State : public InstWalker<State> {
|
||||||
State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
|
State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
|
||||||
: pattern(pattern), matches(matches) {}
|
: pattern(pattern), matches(matches) {}
|
||||||
void visitOperationInst(Instruction *opInst) {
|
void visitInstruction(Instruction *opInst) {
|
||||||
pattern.matchOne(opInst, matches);
|
pattern.matchOne(opInst, matches);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -311,12 +311,12 @@ public:
|
||||||
return &Block::instructions;
|
return &Block::instructions;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Walk the operation instructions of this block in preorder, calling the
|
/// Walk the instructions of this block in preorder, calling the callback for
|
||||||
/// callback for each operation.
|
/// each operation.
|
||||||
void walk(std::function<void(Instruction *)> callback);
|
void walk(std::function<void(Instruction *)> callback);
|
||||||
|
|
||||||
/// Walk the operation instructions in this block in postorder, calling the
|
/// Walk the instructions in this block in postorder, calling the callback for
|
||||||
/// callback for each operation.
|
/// each operation.
|
||||||
void walkPostOrder(std::function<void(Instruction *)> callback);
|
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||||
|
|
||||||
/// Walk the operation instructions in the specified [begin, end) range of
|
/// Walk the operation instructions in the specified [begin, end) range of
|
||||||
|
|
|
@ -117,13 +117,11 @@ public:
|
||||||
|
|
||||||
/// Walk the instructions in the function in preorder, calling the callback
|
/// Walk the instructions in the function in preorder, calling the callback
|
||||||
/// for each instruction or operation.
|
/// for each instruction or operation.
|
||||||
void walkInsts(std::function<void(Instruction *)> callback);
|
void walk(std::function<void(Instruction *)> callback);
|
||||||
void walkOps(std::function<void(Instruction *)> callback);
|
|
||||||
|
|
||||||
/// Walk the instructions in the function in postorder, calling the callback
|
/// Walk the instructions in the function in postorder, calling the callback
|
||||||
/// for each instruction or operation.
|
/// for each instruction or operation.
|
||||||
void walkInstsPostOrder(std::function<void(Instruction *)> callback);
|
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||||
void walkOpsPostOrder(std::function<void(Instruction *)> callback);
|
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
// Arguments
|
// Arguments
|
||||||
|
|
|
@ -67,34 +67,6 @@
|
||||||
#include "mlir/IR/Instruction.h"
|
#include "mlir/IR/Instruction.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
/// Base class for instruction visitors.
|
|
||||||
template <typename SubClass, typename RetTy = void> class InstVisitor {
|
|
||||||
//===--------------------------------------------------------------------===//
|
|
||||||
// Interface code - This is the public interface of the InstVisitor that you
|
|
||||||
// use to visit instructions.
|
|
||||||
|
|
||||||
public:
|
|
||||||
// Function to visit a instruction.
|
|
||||||
RetTy visit(Instruction *s) {
|
|
||||||
static_assert(std::is_base_of<InstVisitor, SubClass>::value,
|
|
||||||
"Must pass the derived type to this template!");
|
|
||||||
return static_cast<SubClass *>(this)->visitOperationInst(s);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
|
||||||
// Visitation functions... these functions provide default fallbacks in case
|
|
||||||
// the user does not specify what to do for a particular instruction type.
|
|
||||||
// The default behavior is to generalize the instruction type to its subtype
|
|
||||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
|
||||||
// because there are no virtual functions to get in the way.
|
|
||||||
//
|
|
||||||
|
|
||||||
// When visiting a for inst, if inst, or an operation inst directly, these
|
|
||||||
// methods get called to indicate when transitioning into a new unit.
|
|
||||||
void visitOperationInst(Instruction *opInst) {}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Base class for instruction walkers. A walker can traverse depth first in
|
/// Base class for instruction walkers. A walker can traverse depth first in
|
||||||
/// pre-order or post order. The walk methods without a suffix do a pre-order
|
/// pre-order or post order. The walk methods without a suffix do a pre-order
|
||||||
/// traversal while those that traverse in post order have a PostOrder suffix.
|
/// traversal while those that traverse in post order have a PostOrder suffix.
|
||||||
|
@ -127,36 +99,26 @@ public:
|
||||||
static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
|
static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
|
||||||
}
|
}
|
||||||
|
|
||||||
void walkOpInst(Instruction *opInst) {
|
|
||||||
static_cast<SubClass *>(this)->visitOperationInst(opInst);
|
|
||||||
for (auto &blockList : opInst->getBlockLists())
|
|
||||||
for (auto &block : blockList)
|
|
||||||
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
|
|
||||||
}
|
|
||||||
|
|
||||||
void walkOpInstPostOrder(Instruction *opInst) {
|
|
||||||
for (auto &blockList : opInst->getBlockLists())
|
|
||||||
for (auto &block : blockList)
|
|
||||||
static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
|
|
||||||
block.end());
|
|
||||||
static_cast<SubClass *>(this)->visitOperationInst(opInst);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Function to walk a instruction.
|
// Function to walk a instruction.
|
||||||
RetTy walk(Instruction *s) {
|
RetTy walk(Instruction *s) {
|
||||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||||
"Must pass the derived type to this template!");
|
"Must pass the derived type to this template!");
|
||||||
|
|
||||||
static_cast<SubClass *>(this)->visitInstruction(s);
|
static_cast<SubClass *>(this)->visitInstruction(s);
|
||||||
return static_cast<SubClass *>(this)->walkOpInst(s);
|
for (auto &blockList : s->getBlockLists())
|
||||||
|
for (auto &block : blockList)
|
||||||
|
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to walk a instruction in post order DFS.
|
// Function to walk a instruction in post order DFS.
|
||||||
RetTy walkPostOrder(Instruction *s) {
|
RetTy walkPostOrder(Instruction *s) {
|
||||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||||
"Must pass the derived type to this template!");
|
"Must pass the derived type to this template!");
|
||||||
|
for (auto &blockList : s->getBlockLists())
|
||||||
|
for (auto &block : blockList)
|
||||||
|
static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
|
||||||
|
block.end());
|
||||||
static_cast<SubClass *>(this)->visitInstruction(s);
|
static_cast<SubClass *>(this)->visitInstruction(s);
|
||||||
return static_cast<SubClass *>(this)->walkOpInstPostOrder(s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//===--------------------------------------------------------------------===//
|
//===--------------------------------------------------------------------===//
|
||||||
|
@ -170,7 +132,6 @@ public:
|
||||||
// called. These are typically O(1) complexity and shouldn't be recursively
|
// called. These are typically O(1) complexity and shouldn't be recursively
|
||||||
// processing their descendants in some way. When using RetTy, all of these
|
// processing their descendants in some way. When using RetTy, all of these
|
||||||
// need to be overridden.
|
// need to be overridden.
|
||||||
void visitOperationInst(Instruction *opInst) {}
|
|
||||||
void visitInstruction(Instruction *inst) {}
|
void visitInstruction(Instruction *inst) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -144,7 +144,7 @@ PassResult MLPatternLoweringPass<Patterns...>::runOnFunction(Function *f) {
|
||||||
MLFuncLoweringRewriter rewriter(&builder);
|
MLFuncLoweringRewriter rewriter(&builder);
|
||||||
|
|
||||||
llvm::SmallVector<Instruction *, 16> ops;
|
llvm::SmallVector<Instruction *, 16> ops;
|
||||||
f->walkOps([&ops](Instruction *inst) { ops.push_back(inst); });
|
f->walk([&ops](Instruction *inst) { ops.push_back(inst); });
|
||||||
|
|
||||||
for (Instruction *inst : ops) {
|
for (Instruction *inst : ops) {
|
||||||
for (const auto &pattern : patterns) {
|
for (const auto &pattern : patterns) {
|
||||||
|
|
|
@ -410,27 +410,26 @@ bool AffineForOp::matchingBoundOperandList() const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineForOp::walkOps(std::function<void(Instruction *)> callback) {
|
void AffineForOp::walk(std::function<void(Instruction *)> callback) {
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(Instruction *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(Instruction *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
: callback(callback) {}
|
: callback(callback) {}
|
||||||
|
|
||||||
void visitOperationInst(Instruction *opInst) { callback(opInst); }
|
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||||
};
|
};
|
||||||
|
|
||||||
Walker w(callback);
|
Walker w(callback);
|
||||||
w.walk(getInstruction());
|
w.walk(getInstruction());
|
||||||
}
|
}
|
||||||
|
|
||||||
void AffineForOp::walkOpsPostOrder(
|
void AffineForOp::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||||
std::function<void(Instruction *)> callback) {
|
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(Instruction *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(Instruction *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
: callback(callback) {}
|
: callback(callback) {}
|
||||||
|
|
||||||
void visitOperationInst(Instruction *opInst) { callback(opInst); }
|
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||||
};
|
};
|
||||||
|
|
||||||
Walker v(callback);
|
Walker v(callback);
|
||||||
|
|
|
@ -46,7 +46,7 @@ struct MemRefDependenceCheck : public FunctionPass,
|
||||||
|
|
||||||
PassResult runOnFunction(Function *f) override;
|
PassResult runOnFunction(Function *f) override;
|
||||||
|
|
||||||
void visitOperationInst(Instruction *opInst) {
|
void visitInstruction(Instruction *opInst) {
|
||||||
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
|
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
|
||||||
loadsAndStores.push_back(opInst);
|
loadsAndStores.push_back(opInst);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ char LowerEDSCTestPass::passID = 0;
|
||||||
#include "mlir/EDSC/reference-impl.inc"
|
#include "mlir/EDSC/reference-impl.inc"
|
||||||
|
|
||||||
PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
PassResult LowerEDSCTestPass::runOnFunction(Function *f) {
|
||||||
f->walkOps([](OperationInst *op) {
|
f->walk([](OperationInst *op) {
|
||||||
if (op->getName().getStringRef() == "print") {
|
if (op->getName().getStringRef() == "print") {
|
||||||
auto opName = op->getAttrOfType<StringAttr>("op");
|
auto opName = op->getAttrOfType<StringAttr>("op");
|
||||||
if (!opName) {
|
if (!opName) {
|
||||||
|
|
|
@ -263,7 +263,7 @@ void ModuleState::initialize(const Module *module) {
|
||||||
for (auto &fn : *module) {
|
for (auto &fn : *module) {
|
||||||
visitType(fn.getType());
|
visitType(fn.getType());
|
||||||
|
|
||||||
const_cast<Function &>(fn).walkInsts(
|
const_cast<Function &>(fn).walk(
|
||||||
[&](Instruction *op) { ModuleState::visitInstruction(op); });
|
[&](Instruction *op) { ModuleState::visitInstruction(op); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -256,31 +256,31 @@ Block *Block::splitBlock(iterator splitBefore) {
|
||||||
return newBB;
|
return newBB;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Block::walk(std::function<void(OperationInst *)> callback) {
|
void Block::walk(std::function<void(Instruction *)> callback) {
|
||||||
walk(begin(), end(), callback);
|
walk(begin(), end(), callback);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Block::walk(Block::iterator begin, Block::iterator end,
|
void Block::walk(Block::iterator begin, Block::iterator end,
|
||||||
std::function<void(OperationInst *)> callback) {
|
std::function<void(Instruction *)> callback) {
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(OperationInst *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(OperationInst *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
: callback(callback) {}
|
: callback(callback) {}
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||||
};
|
};
|
||||||
|
|
||||||
Walker w(callback);
|
Walker w(callback);
|
||||||
w.walk(begin, end);
|
w.walk(begin, end);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Block::walkPostOrder(std::function<void(OperationInst *)> callback) {
|
void Block::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(OperationInst *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(OperationInst *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
: callback(callback) {}
|
: callback(callback) {}
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||||
};
|
};
|
||||||
|
|
||||||
Walker v(callback);
|
Walker v(callback);
|
||||||
|
@ -338,17 +338,13 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper,
|
||||||
BlockAndValueMapping &mapper;
|
BlockAndValueMapping &mapper;
|
||||||
Walker(BlockAndValueMapping &mapper) : mapper(mapper) {}
|
Walker(BlockAndValueMapping &mapper) : mapper(mapper) {}
|
||||||
|
|
||||||
/// Remap the instruction operands.
|
/// Remap the instruction and successor block operands.
|
||||||
void visitInstruction(Instruction *inst) {
|
void visitInstruction(OperationInst *inst) {
|
||||||
for (auto &instOp : inst->getInstOperands())
|
for (auto &instOp : inst->getInstOperands())
|
||||||
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
|
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
|
||||||
instOp.set(mappedOp);
|
instOp.set(mappedOp);
|
||||||
}
|
if (inst->isTerminator())
|
||||||
// Remap the successor block operands.
|
for (auto &succOp : inst->getBlockOperands())
|
||||||
void visitOperationInst(OperationInst *opInst) {
|
|
||||||
if (!opInst->isTerminator())
|
|
||||||
return;
|
|
||||||
for (auto &succOp : opInst->getBlockOperands())
|
|
||||||
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
|
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
|
||||||
succOp.set(mappedOp);
|
succOp.set(mappedOp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -214,7 +214,7 @@ void Function::addEntryBlock() {
|
||||||
entry->addArguments(type.getInputs());
|
entry->addArguments(type.getInputs());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Function::walkInsts(std::function<void(Instruction *)> callback) {
|
void Function::walk(std::function<void(Instruction *)> callback) {
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(Instruction *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(Instruction *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
|
@ -227,39 +227,13 @@ void Function::walkInsts(std::function<void(Instruction *)> callback) {
|
||||||
v.walk(this);
|
v.walk(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Function::walkOps(std::function<void(OperationInst *)> callback) {
|
void Function::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||||
struct Walker : public InstWalker<Walker> {
|
|
||||||
std::function<void(OperationInst *)> const &callback;
|
|
||||||
Walker(std::function<void(OperationInst *)> const &callback)
|
|
||||||
: callback(callback) {}
|
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
|
||||||
};
|
|
||||||
|
|
||||||
Walker v(callback);
|
|
||||||
v.walk(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Function::walkInstsPostOrder(std::function<void(Instruction *)> callback) {
|
|
||||||
struct Walker : public InstWalker<Walker> {
|
struct Walker : public InstWalker<Walker> {
|
||||||
std::function<void(Instruction *)> const &callback;
|
std::function<void(Instruction *)> const &callback;
|
||||||
Walker(std::function<void(Instruction *)> const &callback)
|
Walker(std::function<void(Instruction *)> const &callback)
|
||||||
: callback(callback) {}
|
: callback(callback) {}
|
||||||
|
|
||||||
void visitOperationInst(Instruction *inst) { callback(inst); }
|
void visitInstruction(Instruction *inst) { callback(inst); }
|
||||||
};
|
|
||||||
|
|
||||||
Walker v(callback);
|
|
||||||
v.walkPostOrder(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
void Function::walkOpsPostOrder(std::function<void(OperationInst *)> callback) {
|
|
||||||
struct Walker : public InstWalker<Walker> {
|
|
||||||
std::function<void(OperationInst *)> const &callback;
|
|
||||||
Walker(std::function<void(OperationInst *)> const &callback)
|
|
||||||
: callback(callback) {}
|
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) { callback(opInst); }
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Walker v(callback);
|
Walker v(callback);
|
||||||
|
|
|
@ -48,7 +48,7 @@ namespace {
|
||||||
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
|
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
|
||||||
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
|
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
|
||||||
PassResult runOnFunction(Function *f) override;
|
PassResult runOnFunction(Function *f) override;
|
||||||
void visitOperationInst(OperationInst *opInst);
|
void visitInstruction(OperationInst *opInst);
|
||||||
|
|
||||||
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
|
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ static bool affineApplyOp(const Instruction &inst) {
|
||||||
return opInst.isa<AffineApplyOp>();
|
return opInst.isa<AffineApplyOp>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComposeAffineMaps::visitOperationInst(OperationInst *opInst) {
|
void ComposeAffineMaps::visitInstruction(OperationInst *opInst) {
|
||||||
if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) {
|
if (auto afOp = opInst->dyn_cast<AffineApplyOp>()) {
|
||||||
affineApplyOps.push_back(afOp);
|
affineApplyOps.push_back(afOp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
|
||||||
|
|
||||||
bool foldOperation(OperationInst *op,
|
bool foldOperation(OperationInst *op,
|
||||||
SmallVectorImpl<Value *> &existingConstants);
|
SmallVectorImpl<Value *> &existingConstants);
|
||||||
void visitOperationInst(OperationInst *inst);
|
void visitInstruction(OperationInst *op);
|
||||||
PassResult runOnFunction(Function *f) override;
|
PassResult runOnFunction(Function *f) override;
|
||||||
|
|
||||||
static char passID;
|
static char passID;
|
||||||
|
@ -49,7 +49,7 @@ char ConstantFold::passID = 0;
|
||||||
/// Attempt to fold the specified operation, updating the IR to match. If
|
/// Attempt to fold the specified operation, updating the IR to match. If
|
||||||
/// constants are found, we keep track of them in the existingConstants list.
|
/// constants are found, we keep track of them in the existingConstants list.
|
||||||
///
|
///
|
||||||
void ConstantFold::visitOperationInst(OperationInst *op) {
|
void ConstantFold::visitInstruction(OperationInst *op) {
|
||||||
// If this operation is an AffineForOp, then fold the bounds.
|
// If this operation is an AffineForOp, then fold the bounds.
|
||||||
if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
||||||
constantFoldBounds(forOp);
|
constantFoldBounds(forOp);
|
||||||
|
|
|
@ -118,7 +118,7 @@ public:
|
||||||
SmallVector<OperationInst *, 4> storeOpInsts;
|
SmallVector<OperationInst *, 4> storeOpInsts;
|
||||||
bool hasNonForRegion = false;
|
bool hasNonForRegion = false;
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) {
|
void visitInstruction(OperationInst *opInst) {
|
||||||
if (opInst->isa<AffineForOp>())
|
if (opInst->isa<AffineForOp>())
|
||||||
forOps.push_back(opInst->cast<AffineForOp>());
|
forOps.push_back(opInst->cast<AffineForOp>());
|
||||||
else if (opInst->getNumBlockLists() != 0)
|
else if (opInst->getNumBlockLists() != 0)
|
||||||
|
@ -619,7 +619,7 @@ public:
|
||||||
|
|
||||||
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
|
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) {
|
void visitInstruction(OperationInst *opInst) {
|
||||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||||
if (!forOp)
|
if (!forOp)
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -113,7 +113,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
||||||
return hasInnerLoops;
|
return hasInnerLoops;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool walkOpInstPostOrder(OperationInst *opInst) {
|
bool walkPostOrder(OperationInst *opInst) {
|
||||||
bool hasInnerLoops = false;
|
bool hasInnerLoops = false;
|
||||||
for (auto &blockList : opInst->getBlockLists())
|
for (auto &blockList : opInst->getBlockLists())
|
||||||
for (auto &block : blockList)
|
for (auto &block : blockList)
|
||||||
|
@ -140,7 +140,7 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
||||||
const unsigned minTripCount;
|
const unsigned minTripCount;
|
||||||
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
|
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst) {
|
void visitInstruction(OperationInst *opInst) {
|
||||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||||
if (!forOp)
|
if (!forOp)
|
||||||
return;
|
return;
|
||||||
|
|
|
@ -196,7 +196,7 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
|
||||||
|
|
||||||
// Gather all sub-blocks to jam upon the loop being unrolled.
|
// Gather all sub-blocks to jam upon the loop being unrolled.
|
||||||
JamBlockGatherer jbg;
|
JamBlockGatherer jbg;
|
||||||
jbg.walkOpInst(forInst);
|
jbg.walk(forInst);
|
||||||
auto &subBlocks = jbg.subBlocks;
|
auto &subBlocks = jbg.subBlocks;
|
||||||
|
|
||||||
// Generate the cleanup loop if trip count isn't a multiple of
|
// Generate the cleanup loop if trip count isn't a multiple of
|
||||||
|
|
|
@ -615,7 +615,7 @@ PassResult LowerAffinePass::runOnFunction(Function *function) {
|
||||||
|
|
||||||
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
|
// Collect all the For instructions as well as AffineIfOps and AffineApplyOps.
|
||||||
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
// We do this as a prepass to avoid invalidating the walker with our rewrite.
|
||||||
function->walkInsts([&](Instruction *inst) {
|
function->walk([&](Instruction *inst) {
|
||||||
auto op = cast<OperationInst>(inst);
|
auto op = cast<OperationInst>(inst);
|
||||||
if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
|
if (op->isa<AffineApplyOp>() || op->isa<AffineForOp>() ||
|
||||||
op->isa<AffineIfOp>())
|
op->isa<AffineIfOp>())
|
||||||
|
|
|
@ -75,7 +75,7 @@ struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
|
||||||
|
|
||||||
PassResult runOnFunction(Function *f) override;
|
PassResult runOnFunction(Function *f) override;
|
||||||
|
|
||||||
void visitOperationInst(OperationInst *opInst);
|
void visitInstruction(OperationInst *opInst);
|
||||||
|
|
||||||
// A list of memref's that are potentially dead / could be eliminated.
|
// A list of memref's that are potentially dead / could be eliminated.
|
||||||
SmallPtrSet<Value *, 4> memrefsToErase;
|
SmallPtrSet<Value *, 4> memrefsToErase;
|
||||||
|
@ -100,7 +100,7 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
|
||||||
|
|
||||||
// This is a straightforward implementation not optimized for speed. Optimize
|
// This is a straightforward implementation not optimized for speed. Optimize
|
||||||
// this in the future if needed.
|
// this in the future if needed.
|
||||||
void MemRefDataFlowOpt::visitOperationInst(OperationInst *opInst) {
|
void MemRefDataFlowOpt::visitInstruction(OperationInst *opInst) {
|
||||||
OperationInst *lastWriteStoreOp = nullptr;
|
OperationInst *lastWriteStoreOp = nullptr;
|
||||||
|
|
||||||
auto loadOp = opInst->dyn_cast<LoadOp>();
|
auto loadOp = opInst->dyn_cast<LoadOp>();
|
||||||
|
|
|
@ -142,7 +142,7 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
|
||||||
// deleted and replaced by a prologue, a new steady-state loop and an
|
// deleted and replaced by a prologue, a new steady-state loop and an
|
||||||
// epilogue).
|
// epilogue).
|
||||||
forOps.clear();
|
forOps.clear();
|
||||||
f->walkOpsPostOrder([&](OperationInst *opInst) {
|
f->walkPostOrder([&](OperationInst *opInst) {
|
||||||
if (auto forOp = opInst->dyn_cast<AffineForOp>())
|
if (auto forOp = opInst->dyn_cast<AffineForOp>())
|
||||||
forOps.push_back(forOp);
|
forOps.push_back(forOp);
|
||||||
});
|
});
|
||||||
|
|
|
@ -64,7 +64,7 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
|
||||||
}
|
}
|
||||||
|
|
||||||
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
|
PassResult SimplifyAffineStructures::runOnFunction(Function *f) {
|
||||||
f->walkOps([&](OperationInst *opInst) {
|
f->walk([&](OperationInst *opInst) {
|
||||||
for (auto attr : opInst->getAttrs()) {
|
for (auto attr : opInst->getAttrs()) {
|
||||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||||
MutableAffineMap mMap(mapAttr.getValue());
|
MutableAffineMap mMap(mapAttr.getValue());
|
||||||
|
|
|
@ -39,7 +39,7 @@ PassResult StripDebugInfo::runOnFunction(Function *f) {
|
||||||
|
|
||||||
// Strip the debug info from the function and its instructions.
|
// Strip the debug info from the function and its instructions.
|
||||||
f->setLoc(unknownLoc);
|
f->setLoc(unknownLoc);
|
||||||
f->walkInsts([&](Instruction *inst) { inst->setLoc(unknownLoc); });
|
f->walk([&](Instruction *inst) { inst->setLoc(unknownLoc); });
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -38,7 +38,7 @@ public:
|
||||||
worklist.reserve(64);
|
worklist.reserve(64);
|
||||||
|
|
||||||
// Add all operations to the worklist.
|
// Add all operations to the worklist.
|
||||||
fn->walkOps([&](OperationInst *inst) { addToWorklist(inst); });
|
fn->walk([&](OperationInst *inst) { addToWorklist(inst); });
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Perform the rewrites.
|
/// Perform the rewrites.
|
||||||
|
|
|
@ -135,7 +135,7 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
|
||||||
/// their body into the containing Block.
|
/// their body into the containing Block.
|
||||||
void mlir::promoteSingleIterationLoops(Function *f) {
|
void mlir::promoteSingleIterationLoops(Function *f) {
|
||||||
// Gathers all innermost loops through a post order pruned walk.
|
// Gathers all innermost loops through a post order pruned walk.
|
||||||
f->walkOpsPostOrder([](OperationInst *inst) {
|
f->walkPostOrder([](OperationInst *inst) {
|
||||||
if (auto forOp = inst->dyn_cast<AffineForOp>())
|
if (auto forOp = inst->dyn_cast<AffineForOp>())
|
||||||
promoteIfSingleIteration(forOp);
|
promoteIfSingleIteration(forOp);
|
||||||
});
|
});
|
||||||
|
|
|
@ -362,8 +362,8 @@ void mlir::remapFunctionAttrs(
|
||||||
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
Function &fn, const DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||||
|
|
||||||
// Look at all instructions in a Function.
|
// Look at all instructions in a Function.
|
||||||
fn.walkOps(
|
fn.walk(
|
||||||
[&](OperationInst *inst) { remapFunctionAttrs(*inst, remappingTable); });
|
[&](Instruction *inst) { remapFunctionAttrs(*inst, remappingTable); });
|
||||||
}
|
}
|
||||||
|
|
||||||
void mlir::remapFunctionAttrs(
|
void mlir::remapFunctionAttrs(
|
||||||
|
|
Loading…
Reference in New Issue