Remove InstWalker and move all instruction walking to the api facilities on Function/Block/Instruction.

PiperOrigin-RevId: 232388113
This commit is contained in:
River Riddle 2019-02-04 16:24:44 -08:00 committed by jpienaar
parent c9ad4621ce
commit bf9c381d1d
25 changed files with 280 additions and 458 deletions

View File

@ -229,14 +229,6 @@ public:
/// (same operands in the same order).
bool matchingBoundOperandList() const;
/// Walk the operation instructions in the 'for' instruction in preorder,
/// calling the callback for each operation.
void walk(std::function<void(Instruction *)> callback);
/// Walk the operation instructions in the 'for' instruction in postorder,
/// calling the callback for each operation.
void walkPostOrder(std::function<void(Instruction *)> callback);
private:
friend class Instruction;
explicit AffineForOp(const Instruction *state) : Op(state) {}

View File

@ -18,7 +18,7 @@
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Function.h"
#include "llvm/Support/Allocator.h"
namespace mlir {
@ -76,7 +76,7 @@ private:
ArrayRef<NestedMatch> matchedChildren;
};
/// A NestedPattern is a nested InstWalker that:
/// A NestedPattern is a nested instruction walker that:
/// 1. recursively matches a substructure in the tree;
/// 2. uses a filter function to refine matches with extra semantic
/// constraints (passed via a lambda of type FilterFunctionType);
@ -92,8 +92,8 @@ private:
///
/// The NestedMatches captured in the IR can grow large, especially after
/// aggressive unrolling. As experience has shown, it is generally better to use
/// a plain InstWalker to match flat patterns but the current implementation is
/// competitive nonetheless.
/// a plain walk over instructions to match flat patterns but the current
/// implementation is competitive nonetheless.
using FilterFunctionType = std::function<bool(const Instruction &)>;
static bool defaultFilterFunction(const Instruction &) { return true; };
struct NestedPattern {
@ -102,16 +102,14 @@ struct NestedPattern {
NestedPattern(const NestedPattern &) = default;
NestedPattern &operator=(const NestedPattern &) = default;
/// Returns all the top-level matches in `function`.
void match(Function *function, SmallVectorImpl<NestedMatch> *matches) {
State state(*this, matches);
state.walkPostOrder(function);
/// Returns all the top-level matches in `func`.
void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); });
}
/// Returns all the top-level matches in `inst`.
void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) {
State state(*this, matches);
state.walkPostOrder(inst);
inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); });
}
/// Returns the depth of the pattern.
@ -120,22 +118,8 @@ struct NestedPattern {
private:
friend class NestedPatternContext;
friend class NestedMatch;
friend class InstWalker<NestedPattern>;
friend struct State;
/// Helper state that temporarily holds matches for the next level of nesting.
struct State : public InstWalker<State> {
State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
: pattern(pattern), matches(matches) {}
void visitInstruction(Instruction *opInst) {
pattern.matchOne(opInst, matches);
}
private:
NestedPattern &pattern;
SmallVectorImpl<NestedMatch> *matches;
};
/// Underlying global bump allocator managed by a NestedPatternContext.
static llvm::BumpPtrAllocator *&allocator();
@ -153,8 +137,9 @@ private:
/// without switching on the type of the Instruction. The idea is that a
/// NestedPattern first checks if it matches locally and then recursively
/// applies its nested matchers to its elem->nested. Since we want to rely on
/// the InstWalker impl rather than duplicate its the logic, we allow an
/// off-by-one traversal to account for the fact that we write:
/// the existing instruction walking functionality rather than duplicate
/// it, we allow an off-by-one traversal to account for the fact that we
/// write:
///
/// void match(Instruction *elem) {
/// for (auto &c : getNestedPatterns()) {

View File

@ -287,6 +287,28 @@ public:
succ_iterator succ_end();
llvm::iterator_range<succ_iterator> getSuccessors();
//===--------------------------------------------------------------------===//
// Instruction Walkers
//===--------------------------------------------------------------------===//
/// Walk the instructions of this block in preorder, calling the callback for
/// each operation.
void walk(const std::function<void(Instruction *)> &callback);
/// Walk the instructions in the specified [begin, end) range of
/// this block, calling the callback for each operation.
void walk(Block::iterator begin, Block::iterator end,
const std::function<void(Instruction *)> &callback);
/// Walk the instructions in this block in postorder, calling the callback for
/// each operation.
void walkPostOrder(const std::function<void(Instruction *)> &callback);
/// Walk the instructions in the specified [begin, end) range of this block
/// in postorder, calling the callback for each operation.
void walkPostOrder(Block::iterator begin, Block::iterator end,
const std::function<void(Instruction *)> &callback);
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//
@ -311,19 +333,6 @@ public:
return &Block::instructions;
}
/// Walk the instructions of this block in preorder, calling the callback for
/// each operation.
void walk(std::function<void(Instruction *)> callback);
/// Walk the instructions in this block in postorder, calling the callback for
/// each operation.
void walkPostOrder(std::function<void(Instruction *)> callback);
/// Walk the instructions in the specified [begin, end) range of
/// this block, calling the callback for each operation.
void walk(Block::iterator begin, Block::iterator end,
std::function<void(Instruction *)> callback);
void print(raw_ostream &os) const;
void dump() const;

View File

@ -27,6 +27,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
@ -39,6 +40,7 @@ class FunctionType;
class MLIRContext;
class Module;
template <typename ObjectType, typename ElementType> class ArgumentIterator;
template <typename T> class OpPointer;
/// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
@ -115,13 +117,35 @@ public:
Block &front() { return blocks.front(); }
const Block &front() const { return const_cast<Function *>(this)->front(); }
//===--------------------------------------------------------------------===//
// Instruction Walkers
//===--------------------------------------------------------------------===//
/// Walk the instructions in the function in preorder, calling the callback
/// for each instruction or operation.
void walk(std::function<void(Instruction *)> callback);
/// for each instruction.
void walk(const std::function<void(Instruction *)> &callback);
/// Specialization of walk to only visit operations of 'OpTy'.
template <typename OpTy>
void walk(std::function<void(OpPointer<OpTy>)> callback) {
walk([&](Instruction *inst) {
if (auto op = inst->dyn_cast<OpTy>())
callback(op);
});
}
/// Walk the instructions in the function in postorder, calling the callback
/// for each instruction or operation.
void walkPostOrder(std::function<void(Instruction *)> callback);
/// for each instruction.
void walkPostOrder(const std::function<void(Instruction *)> &callback);
/// Specialization of walkPostOrder to only visit operations of 'OpTy'.
template <typename OpTy>
void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
walkPostOrder([&](Instruction *inst) {
if (auto op = inst->dyn_cast<OpTy>())
callback(op);
});
}
//===--------------------------------------------------------------------===//
// Arguments

View File

@ -1,140 +0,0 @@
//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the base classes for Function's instruction visitors and
// walkers. A visit is a O(1) operation that visits just the node in question. A
// walk visits the node it's called on as well as the node's descendants.
//
// Instruction visitors/walkers are used when you want to perform different
// actions for different kinds of instructions without having to use lots of
// casts and a big switch instruction.
//
// To define your own visitor/walker, inherit from these classes, specifying
// your new type for the 'SubClass' template parameter, and "override" visitXXX
// functions in your class. This class is defined in terms of statically
// resolved overloading, not virtual functions.
//
// For example, here is a walker that counts the number of for loops in an
// Function.
//
// /// Declare the class. Note that we derive from InstWalker instantiated
// /// with _our new subclasses_ type.
// struct LoopCounter : public InstWalker<LoopCounter> {
// unsigned numLoops;
// LoopCounter() : numLoops(0) {}
// void visitForInst(ForInst &fs) { ++numLoops; }
// };
//
// And this class would be used like this:
// LoopCounter lc;
// lc.walk(function);
// numLoops = lc.numLoops;
//
// There are 'visit' methods for Instruction and Function, which recursively
// process all contained instructions.
//
// Note that if you don't implement visitXXX for some instruction type,
// the visitXXX method for Instruction superclass will be invoked.
//
// The optional second template argument specifies the type that instruction
// visitation functions should return. If you specify this, you *MUST* provide
// an implementation of every visit<#Instruction>(InstType *).
//
// Note that these classes are specifically designed as a template to avoid
// virtual function call overhead.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_INSTVISITOR_H
#define MLIR_IR_INSTVISITOR_H
#include "mlir/IR/Function.h"
#include "mlir/IR/Instruction.h"
namespace mlir {
/// Base class for instruction walkers. A walker can traverse depth first in
/// pre-order or post order. The walk methods without a suffix do a pre-order
/// traversal while those that traverse in post order have a PostOrder suffix.
template <typename SubClass, typename RetTy = void> class InstWalker {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the InstWalker used to
// walk instructions.
public:
// Generic walk method - allow walk to all instructions in a range.
template <class Iterator> void walk(Iterator Start, Iterator End) {
while (Start != End) {
walk(&(*Start++));
}
}
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
while (Start != End) {
walkPostOrder(&(*Start++));
}
}
// Define walkers for Function and all Function instruction kinds.
void walk(Function *f) {
for (auto &block : *f)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
}
void walkPostOrder(Function *f) {
for (auto it = f->rbegin(), e = f->rend(); it != e; ++it)
static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
}
// Function to walk a instruction.
RetTy walk(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
static_cast<SubClass *>(this)->visitInstruction(s);
for (auto &blockList : s->getBlockLists())
for (auto &block : blockList)
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
}
// Function to walk a instruction in post order DFS.
RetTy walkPostOrder(Instruction *s) {
static_assert(std::is_base_of<InstWalker, SubClass>::value,
"Must pass the derived type to this template!");
for (auto &blockList : s->getBlockLists())
for (auto &block : blockList)
static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
block.end());
static_cast<SubClass *>(this)->visitInstruction(s);
}
//===--------------------------------------------------------------------===//
// Visitation functions... these functions provide default fallbacks in case
// the user does not specify what to do for a particular instruction type.
// The default behavior is to generalize the instruction type to its subtype
// and try visiting the subtype. All of this should be inlined perfectly,
// because there are no virtual functions to get in the way.
// When visiting a specific inst directly during a walk, these methods get
// called. These are typically O(1) complexity and shouldn't be recursively
// processing their descendants in some way. When using RetTy, all of these
// need to be overridden.
void visitInstruction(Instruction *inst) {}
};
} // end namespace mlir
#endif // MLIR_IR_INSTVISITOR_H

View File

@ -613,6 +613,36 @@ public:
return OpClass::isClassFor(this);
}
//===--------------------------------------------------------------------===//
// Instruction Walkers
//===--------------------------------------------------------------------===//
/// Walk the instructions held by this instruction in preorder, calling the
/// callback for each instruction.
void walk(const std::function<void(Instruction *)> &callback);
/// Specialization of walk to only visit operations of 'OpTy'.
template <typename OpTy>
void walk(std::function<void(OpPointer<OpTy>)> callback) {
walk([&](Instruction *inst) {
if (auto op = inst->dyn_cast<OpTy>())
callback(op);
});
}
/// Walk the instructions held by this function in postorder, calling the
/// callback for each instruction.
void walkPostOrder(const std::function<void(Instruction *)> &callback);
/// Specialization of walkPostOrder to only visit operations of 'OpTy'.
template <typename OpTy>
void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
walkPostOrder([&](Instruction *inst) {
if (auto op = inst->dyn_cast<OpTy>())
callback(op);
});
}
//===--------------------------------------------------------------------===//
// Other
//===--------------------------------------------------------------------===//

View File

@ -19,7 +19,6 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
@ -646,32 +645,6 @@ bool AffineForOp::matchingBoundOperandList() const {
return true;
}
void AffineForOp::walk(std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *opInst) { callback(opInst); }
};
Walker w(callback);
w.walk(getInstruction());
}
void AffineForOp::walkPostOrder(std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *opInst) { callback(opInst); }
};
Walker v(callback);
v.walkPostOrder(getInstruction());
}
/// Returns the induction variable for this loop.
Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }

View File

@ -26,7 +26,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@ -38,13 +37,11 @@ using namespace mlir;
namespace {
/// Checks for out of bound memef access subscripts..
struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
struct MemRefBoundCheck : public FunctionPass {
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
PassResult runOnFunction(Function *f) override;
void visitInstruction(Instruction *opInst);
static char passID;
};
@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
return new MemRefBoundCheck();
}
void MemRefBoundCheck::visitInstruction(Instruction *opInst) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
boundCheckLoadOrStoreOp(storeOp);
}
// TODO(bondhugula): do this for DMA ops as well.
}
PassResult MemRefBoundCheck::runOnFunction(Function *f) {
return walk(f), success();
f->walk([](Instruction *opInst) {
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
boundCheckLoadOrStoreOp(loadOp);
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
boundCheckLoadOrStoreOp(storeOp);
}
// TODO(bondhugula): do this for DMA ops as well.
});
return success();
}
static PassRegistration<MemRefBoundCheck>

View File

@ -25,7 +25,6 @@
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/Support/Debug.h"
@ -38,19 +37,13 @@ namespace {
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
/// Checks dependences between all pairs of memref accesses in a Function.
struct MemRefDependenceCheck : public FunctionPass,
InstWalker<MemRefDependenceCheck> {
struct MemRefDependenceCheck : public FunctionPass {
SmallVector<Instruction *, 4> loadsAndStores;
explicit MemRefDependenceCheck()
: FunctionPass(&MemRefDependenceCheck::passID) {}
PassResult runOnFunction(Function *f) override;
void visitInstruction(Instruction *opInst) {
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
loadsAndStores.push_back(opInst);
}
}
static char passID;
};
@ -120,8 +113,13 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) {
// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
// Runs pair-wise dependence checks.
PassResult MemRefDependenceCheck::runOnFunction(Function *f) {
// Collect the loads and stores within the function.
loadsAndStores.clear();
walk(f);
f->walk([&](Instruction *inst) {
if (inst->isa<LoadOp>() || inst->isa<StoreOp>())
loadsAndStores.push_back(inst);
});
checkDependences(loadsAndStores);
return success();
}

View File

@ -15,7 +15,6 @@
// limitations under the License.
// =============================================================================
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSupport.h"
@ -27,16 +26,13 @@
using namespace mlir;
namespace {
struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
struct PrintOpStatsPass : public ModulePass {
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
: ModulePass(&PrintOpStatsPass::passID), os(os) {}
// Prints the resultant operation statistics post iterating over the module.
PassResult runOnModule(Module *m) override;
// Updates the operation statistics for the given instruction.
void visitInstruction(Instruction *inst);
// Print summary of op stats.
void printSummary();
@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
private:
llvm::StringMap<int64_t> opCount;
llvm::raw_ostream &os;
};
} // namespace
@ -52,16 +47,16 @@ private:
char PrintOpStatsPass::passID = 0;
PassResult PrintOpStatsPass::runOnModule(Module *m) {
opCount.clear();
// Compute the operation statistics for each function in the module.
for (auto &fn : *m)
walk(&fn);
fn.walk(
[&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; });
printSummary();
return success();
}
void PrintOpStatsPass::visitInstruction(Instruction *inst) {
++opCount[inst->getName().getStringRef()];
}
void PrintOpStatsPass::printSummary() {
os << "Operations encountered:\n";
os << "-----------------------\n";

View File

@ -19,7 +19,6 @@
#include "mlir/EDSC/Types.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StandardTypes.h"

View File

@ -25,7 +25,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -18,7 +18,6 @@
#include "mlir/IR/Block.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
using namespace mlir;
@ -226,6 +225,34 @@ Block *Block::getSinglePredecessor() {
return it == pred_end() ? firstPred : nullptr;
}
//===----------------------------------------------------------------------===//
// Instruction Walkers
//===----------------------------------------------------------------------===//
void Block::walk(const std::function<void(Instruction *)> &callback) {
walk(begin(), end(), callback);
}
void Block::walk(Block::iterator begin, Block::iterator end,
const std::function<void(Instruction *)> &callback) {
// Walk the instructions within this block.
for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
inst.walk(callback);
}
void Block::walkPostOrder(const std::function<void(Instruction *)> &callback) {
walkPostOrder(begin(), end(), callback);
}
/// Walk the instructions in the specified [begin, end) range of this block
/// in postorder, calling the callback for each operation.
void Block::walkPostOrder(Block::iterator begin, Block::iterator end,
const std::function<void(Instruction *)> &callback) {
// Walk the instructions within this block.
for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
inst.walkPostOrder(callback);
}
//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//
@ -253,37 +280,6 @@ Block *Block::splitBlock(iterator splitBefore) {
return newBB;
}
void Block::walk(std::function<void(Instruction *)> callback) {
walk(begin(), end(), callback);
}
void Block::walk(Block::iterator begin, Block::iterator end,
std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *opInst) { callback(opInst); }
};
Walker w(callback);
w.walk(begin, end);
}
void Block::walkPostOrder(std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *opInst) { callback(opInst); }
};
Walker v(callback);
v.walkPostOrder(begin(), end());
}
//===----------------------------------------------------------------------===//
// BlockList
//===----------------------------------------------------------------------===//
@ -331,25 +327,18 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper,
// Now that each of the blocks have been cloned, go through and remap the
// operands of each of the instructions.
struct Walker : public InstWalker<Walker> {
BlockAndValueMapping &mapper;
Walker(BlockAndValueMapping &mapper) : mapper(mapper) {}
/// Remap the instruction and successor block operands.
void visitInstruction(Instruction *inst) {
for (auto &instOp : inst->getInstOperands())
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
instOp.set(mappedOp);
if (inst->isTerminator())
for (auto &succOp : inst->getBlockOperands())
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp);
}
auto remapOperands = [&](Instruction *inst) {
for (auto &instOp : inst->getInstOperands())
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
instOp.set(mappedOp);
if (inst->isTerminator())
for (auto &succOp : inst->getBlockOperands())
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
succOp.set(mappedOp);
};
Walker v(mapper);
for (auto it = std::next(lastOldBlock), e = dest->end(); it != e; ++it)
v.walk(it->begin(), it->end());
it->walk(remapOperands);
}
BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() {

View File

@ -19,7 +19,6 @@
#include "AttributeListStorage.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
@ -214,28 +213,15 @@ void Function::addEntryBlock() {
entry->addArguments(type.getInputs());
}
void Function::walk(std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *inst) { callback(inst); }
};
Walker v(callback);
v.walk(this);
void Function::walk(const std::function<void(Instruction *)> &callback) {
// Walk each of the blocks within the function.
for (auto &block : getBlocks())
block.walk(callback);
}
void Function::walkPostOrder(std::function<void(Instruction *)> callback) {
struct Walker : public InstWalker<Walker> {
std::function<void(Instruction *)> const &callback;
Walker(std::function<void(Instruction *)> const &callback)
: callback(callback) {}
void visitInstruction(Instruction *inst) { callback(inst); }
};
Walker v(callback);
v.walkPostOrder(this);
void Function::walkPostOrder(
const std::function<void(Instruction *)> &callback) {
// Walk each of the blocks within the function.
for (auto &block : llvm::reverse(getBlocks()))
block.walkPostOrder(callback);
}

View File

@ -22,7 +22,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "llvm/ADT/DenseMap.h"
@ -300,6 +299,35 @@ Function *Instruction::getFunction() const {
return block ? block->getFunction() : nullptr;
}
//===----------------------------------------------------------------------===//
// Instruction Walkers
//===----------------------------------------------------------------------===//
void Instruction::walk(const std::function<void(Instruction *)> &callback) {
// Visit the current instruction.
callback(this);
// Visit any internal instructions.
for (auto &blockList : getBlockLists())
for (auto &block : blockList)
block.walk(callback);
}
void Instruction::walkPostOrder(
const std::function<void(Instruction *)> &callback) {
// Visit any internal instructions.
for (auto &blockList : llvm::reverse(getBlockLists()))
for (auto &block : llvm::reverse(blockList))
block.walkPostOrder(callback);
// Visit the current instruction.
callback(this);
}
//===----------------------------------------------------------------------===//
// Other
//===----------------------------------------------------------------------===//
/// Emit a note about this instruction, reporting up to any diagnostic
/// handlers that may be listening.
void Instruction::emitNote(const Twine &message) const {

View File

@ -26,7 +26,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"

View File

@ -24,7 +24,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Transforms/Passes.h"

View File

@ -27,7 +27,6 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@ -46,10 +45,9 @@ namespace {
// result of any AffineApplyOp). After this composition, AffineApplyOps with no
// remaining uses are erased.
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
struct ComposeAffineMaps : public FunctionPass {
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
PassResult runOnFunction(Function *f) override;
void visitInstruction(Instruction *opInst);
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) {
return inst.isa<AffineApplyOp>();
}
void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
affineApplyOps.push_back(afOp);
}
PassResult ComposeAffineMaps::runOnFunction(Function *f) {
// If needed for future efficiency, reserve space based on a pre-walk.
affineApplyOps.clear();
walk(f);
f->walk<AffineApplyOp>(
[&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
for (auto afOp : affineApplyOps) {
SmallVector<Value *, 8> operands(afOp->getOperands());
FuncBuilder b(afOp->getInstruction());
@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) {
// Erase dead affine apply ops.
affineApplyOps.clear();
walk(f);
f->walk<AffineApplyOp>(
[&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) {
if ((*it)->use_empty()) {
(*it)->erase();

View File

@ -18,7 +18,6 @@
#include "mlir/AffineOps/AffineOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/Utils.h"
@ -27,7 +26,7 @@ using namespace mlir;
namespace {
/// Simple constant folding pass.
struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
struct ConstantFold : public FunctionPass {
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
// All constants in the function post folding.
@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
// Operations that were folded and that need to be erased.
std::vector<Instruction *> opInstsToErase;
bool foldOperation(Instruction *op,
SmallVectorImpl<Value *> &existingConstants);
void visitInstruction(Instruction *op);
void foldInstruction(Instruction *op);
PassResult runOnFunction(Function *f) override;
static char passID;
@ -49,7 +46,7 @@ char ConstantFold::passID = 0;
/// Attempt to fold the specified operation, updating the IR to match. If
/// constants are found, we keep track of them in the existingConstants list.
///
void ConstantFold::visitInstruction(Instruction *op) {
void ConstantFold::foldInstruction(Instruction *op) {
// If this operation is an AffineForOp, then fold the bounds.
if (auto forOp = op->dyn_cast<AffineForOp>()) {
constantFoldBounds(forOp);
@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) {
existingConstants.clear();
opInstsToErase.clear();
walk(f);
f->walk([&](Instruction *inst) { foldInstruction(inst); });
// At this point, these operations are dead, remove them.
// TODO: This is assuming that all constant foldable operations have no

View File

@ -28,7 +28,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/LoopUtils.h"
@ -111,22 +110,23 @@ namespace {
// LoopNestStateCollector walks loop nests and collects load and store
// operations, and whether or not an IfInst was encountered in the loop nest.
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
public:
struct LoopNestStateCollector {
SmallVector<OpPointer<AffineForOp>, 4> forOps;
SmallVector<Instruction *, 4> loadOpInsts;
SmallVector<Instruction *, 4> storeOpInsts;
bool hasNonForRegion = false;
void visitInstruction(Instruction *opInst) {
if (opInst->isa<AffineForOp>())
forOps.push_back(opInst->cast<AffineForOp>());
else if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
else if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
void collect(Instruction *instToWalk) {
instToWalk->walk([&](Instruction *opInst) {
if (opInst->isa<AffineForOp>())
forOps.push_back(opInst->cast<AffineForOp>());
else if (opInst->getNumBlockLists() != 0)
hasNonForRegion = true;
else if (opInst->isa<LoadOp>())
loadOpInsts.push_back(opInst);
else if (opInst->isa<StoreOp>())
storeOpInsts.push_back(opInst);
});
}
};
@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) {
// Create graph node 'id' to represent top-level 'forOp' and record
// all loads and store accesses it contains.
LoopNestStateCollector collector;
collector.walk(&inst);
collector.collect(&inst);
// Return false if a non 'for' region was found (not currently supported).
if (collector.hasNonForRegion)
return false;
@ -606,41 +606,39 @@ struct LoopNestStats {
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
// trip count and operation count statistics and records them in 'stats'.
class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
public:
struct LoopNestStatsCollector {
LoopNestStats *stats;
bool hasLoopWithNonConstTripCount = false;
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
void collect(Instruction *inst) {
inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
auto *forInst = forOp->getInstruction();
auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentInst].push_back(forOp);
}
auto *forInst = forOp->getInstruction();
auto *parentInst = forOp->getInstruction()->getParentInst();
if (parentInst != nullptr) {
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
// Add mapping to 'forOp' from its parent AffineForOp.
stats->loopMap[parentInst].push_back(forOp);
}
// Record the number of op instructions in the body of 'forOp'.
unsigned count = 0;
stats->opCountMap[forInst] = 0;
for (auto &inst : *forOp->getBody()) {
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
// Record trip count for 'forOp'. Set flag if trip count is not constant.
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
hasLoopWithNonConstTripCount = true;
return;
}
stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
// Record the number of op instructions in the body of 'forOp'.
unsigned count = 0;
stats->opCountMap[forInst] = 0;
for (auto &inst : *forOp->getBody()) {
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
++count;
}
stats->opCountMap[forInst] = count;
// Record trip count for 'forOp'. Set flag if trip count is not
// constant.
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
if (!maybeConstTripCount.hasValue()) {
hasLoopWithNonConstTripCount = true;
return;
}
stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
});
}
};
@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
// Walk src loop nest and collect stats.
LoopNestStats srcLoopNestStats;
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (srcStatsCollector.hasLoopWithNonConstTripCount)
return false;
@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
LoopNestStats dstLoopNestStats;
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
// Currently only constant trip count loop nests are supported.
if (dstStatsCollector.hasLoopWithNonConstTripCount)
return false;
@ -1474,7 +1472,7 @@ public:
// Collect slice loop stats.
LoopNestStateCollector sliceCollector;
sliceCollector.walk(sliceLoopNest->getInstruction());
sliceCollector.collect(sliceLoopNest->getInstruction());
// Promote single iteration slice loops to single IV value.
for (auto forOp : sliceCollector.forOps) {
promoteIfSingleIteration(forOp);
@ -1498,7 +1496,7 @@ public:
// Collect dst loop stats after memref privatizaton transformation.
LoopNestStateCollector dstLoopCollector;
dstLoopCollector.walk(dstAffineForOp->getInstruction());
dstLoopCollector.collect(dstAffineForOp->getInstruction());
// Add new load ops to current Node load op list 'loads' to
// continue fusing based on new operands.

View File

@ -27,7 +27,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@ -95,15 +94,16 @@ char LoopUnroll::passID = 0;
PassResult LoopUnroll::runOnFunction(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
public:
struct InnermostLoopGatherer {
// Store innermost loops as we walk.
std::vector<OpPointer<AffineForOp>> loops;
// This method specialized to encode custom return logic.
using InstListType = llvm::iplist<Instruction>;
bool walkPostOrder(InstListType::iterator Start,
InstListType::iterator End) {
void walkPostOrder(Function *f) {
for (auto &b : *f)
walkPostOrder(b.begin(), b.end());
}
bool walkPostOrder(Block::iterator Start, Block::iterator End) {
bool hasInnerLoops = false;
// We need to walk all elements since all innermost loops need to be
// gathered as opposed to determining whether this list has any inner
@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
hasInnerLoops |= walkPostOrder(&(*Start++));
return hasInnerLoops;
}
bool walkPostOrder(Instruction *opInst) {
bool hasInnerLoops = false;
for (auto &blockList : opInst->getBlockLists())
@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
}
return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
// need to use the derived class method above. CRTP doesn't allow it, and
// the compiler error resulting from it is also misleading.
using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
public:
// Store short loops as we walk.
std::vector<OpPointer<AffineForOp>> loops;
const unsigned minTripCount;
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
void visitInstruction(Instruction *opInst) {
auto forOp = opInst->dyn_cast<AffineForOp>();
if (!forOp)
return;
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
loops.push_back(forOp);
}
};
if (clUnrollFull.getNumOccurrences() > 0 &&
clUnrollFullThreshold.getNumOccurrences() > 0) {
ShortLoopGatherer slg(clUnrollFullThreshold);
// Do a post order walk so that loops are gathered from innermost to
// outermost (or else unrolling an outer one may delete gathered inner
// ones).
slg.walkPostOrder(f);
auto &loops = slg.loops;
// Store short loops as we walk.
std::vector<OpPointer<AffineForOp>> loops;
// Gathers all loops with trip count <= minTripCount. Do a post order walk
// so that loops are gathered from innermost to outermost (or else unrolling
// an outer one may delete gathered inner ones).
f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
loops.push_back(forOp);
});
for (auto forOp : loops)
loopUnrollFull(forOp);
return success();

View File

@ -50,7 +50,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/Transforms/LoopUtils.h"
#include "llvm/ADT/DenseMap.h"
@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
// Gathers all maximal sub-blocks of instructions that do not themselves
// include a for inst (a instruction could have a descendant for inst though
// in its tree).
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
public:
using InstListType = llvm::iplist<Instruction>;
using InstWalker<JamBlockGatherer>::walk;
struct JamBlockGatherer {
// Store iterators to the first and last inst of each sub-block found.
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
// This is a linear time walk.
void walk(InstListType::iterator Start, InstListType::iterator End) {
for (auto it = Start; it != End;) {
void walk(Instruction *inst) {
for (auto &blockList : inst->getBlockLists())
for (auto &block : blockList)
walk(block);
}
void walk(Block &block) {
for (auto it = block.begin(), e = block.end(); it != e;) {
auto subBlockStart = it;
while (it != End && !it->isa<AffineForOp>())
while (it != e && !it->isa<AffineForOp>())
++it;
if (it != subBlockStart)
subBlocks.push_back({subBlockStart, std::prev(it)});
// Process all for insts that appear next.
while (it != End && it->isa<AffineForOp>())
while (it != e && it->isa<AffineForOp>())
walk(&*it++);
}
}

View File

@ -25,7 +25,6 @@
#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/Analysis/Dominance.h"
#include "mlir/Analysis/Utils.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/Pass.h"
#include "mlir/StandardOps/StandardOps.h"
#include "mlir/Transforms/Passes.h"
@ -70,12 +69,12 @@ namespace {
// currently only eliminates the stores only if no other loads/uses (other
// than dealloc) remain.
//
struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
struct MemRefDataFlowOpt : public FunctionPass {
explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {}
PassResult runOnFunction(Function *f) override;
void visitInstruction(Instruction *opInst);
void forwardStoreToLoad(OpPointer<LoadOp> loadOp);
// A list of memref's that are potentially dead / could be eliminated.
SmallPtrSet<Value *, 4> memrefsToErase;
@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
// This is a straightforward implementation not optimized for speed. Optimize
// this in the future if needed.
void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) {
Instruction *lastWriteStoreOp = nullptr;
auto loadOp = opInst->dyn_cast<LoadOp>();
if (!loadOp)
return;
Instruction *loadOpInst = opInst;
Instruction *loadOpInst = loadOp->getInstruction();
// First pass over the use list to get minimum number of surrounding
// loops common between the load op and the store op, with min taken across
@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
memrefsToErase.clear();
// Walk all load's and perform load/store forwarding.
walk(f);
f->walk<LoadOp>(
[&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); });
// Erase all load op's whose results were replaced with store fwd'ed ones.
for (auto *loadOp : loadOpsToErase) {

View File

@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
// deleted and replaced by a prologue, a new steady-state loop and an
// epilogue).
forOps.clear();
f->walkPostOrder([&](Instruction *opInst) {
if (auto forOp = opInst->dyn_cast<AffineForOp>())
forOps.push_back(forOp);
});
f->walkPostOrder<AffineForOp>(
[&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); });
bool ret = false;
for (auto forOp : forOps) {
ret = ret | runOnAffineForOp(forOp);

View File

@ -28,7 +28,6 @@
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instruction.h"
#include "mlir/StandardOps/StandardOps.h"
#include "llvm/ADT/DenseMap.h"
@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
/// their body into the containing Block.
void mlir::promoteSingleIterationLoops(Function *f) {
// Gathers all innermost loops through a post order pruned walk.
f->walkPostOrder([](Instruction *inst) {
if (auto forOp = inst->dyn_cast<AffineForOp>())
promoteIfSingleIteration(forOp);
});
f->walkPostOrder<AffineForOp>(
[](OpPointer<AffineForOp> forOp) { promoteIfSingleIteration(forOp); });
}
/// Generates a 'for' inst with the specified lower and upper bounds while