Remove InstWalker and move all instruction walking to the api facilities on Function/Block/Instruction.
PiperOrigin-RevId: 232388113
This commit is contained in:
parent
c9ad4621ce
commit
bf9c381d1d
|
@ -229,14 +229,6 @@ public:
|
|||
/// (same operands in the same order).
|
||||
bool matchingBoundOperandList() const;
|
||||
|
||||
/// Walk the operation instructions in the 'for' instruction in preorder,
|
||||
/// calling the callback for each operation.
|
||||
void walk(std::function<void(Instruction *)> callback);
|
||||
|
||||
/// Walk the operation instructions in the 'for' instruction in postorder,
|
||||
/// calling the callback for each operation.
|
||||
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||
|
||||
private:
|
||||
friend class Instruction;
|
||||
explicit AffineForOp(const Instruction *state) : Op(state) {}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
#ifndef MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
|
||||
#define MLIR_ANALYSIS_MLFUNCTIONMATCHER_H_
|
||||
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "llvm/Support/Allocator.h"
|
||||
|
||||
namespace mlir {
|
||||
|
@ -76,7 +76,7 @@ private:
|
|||
ArrayRef<NestedMatch> matchedChildren;
|
||||
};
|
||||
|
||||
/// A NestedPattern is a nested InstWalker that:
|
||||
/// A NestedPattern is a nested instruction walker that:
|
||||
/// 1. recursively matches a substructure in the tree;
|
||||
/// 2. uses a filter function to refine matches with extra semantic
|
||||
/// constraints (passed via a lambda of type FilterFunctionType);
|
||||
|
@ -92,8 +92,8 @@ private:
|
|||
///
|
||||
/// The NestedMatches captured in the IR can grow large, especially after
|
||||
/// aggressive unrolling. As experience has shown, it is generally better to use
|
||||
/// a plain InstWalker to match flat patterns but the current implementation is
|
||||
/// competitive nonetheless.
|
||||
/// a plain walk over instructions to match flat patterns but the current
|
||||
/// implementation is competitive nonetheless.
|
||||
using FilterFunctionType = std::function<bool(const Instruction &)>;
|
||||
static bool defaultFilterFunction(const Instruction &) { return true; };
|
||||
struct NestedPattern {
|
||||
|
@ -102,16 +102,14 @@ struct NestedPattern {
|
|||
NestedPattern(const NestedPattern &) = default;
|
||||
NestedPattern &operator=(const NestedPattern &) = default;
|
||||
|
||||
/// Returns all the top-level matches in `function`.
|
||||
void match(Function *function, SmallVectorImpl<NestedMatch> *matches) {
|
||||
State state(*this, matches);
|
||||
state.walkPostOrder(function);
|
||||
/// Returns all the top-level matches in `func`.
|
||||
void match(Function *func, SmallVectorImpl<NestedMatch> *matches) {
|
||||
func->walkPostOrder([&](Instruction *inst) { matchOne(inst, matches); });
|
||||
}
|
||||
|
||||
/// Returns all the top-level matches in `inst`.
|
||||
void match(Instruction *inst, SmallVectorImpl<NestedMatch> *matches) {
|
||||
State state(*this, matches);
|
||||
state.walkPostOrder(inst);
|
||||
inst->walkPostOrder([&](Instruction *child) { matchOne(child, matches); });
|
||||
}
|
||||
|
||||
/// Returns the depth of the pattern.
|
||||
|
@ -120,22 +118,8 @@ struct NestedPattern {
|
|||
private:
|
||||
friend class NestedPatternContext;
|
||||
friend class NestedMatch;
|
||||
friend class InstWalker<NestedPattern>;
|
||||
friend struct State;
|
||||
|
||||
/// Helper state that temporarily holds matches for the next level of nesting.
|
||||
struct State : public InstWalker<State> {
|
||||
State(NestedPattern &pattern, SmallVectorImpl<NestedMatch> *matches)
|
||||
: pattern(pattern), matches(matches) {}
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
pattern.matchOne(opInst, matches);
|
||||
}
|
||||
|
||||
private:
|
||||
NestedPattern &pattern;
|
||||
SmallVectorImpl<NestedMatch> *matches;
|
||||
};
|
||||
|
||||
/// Underlying global bump allocator managed by a NestedPatternContext.
|
||||
static llvm::BumpPtrAllocator *&allocator();
|
||||
|
||||
|
@ -153,8 +137,9 @@ private:
|
|||
/// without switching on the type of the Instruction. The idea is that a
|
||||
/// NestedPattern first checks if it matches locally and then recursively
|
||||
/// applies its nested matchers to its elem->nested. Since we want to rely on
|
||||
/// the InstWalker impl rather than duplicate its the logic, we allow an
|
||||
/// off-by-one traversal to account for the fact that we write:
|
||||
/// the existing instruction walking functionality rather than duplicate
|
||||
/// it, we allow an off-by-one traversal to account for the fact that we
|
||||
/// write:
|
||||
///
|
||||
/// void match(Instruction *elem) {
|
||||
/// for (auto &c : getNestedPatterns()) {
|
||||
|
|
|
@ -287,6 +287,28 @@ public:
|
|||
succ_iterator succ_end();
|
||||
llvm::iterator_range<succ_iterator> getSuccessors();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instruction Walkers
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Walk the instructions of this block in preorder, calling the callback for
|
||||
/// each operation.
|
||||
void walk(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Walk the instructions in the specified [begin, end) range of
|
||||
/// this block, calling the callback for each operation.
|
||||
void walk(Block::iterator begin, Block::iterator end,
|
||||
const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Walk the instructions in this block in postorder, calling the callback for
|
||||
/// each operation.
|
||||
void walkPostOrder(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Walk the instructions in the specified [begin, end) range of this block
|
||||
/// in postorder, calling the callback for each operation.
|
||||
void walkPostOrder(Block::iterator begin, Block::iterator end,
|
||||
const std::function<void(Instruction *)> &callback);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -311,19 +333,6 @@ public:
|
|||
return &Block::instructions;
|
||||
}
|
||||
|
||||
/// Walk the instructions of this block in preorder, calling the callback for
|
||||
/// each operation.
|
||||
void walk(std::function<void(Instruction *)> callback);
|
||||
|
||||
/// Walk the instructions in this block in postorder, calling the callback for
|
||||
/// each operation.
|
||||
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||
|
||||
/// Walk the instructions in the specified [begin, end) range of
|
||||
/// this block, calling the callback for each operation.
|
||||
void walk(Block::iterator begin, Block::iterator end,
|
||||
std::function<void(Instruction *)> callback);
|
||||
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
@ -39,6 +40,7 @@ class FunctionType;
|
|||
class MLIRContext;
|
||||
class Module;
|
||||
template <typename ObjectType, typename ElementType> class ArgumentIterator;
|
||||
template <typename T> class OpPointer;
|
||||
|
||||
/// NamedAttribute is used for function attribute lists, it holds an
|
||||
/// identifier for the name and a value for the attribute. The attribute
|
||||
|
@ -115,13 +117,35 @@ public:
|
|||
Block &front() { return blocks.front(); }
|
||||
const Block &front() const { return const_cast<Function *>(this)->front(); }
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instruction Walkers
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Walk the instructions in the function in preorder, calling the callback
|
||||
/// for each instruction or operation.
|
||||
void walk(std::function<void(Instruction *)> callback);
|
||||
/// for each instruction.
|
||||
void walk(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Specialization of walk to only visit operations of 'OpTy'.
|
||||
template <typename OpTy>
|
||||
void walk(std::function<void(OpPointer<OpTy>)> callback) {
|
||||
walk([&](Instruction *inst) {
|
||||
if (auto op = inst->dyn_cast<OpTy>())
|
||||
callback(op);
|
||||
});
|
||||
}
|
||||
|
||||
/// Walk the instructions in the function in postorder, calling the callback
|
||||
/// for each instruction or operation.
|
||||
void walkPostOrder(std::function<void(Instruction *)> callback);
|
||||
/// for each instruction.
|
||||
void walkPostOrder(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Specialization of walkPostOrder to only visit operations of 'OpTy'.
|
||||
template <typename OpTy>
|
||||
void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
|
||||
walkPostOrder([&](Instruction *inst) {
|
||||
if (auto op = inst->dyn_cast<OpTy>())
|
||||
callback(op);
|
||||
});
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Arguments
|
||||
|
|
|
@ -1,140 +0,0 @@
|
|||
//===- InstVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines the base classes for Function's instruction visitors and
|
||||
// walkers. A visit is a O(1) operation that visits just the node in question. A
|
||||
// walk visits the node it's called on as well as the node's descendants.
|
||||
//
|
||||
// Instruction visitors/walkers are used when you want to perform different
|
||||
// actions for different kinds of instructions without having to use lots of
|
||||
// casts and a big switch instruction.
|
||||
//
|
||||
// To define your own visitor/walker, inherit from these classes, specifying
|
||||
// your new type for the 'SubClass' template parameter, and "override" visitXXX
|
||||
// functions in your class. This class is defined in terms of statically
|
||||
// resolved overloading, not virtual functions.
|
||||
//
|
||||
// For example, here is a walker that counts the number of for loops in an
|
||||
// Function.
|
||||
//
|
||||
// /// Declare the class. Note that we derive from InstWalker instantiated
|
||||
// /// with _our new subclasses_ type.
|
||||
// struct LoopCounter : public InstWalker<LoopCounter> {
|
||||
// unsigned numLoops;
|
||||
// LoopCounter() : numLoops(0) {}
|
||||
// void visitForInst(ForInst &fs) { ++numLoops; }
|
||||
// };
|
||||
//
|
||||
// And this class would be used like this:
|
||||
// LoopCounter lc;
|
||||
// lc.walk(function);
|
||||
// numLoops = lc.numLoops;
|
||||
//
|
||||
// There are 'visit' methods for Instruction and Function, which recursively
|
||||
// process all contained instructions.
|
||||
//
|
||||
// Note that if you don't implement visitXXX for some instruction type,
|
||||
// the visitXXX method for Instruction superclass will be invoked.
|
||||
//
|
||||
// The optional second template argument specifies the type that instruction
|
||||
// visitation functions should return. If you specify this, you *MUST* provide
|
||||
// an implementation of every visit<#Instruction>(InstType *).
|
||||
//
|
||||
// Note that these classes are specifically designed as a template to avoid
|
||||
// virtual function call overhead.
|
||||
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_IR_INSTVISITOR_H
|
||||
#define MLIR_IR_INSTVISITOR_H
|
||||
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
|
||||
namespace mlir {
|
||||
/// Base class for instruction walkers. A walker can traverse depth first in
|
||||
/// pre-order or post order. The walk methods without a suffix do a pre-order
|
||||
/// traversal while those that traverse in post order have a PostOrder suffix.
|
||||
template <typename SubClass, typename RetTy = void> class InstWalker {
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Interface code - This is the public interface of the InstWalker used to
|
||||
// walk instructions.
|
||||
|
||||
public:
|
||||
// Generic walk method - allow walk to all instructions in a range.
|
||||
template <class Iterator> void walk(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walk(&(*Start++));
|
||||
}
|
||||
}
|
||||
template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
|
||||
while (Start != End) {
|
||||
walkPostOrder(&(*Start++));
|
||||
}
|
||||
}
|
||||
|
||||
// Define walkers for Function and all Function instruction kinds.
|
||||
void walk(Function *f) {
|
||||
for (auto &block : *f)
|
||||
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
|
||||
}
|
||||
|
||||
void walkPostOrder(Function *f) {
|
||||
for (auto it = f->rbegin(), e = f->rend(); it != e; ++it)
|
||||
static_cast<SubClass *>(this)->walkPostOrder(it->begin(), it->end());
|
||||
}
|
||||
|
||||
// Function to walk a instruction.
|
||||
RetTy walk(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
|
||||
static_cast<SubClass *>(this)->visitInstruction(s);
|
||||
for (auto &blockList : s->getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
static_cast<SubClass *>(this)->walk(block.begin(), block.end());
|
||||
}
|
||||
|
||||
// Function to walk a instruction in post order DFS.
|
||||
RetTy walkPostOrder(Instruction *s) {
|
||||
static_assert(std::is_base_of<InstWalker, SubClass>::value,
|
||||
"Must pass the derived type to this template!");
|
||||
for (auto &blockList : s->getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
static_cast<SubClass *>(this)->walkPostOrder(block.begin(),
|
||||
block.end());
|
||||
static_cast<SubClass *>(this)->visitInstruction(s);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Visitation functions... these functions provide default fallbacks in case
|
||||
// the user does not specify what to do for a particular instruction type.
|
||||
// The default behavior is to generalize the instruction type to its subtype
|
||||
// and try visiting the subtype. All of this should be inlined perfectly,
|
||||
// because there are no virtual functions to get in the way.
|
||||
|
||||
// When visiting a specific inst directly during a walk, these methods get
|
||||
// called. These are typically O(1) complexity and shouldn't be recursively
|
||||
// processing their descendants in some way. When using RetTy, all of these
|
||||
// need to be overridden.
|
||||
void visitInstruction(Instruction *inst) {}
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_INSTVISITOR_H
|
|
@ -613,6 +613,36 @@ public:
|
|||
return OpClass::isClassFor(this);
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Instruction Walkers
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Walk the instructions held by this instruction in preorder, calling the
|
||||
/// callback for each instruction.
|
||||
void walk(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Specialization of walk to only visit operations of 'OpTy'.
|
||||
template <typename OpTy>
|
||||
void walk(std::function<void(OpPointer<OpTy>)> callback) {
|
||||
walk([&](Instruction *inst) {
|
||||
if (auto op = inst->dyn_cast<OpTy>())
|
||||
callback(op);
|
||||
});
|
||||
}
|
||||
|
||||
/// Walk the instructions held by this function in postorder, calling the
|
||||
/// callback for each instruction.
|
||||
void walkPostOrder(const std::function<void(Instruction *)> &callback);
|
||||
|
||||
/// Specialization of walkPostOrder to only visit operations of 'OpTy'.
|
||||
template <typename OpTy>
|
||||
void walkPostOrder(std::function<void(OpPointer<OpTy>)> callback) {
|
||||
walkPostOrder([&](Instruction *inst) {
|
||||
if (auto op = inst->dyn_cast<OpTy>())
|
||||
callback(op);
|
||||
});
|
||||
}
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
@ -646,32 +645,6 @@ bool AffineForOp::matchingBoundOperandList() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
void AffineForOp::walk(std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker w(callback);
|
||||
w.walk(getInstruction());
|
||||
}
|
||||
|
||||
void AffineForOp::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
v.walkPostOrder(getInstruction());
|
||||
}
|
||||
|
||||
/// Returns the induction variable for this loop.
|
||||
Value *AffineForOp::getInductionVar() { return getBody()->getArgument(0); }
|
||||
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -38,13 +37,11 @@ using namespace mlir;
|
|||
namespace {
|
||||
|
||||
/// Checks for out of bound memef access subscripts..
|
||||
struct MemRefBoundCheck : public FunctionPass, InstWalker<MemRefBoundCheck> {
|
||||
struct MemRefBoundCheck : public FunctionPass {
|
||||
explicit MemRefBoundCheck() : FunctionPass(&MemRefBoundCheck::passID) {}
|
||||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
void visitInstruction(Instruction *opInst);
|
||||
|
||||
static char passID;
|
||||
};
|
||||
|
||||
|
@ -56,17 +53,16 @@ FunctionPass *mlir::createMemRefBoundCheckPass() {
|
|||
return new MemRefBoundCheck();
|
||||
}
|
||||
|
||||
void MemRefBoundCheck::visitInstruction(Instruction *opInst) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
boundCheckLoadOrStoreOp(loadOp);
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
boundCheckLoadOrStoreOp(storeOp);
|
||||
}
|
||||
// TODO(bondhugula): do this for DMA ops as well.
|
||||
}
|
||||
|
||||
PassResult MemRefBoundCheck::runOnFunction(Function *f) {
|
||||
return walk(f), success();
|
||||
f->walk([](Instruction *opInst) {
|
||||
if (auto loadOp = opInst->dyn_cast<LoadOp>()) {
|
||||
boundCheckLoadOrStoreOp(loadOp);
|
||||
} else if (auto storeOp = opInst->dyn_cast<StoreOp>()) {
|
||||
boundCheckLoadOrStoreOp(storeOp);
|
||||
}
|
||||
// TODO(bondhugula): do this for DMA ops as well.
|
||||
});
|
||||
return success();
|
||||
}
|
||||
|
||||
static PassRegistration<MemRefBoundCheck>
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
@ -38,19 +37,13 @@ namespace {
|
|||
|
||||
// TODO(andydavis) Add common surrounding loop depth-wise dependence checks.
|
||||
/// Checks dependences between all pairs of memref accesses in a Function.
|
||||
struct MemRefDependenceCheck : public FunctionPass,
|
||||
InstWalker<MemRefDependenceCheck> {
|
||||
struct MemRefDependenceCheck : public FunctionPass {
|
||||
SmallVector<Instruction *, 4> loadsAndStores;
|
||||
explicit MemRefDependenceCheck()
|
||||
: FunctionPass(&MemRefDependenceCheck::passID) {}
|
||||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
if (opInst->isa<LoadOp>() || opInst->isa<StoreOp>()) {
|
||||
loadsAndStores.push_back(opInst);
|
||||
}
|
||||
}
|
||||
static char passID;
|
||||
};
|
||||
|
||||
|
@ -120,8 +113,13 @@ static void checkDependences(ArrayRef<Instruction *> loadsAndStores) {
|
|||
// Walks the Function 'f' adding load and store ops to 'loadsAndStores'.
|
||||
// Runs pair-wise dependence checks.
|
||||
PassResult MemRefDependenceCheck::runOnFunction(Function *f) {
|
||||
// Collect the loads and stores within the function.
|
||||
loadsAndStores.clear();
|
||||
walk(f);
|
||||
f->walk([&](Instruction *inst) {
|
||||
if (inst->isa<LoadOp>() || inst->isa<StoreOp>())
|
||||
loadsAndStores.push_back(inst);
|
||||
});
|
||||
|
||||
checkDependences(loadsAndStores);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
|
@ -27,16 +26,13 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
|
||||
struct PrintOpStatsPass : public ModulePass {
|
||||
explicit PrintOpStatsPass(llvm::raw_ostream &os = llvm::errs())
|
||||
: ModulePass(&PrintOpStatsPass::passID), os(os) {}
|
||||
|
||||
// Prints the resultant operation statistics post iterating over the module.
|
||||
PassResult runOnModule(Module *m) override;
|
||||
|
||||
// Updates the operation statistics for the given instruction.
|
||||
void visitInstruction(Instruction *inst);
|
||||
|
||||
// Print summary of op stats.
|
||||
void printSummary();
|
||||
|
||||
|
@ -44,7 +40,6 @@ struct PrintOpStatsPass : public ModulePass, InstWalker<PrintOpStatsPass> {
|
|||
|
||||
private:
|
||||
llvm::StringMap<int64_t> opCount;
|
||||
|
||||
llvm::raw_ostream &os;
|
||||
};
|
||||
} // namespace
|
||||
|
@ -52,16 +47,16 @@ private:
|
|||
char PrintOpStatsPass::passID = 0;
|
||||
|
||||
PassResult PrintOpStatsPass::runOnModule(Module *m) {
|
||||
opCount.clear();
|
||||
|
||||
// Compute the operation statistics for each function in the module.
|
||||
for (auto &fn : *m)
|
||||
walk(&fn);
|
||||
fn.walk(
|
||||
[&](Instruction *inst) { ++opCount[inst->getName().getStringRef()]; });
|
||||
printSummary();
|
||||
return success();
|
||||
}
|
||||
|
||||
void PrintOpStatsPass::visitInstruction(Instruction *inst) {
|
||||
++opCount[inst->getName().getStringRef()];
|
||||
}
|
||||
|
||||
void PrintOpStatsPass::printSummary() {
|
||||
os << "Operations encountered:\n";
|
||||
os << "-----------------------\n";
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "mlir/EDSC/Types.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include "mlir/IR/Block.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -226,6 +225,34 @@ Block *Block::getSinglePredecessor() {
|
|||
return it == pred_end() ? firstPred : nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Instruction Walkers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Block::walk(const std::function<void(Instruction *)> &callback) {
|
||||
walk(begin(), end(), callback);
|
||||
}
|
||||
|
||||
void Block::walk(Block::iterator begin, Block::iterator end,
|
||||
const std::function<void(Instruction *)> &callback) {
|
||||
// Walk the instructions within this block.
|
||||
for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
||||
inst.walk(callback);
|
||||
}
|
||||
|
||||
void Block::walkPostOrder(const std::function<void(Instruction *)> &callback) {
|
||||
walkPostOrder(begin(), end(), callback);
|
||||
}
|
||||
|
||||
/// Walk the instructions in the specified [begin, end) range of this block
|
||||
/// in postorder, calling the callback for each operation.
|
||||
void Block::walkPostOrder(Block::iterator begin, Block::iterator end,
|
||||
const std::function<void(Instruction *)> &callback) {
|
||||
// Walk the instructions within this block.
|
||||
for (auto &inst : llvm::make_early_inc_range(llvm::make_range(begin, end)))
|
||||
inst.walkPostOrder(callback);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -253,37 +280,6 @@ Block *Block::splitBlock(iterator splitBefore) {
|
|||
return newBB;
|
||||
}
|
||||
|
||||
void Block::walk(std::function<void(Instruction *)> callback) {
|
||||
walk(begin(), end(), callback);
|
||||
}
|
||||
|
||||
void Block::walk(Block::iterator begin, Block::iterator end,
|
||||
std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker w(callback);
|
||||
w.walk(begin, end);
|
||||
}
|
||||
|
||||
void Block::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) { callback(opInst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
v.walkPostOrder(begin(), end());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BlockList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -331,25 +327,18 @@ void BlockList::cloneInto(BlockList *dest, BlockAndValueMapping &mapper,
|
|||
|
||||
// Now that each of the blocks have been cloned, go through and remap the
|
||||
// operands of each of the instructions.
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
BlockAndValueMapping &mapper;
|
||||
Walker(BlockAndValueMapping &mapper) : mapper(mapper) {}
|
||||
|
||||
/// Remap the instruction and successor block operands.
|
||||
void visitInstruction(Instruction *inst) {
|
||||
for (auto &instOp : inst->getInstOperands())
|
||||
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
|
||||
instOp.set(mappedOp);
|
||||
if (inst->isTerminator())
|
||||
for (auto &succOp : inst->getBlockOperands())
|
||||
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
|
||||
succOp.set(mappedOp);
|
||||
}
|
||||
auto remapOperands = [&](Instruction *inst) {
|
||||
for (auto &instOp : inst->getInstOperands())
|
||||
if (auto *mappedOp = mapper.lookupOrNull(instOp.get()))
|
||||
instOp.set(mappedOp);
|
||||
if (inst->isTerminator())
|
||||
for (auto &succOp : inst->getBlockOperands())
|
||||
if (auto *mappedOp = mapper.lookupOrNull(succOp.get()))
|
||||
succOp.set(mappedOp);
|
||||
};
|
||||
|
||||
Walker v(mapper);
|
||||
for (auto it = std::next(lastOldBlock), e = dest->end(); it != e; ++it)
|
||||
v.walk(it->begin(), it->end());
|
||||
it->walk(remapOperands);
|
||||
}
|
||||
|
||||
BlockList *llvm::ilist_traits<::mlir::Block>::getContainingBlockList() {
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
#include "AttributeListStorage.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
@ -214,28 +213,15 @@ void Function::addEntryBlock() {
|
|||
entry->addArguments(type.getInputs());
|
||||
}
|
||||
|
||||
void Function::walk(std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *inst) { callback(inst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
v.walk(this);
|
||||
void Function::walk(const std::function<void(Instruction *)> &callback) {
|
||||
// Walk each of the blocks within the function.
|
||||
for (auto &block : getBlocks())
|
||||
block.walk(callback);
|
||||
}
|
||||
|
||||
void Function::walkPostOrder(std::function<void(Instruction *)> callback) {
|
||||
struct Walker : public InstWalker<Walker> {
|
||||
std::function<void(Instruction *)> const &callback;
|
||||
Walker(std::function<void(Instruction *)> const &callback)
|
||||
: callback(callback) {}
|
||||
|
||||
void visitInstruction(Instruction *inst) { callback(inst); }
|
||||
};
|
||||
|
||||
Walker v(callback);
|
||||
v.walkPostOrder(this);
|
||||
void Function::walkPostOrder(
|
||||
const std::function<void(Instruction *)> &callback) {
|
||||
// Walk each of the blocks within the function.
|
||||
for (auto &block : llvm::reverse(getBlocks()))
|
||||
block.walkPostOrder(callback);
|
||||
}
|
||||
|
|
|
@ -22,7 +22,6 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -300,6 +299,35 @@ Function *Instruction::getFunction() const {
|
|||
return block ? block->getFunction() : nullptr;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Instruction Walkers
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void Instruction::walk(const std::function<void(Instruction *)> &callback) {
|
||||
// Visit the current instruction.
|
||||
callback(this);
|
||||
|
||||
// Visit any internal instructions.
|
||||
for (auto &blockList : getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
block.walk(callback);
|
||||
}
|
||||
|
||||
void Instruction::walkPostOrder(
|
||||
const std::function<void(Instruction *)> &callback) {
|
||||
// Visit any internal instructions.
|
||||
for (auto &blockList : llvm::reverse(getBlockLists()))
|
||||
for (auto &block : llvm::reverse(blockList))
|
||||
block.walkPostOrder(callback);
|
||||
|
||||
// Visit the current instruction.
|
||||
callback(this);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Other
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Emit a note about this instruction, reporting up to any diagnostic
|
||||
/// handlers that may be listening.
|
||||
void Instruction::emitNote(const Twine &message) const {
|
||||
|
|
|
@ -26,7 +26,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
|
|
@ -24,7 +24,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -46,10 +45,9 @@ namespace {
|
|||
// result of any AffineApplyOp). After this composition, AffineApplyOps with no
|
||||
// remaining uses are erased.
|
||||
// TODO(andydavis) Remove this when Chris adds instruction combiner pass.
|
||||
struct ComposeAffineMaps : public FunctionPass, InstWalker<ComposeAffineMaps> {
|
||||
struct ComposeAffineMaps : public FunctionPass {
|
||||
explicit ComposeAffineMaps() : FunctionPass(&ComposeAffineMaps::passID) {}
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
void visitInstruction(Instruction *opInst);
|
||||
|
||||
SmallVector<OpPointer<AffineApplyOp>, 8> affineApplyOps;
|
||||
|
||||
|
@ -68,15 +66,11 @@ static bool affineApplyOp(const Instruction &inst) {
|
|||
return inst.isa<AffineApplyOp>();
|
||||
}
|
||||
|
||||
void ComposeAffineMaps::visitInstruction(Instruction *opInst) {
|
||||
if (auto afOp = opInst->dyn_cast<AffineApplyOp>())
|
||||
affineApplyOps.push_back(afOp);
|
||||
}
|
||||
|
||||
PassResult ComposeAffineMaps::runOnFunction(Function *f) {
|
||||
// If needed for future efficiency, reserve space based on a pre-walk.
|
||||
affineApplyOps.clear();
|
||||
walk(f);
|
||||
f->walk<AffineApplyOp>(
|
||||
[&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
|
||||
for (auto afOp : affineApplyOps) {
|
||||
SmallVector<Value *, 8> operands(afOp->getOperands());
|
||||
FuncBuilder b(afOp->getInstruction());
|
||||
|
@ -87,7 +81,8 @@ PassResult ComposeAffineMaps::runOnFunction(Function *f) {
|
|||
|
||||
// Erase dead affine apply ops.
|
||||
affineApplyOps.clear();
|
||||
walk(f);
|
||||
f->walk<AffineApplyOp>(
|
||||
[&](OpPointer<AffineApplyOp> afOp) { affineApplyOps.push_back(afOp); });
|
||||
for (auto it = affineApplyOps.rbegin(); it != affineApplyOps.rend(); ++it) {
|
||||
if ((*it)->use_empty()) {
|
||||
(*it)->erase();
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
#include "mlir/AffineOps/AffineOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
|
@ -27,7 +26,7 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
/// Simple constant folding pass.
|
||||
struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
|
||||
struct ConstantFold : public FunctionPass {
|
||||
ConstantFold() : FunctionPass(&ConstantFold::passID) {}
|
||||
|
||||
// All constants in the function post folding.
|
||||
|
@ -35,9 +34,7 @@ struct ConstantFold : public FunctionPass, InstWalker<ConstantFold> {
|
|||
// Operations that were folded and that need to be erased.
|
||||
std::vector<Instruction *> opInstsToErase;
|
||||
|
||||
bool foldOperation(Instruction *op,
|
||||
SmallVectorImpl<Value *> &existingConstants);
|
||||
void visitInstruction(Instruction *op);
|
||||
void foldInstruction(Instruction *op);
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
static char passID;
|
||||
|
@ -49,7 +46,7 @@ char ConstantFold::passID = 0;
|
|||
/// Attempt to fold the specified operation, updating the IR to match. If
|
||||
/// constants are found, we keep track of them in the existingConstants list.
|
||||
///
|
||||
void ConstantFold::visitInstruction(Instruction *op) {
|
||||
void ConstantFold::foldInstruction(Instruction *op) {
|
||||
// If this operation is an AffineForOp, then fold the bounds.
|
||||
if (auto forOp = op->dyn_cast<AffineForOp>()) {
|
||||
constantFoldBounds(forOp);
|
||||
|
@ -111,7 +108,7 @@ PassResult ConstantFold::runOnFunction(Function *f) {
|
|||
existingConstants.clear();
|
||||
opInstsToErase.clear();
|
||||
|
||||
walk(f);
|
||||
f->walk([&](Instruction *inst) { foldInstruction(inst); });
|
||||
|
||||
// At this point, these operations are dead, remove them.
|
||||
// TODO: This is assuming that all constant foldable operations have no
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
|
@ -111,22 +110,23 @@ namespace {
|
|||
|
||||
// LoopNestStateCollector walks loop nests and collects load and store
|
||||
// operations, and whether or not an IfInst was encountered in the loop nest.
|
||||
class LoopNestStateCollector : public InstWalker<LoopNestStateCollector> {
|
||||
public:
|
||||
struct LoopNestStateCollector {
|
||||
SmallVector<OpPointer<AffineForOp>, 4> forOps;
|
||||
SmallVector<Instruction *, 4> loadOpInsts;
|
||||
SmallVector<Instruction *, 4> storeOpInsts;
|
||||
bool hasNonForRegion = false;
|
||||
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
if (opInst->isa<AffineForOp>())
|
||||
forOps.push_back(opInst->cast<AffineForOp>());
|
||||
else if (opInst->getNumBlockLists() != 0)
|
||||
hasNonForRegion = true;
|
||||
else if (opInst->isa<LoadOp>())
|
||||
loadOpInsts.push_back(opInst);
|
||||
else if (opInst->isa<StoreOp>())
|
||||
storeOpInsts.push_back(opInst);
|
||||
void collect(Instruction *instToWalk) {
|
||||
instToWalk->walk([&](Instruction *opInst) {
|
||||
if (opInst->isa<AffineForOp>())
|
||||
forOps.push_back(opInst->cast<AffineForOp>());
|
||||
else if (opInst->getNumBlockLists() != 0)
|
||||
hasNonForRegion = true;
|
||||
else if (opInst->isa<LoadOp>())
|
||||
loadOpInsts.push_back(opInst);
|
||||
else if (opInst->isa<StoreOp>())
|
||||
storeOpInsts.push_back(opInst);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -510,7 +510,7 @@ bool MemRefDependenceGraph::init(Function *f) {
|
|||
// Create graph node 'id' to represent top-level 'forOp' and record
|
||||
// all loads and store accesses it contains.
|
||||
LoopNestStateCollector collector;
|
||||
collector.walk(&inst);
|
||||
collector.collect(&inst);
|
||||
// Return false if a non 'for' region was found (not currently supported).
|
||||
if (collector.hasNonForRegion)
|
||||
return false;
|
||||
|
@ -606,41 +606,39 @@ struct LoopNestStats {
|
|||
|
||||
// LoopNestStatsCollector walks a single loop nest and gathers per-loop
|
||||
// trip count and operation count statistics and records them in 'stats'.
|
||||
class LoopNestStatsCollector : public InstWalker<LoopNestStatsCollector> {
|
||||
public:
|
||||
struct LoopNestStatsCollector {
|
||||
LoopNestStats *stats;
|
||||
bool hasLoopWithNonConstTripCount = false;
|
||||
|
||||
LoopNestStatsCollector(LoopNestStats *stats) : stats(stats) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||
if (!forOp)
|
||||
return;
|
||||
void collect(Instruction *inst) {
|
||||
inst->walk<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
|
||||
auto *forInst = forOp->getInstruction();
|
||||
auto *parentInst = forOp->getInstruction()->getParentInst();
|
||||
if (parentInst != nullptr) {
|
||||
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
|
||||
// Add mapping to 'forOp' from its parent AffineForOp.
|
||||
stats->loopMap[parentInst].push_back(forOp);
|
||||
}
|
||||
|
||||
auto *forInst = forOp->getInstruction();
|
||||
auto *parentInst = forOp->getInstruction()->getParentInst();
|
||||
if (parentInst != nullptr) {
|
||||
assert(parentInst->isa<AffineForOp>() && "Expected parent AffineForOp");
|
||||
// Add mapping to 'forOp' from its parent AffineForOp.
|
||||
stats->loopMap[parentInst].push_back(forOp);
|
||||
}
|
||||
|
||||
// Record the number of op instructions in the body of 'forOp'.
|
||||
unsigned count = 0;
|
||||
stats->opCountMap[forInst] = 0;
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
|
||||
++count;
|
||||
}
|
||||
stats->opCountMap[forInst] = count;
|
||||
// Record trip count for 'forOp'. Set flag if trip count is not constant.
|
||||
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
|
||||
if (!maybeConstTripCount.hasValue()) {
|
||||
hasLoopWithNonConstTripCount = true;
|
||||
return;
|
||||
}
|
||||
stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
|
||||
// Record the number of op instructions in the body of 'forOp'.
|
||||
unsigned count = 0;
|
||||
stats->opCountMap[forInst] = 0;
|
||||
for (auto &inst : *forOp->getBody()) {
|
||||
if (!(inst.isa<AffineForOp>() || inst.isa<AffineIfOp>()))
|
||||
++count;
|
||||
}
|
||||
stats->opCountMap[forInst] = count;
|
||||
// Record trip count for 'forOp'. Set flag if trip count is not
|
||||
// constant.
|
||||
Optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp);
|
||||
if (!maybeConstTripCount.hasValue()) {
|
||||
hasLoopWithNonConstTripCount = true;
|
||||
return;
|
||||
}
|
||||
stats->tripCountMap[forInst] = maybeConstTripCount.getValue();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1078,7 +1076,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
// Walk src loop nest and collect stats.
|
||||
LoopNestStats srcLoopNestStats;
|
||||
LoopNestStatsCollector srcStatsCollector(&srcLoopNestStats);
|
||||
srcStatsCollector.walk(srcLoopIVs[0]->getInstruction());
|
||||
srcStatsCollector.collect(srcLoopIVs[0]->getInstruction());
|
||||
// Currently only constant trip count loop nests are supported.
|
||||
if (srcStatsCollector.hasLoopWithNonConstTripCount)
|
||||
return false;
|
||||
|
@ -1089,7 +1087,7 @@ static bool isFusionProfitable(Instruction *srcOpInst,
|
|||
|
||||
LoopNestStats dstLoopNestStats;
|
||||
LoopNestStatsCollector dstStatsCollector(&dstLoopNestStats);
|
||||
dstStatsCollector.walk(dstLoopIVs[0]->getInstruction());
|
||||
dstStatsCollector.collect(dstLoopIVs[0]->getInstruction());
|
||||
// Currently only constant trip count loop nests are supported.
|
||||
if (dstStatsCollector.hasLoopWithNonConstTripCount)
|
||||
return false;
|
||||
|
@ -1474,7 +1472,7 @@ public:
|
|||
|
||||
// Collect slice loop stats.
|
||||
LoopNestStateCollector sliceCollector;
|
||||
sliceCollector.walk(sliceLoopNest->getInstruction());
|
||||
sliceCollector.collect(sliceLoopNest->getInstruction());
|
||||
// Promote single iteration slice loops to single IV value.
|
||||
for (auto forOp : sliceCollector.forOps) {
|
||||
promoteIfSingleIteration(forOp);
|
||||
|
@ -1498,7 +1496,7 @@ public:
|
|||
|
||||
// Collect dst loop stats after memref privatizaton transformation.
|
||||
LoopNestStateCollector dstLoopCollector;
|
||||
dstLoopCollector.walk(dstAffineForOp->getInstruction());
|
||||
dstLoopCollector.collect(dstAffineForOp->getInstruction());
|
||||
|
||||
// Add new load ops to current Node load op list 'loads' to
|
||||
// continue fusing based on new operands.
|
||||
|
|
|
@ -27,7 +27,6 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -95,15 +94,16 @@ char LoopUnroll::passID = 0;
|
|||
|
||||
PassResult LoopUnroll::runOnFunction(Function *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
class InnermostLoopGatherer : public InstWalker<InnermostLoopGatherer, bool> {
|
||||
public:
|
||||
struct InnermostLoopGatherer {
|
||||
// Store innermost loops as we walk.
|
||||
std::vector<OpPointer<AffineForOp>> loops;
|
||||
|
||||
// This method specialized to encode custom return logic.
|
||||
using InstListType = llvm::iplist<Instruction>;
|
||||
bool walkPostOrder(InstListType::iterator Start,
|
||||
InstListType::iterator End) {
|
||||
void walkPostOrder(Function *f) {
|
||||
for (auto &b : *f)
|
||||
walkPostOrder(b.begin(), b.end());
|
||||
}
|
||||
|
||||
bool walkPostOrder(Block::iterator Start, Block::iterator End) {
|
||||
bool hasInnerLoops = false;
|
||||
// We need to walk all elements since all innermost loops need to be
|
||||
// gathered as opposed to determining whether this list has any inner
|
||||
|
@ -112,7 +112,6 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
|||
hasInnerLoops |= walkPostOrder(&(*Start++));
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
bool walkPostOrder(Instruction *opInst) {
|
||||
bool hasInnerLoops = false;
|
||||
for (auto &blockList : opInst->getBlockLists())
|
||||
|
@ -125,39 +124,21 @@ PassResult LoopUnroll::runOnFunction(Function *f) {
|
|||
}
|
||||
return hasInnerLoops;
|
||||
}
|
||||
|
||||
// FIXME: can't use base class method for this because that in turn would
|
||||
// need to use the derived class method above. CRTP doesn't allow it, and
|
||||
// the compiler error resulting from it is also misleading.
|
||||
using InstWalker<InnermostLoopGatherer, bool>::walkPostOrder;
|
||||
};
|
||||
|
||||
// Gathers all loops with trip count <= minTripCount.
|
||||
class ShortLoopGatherer : public InstWalker<ShortLoopGatherer> {
|
||||
public:
|
||||
// Store short loops as we walk.
|
||||
std::vector<OpPointer<AffineForOp>> loops;
|
||||
const unsigned minTripCount;
|
||||
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
|
||||
|
||||
void visitInstruction(Instruction *opInst) {
|
||||
auto forOp = opInst->dyn_cast<AffineForOp>();
|
||||
if (!forOp)
|
||||
return;
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
|
||||
if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
|
||||
loops.push_back(forOp);
|
||||
}
|
||||
};
|
||||
|
||||
if (clUnrollFull.getNumOccurrences() > 0 &&
|
||||
clUnrollFullThreshold.getNumOccurrences() > 0) {
|
||||
ShortLoopGatherer slg(clUnrollFullThreshold);
|
||||
// Do a post order walk so that loops are gathered from innermost to
|
||||
// outermost (or else unrolling an outer one may delete gathered inner
|
||||
// ones).
|
||||
slg.walkPostOrder(f);
|
||||
auto &loops = slg.loops;
|
||||
// Store short loops as we walk.
|
||||
std::vector<OpPointer<AffineForOp>> loops;
|
||||
|
||||
// Gathers all loops with trip count <= minTripCount. Do a post order walk
|
||||
// so that loops are gathered from innermost to outermost (or else unrolling
|
||||
// an outer one may delete gathered inner ones).
|
||||
f->walkPostOrder<AffineForOp>([&](OpPointer<AffineForOp> forOp) {
|
||||
Optional<uint64_t> tripCount = getConstantTripCount(forOp);
|
||||
if (tripCount.hasValue() && tripCount.getValue() <= clUnrollFullThreshold)
|
||||
loops.push_back(forOp);
|
||||
});
|
||||
for (auto forOp : loops)
|
||||
loopUnrollFull(forOp);
|
||||
return success();
|
||||
|
|
|
@ -50,7 +50,6 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Transforms/LoopUtils.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -136,24 +135,25 @@ bool mlir::loopUnrollJamByFactor(OpPointer<AffineForOp> forOp,
|
|||
// Gathers all maximal sub-blocks of instructions that do not themselves
|
||||
// include a for inst (a instruction could have a descendant for inst though
|
||||
// in its tree).
|
||||
class JamBlockGatherer : public InstWalker<JamBlockGatherer> {
|
||||
public:
|
||||
using InstListType = llvm::iplist<Instruction>;
|
||||
using InstWalker<JamBlockGatherer>::walk;
|
||||
|
||||
struct JamBlockGatherer {
|
||||
// Store iterators to the first and last inst of each sub-block found.
|
||||
std::vector<std::pair<Block::iterator, Block::iterator>> subBlocks;
|
||||
|
||||
// This is a linear time walk.
|
||||
void walk(InstListType::iterator Start, InstListType::iterator End) {
|
||||
for (auto it = Start; it != End;) {
|
||||
void walk(Instruction *inst) {
|
||||
for (auto &blockList : inst->getBlockLists())
|
||||
for (auto &block : blockList)
|
||||
walk(block);
|
||||
}
|
||||
void walk(Block &block) {
|
||||
for (auto it = block.begin(), e = block.end(); it != e;) {
|
||||
auto subBlockStart = it;
|
||||
while (it != End && !it->isa<AffineForOp>())
|
||||
while (it != e && !it->isa<AffineForOp>())
|
||||
++it;
|
||||
if (it != subBlockStart)
|
||||
subBlocks.push_back({subBlockStart, std::prev(it)});
|
||||
// Process all for insts that appear next.
|
||||
while (it != End && it->isa<AffineForOp>())
|
||||
while (it != e && it->isa<AffineForOp>())
|
||||
walk(&*it++);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/Analysis/AffineAnalysis.h"
|
||||
#include "mlir/Analysis/Dominance.h"
|
||||
#include "mlir/Analysis/Utils.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
@ -70,12 +69,12 @@ namespace {
|
|||
// currently only eliminates the stores only if no other loads/uses (other
|
||||
// than dealloc) remain.
|
||||
//
|
||||
struct MemRefDataFlowOpt : public FunctionPass, InstWalker<MemRefDataFlowOpt> {
|
||||
struct MemRefDataFlowOpt : public FunctionPass {
|
||||
explicit MemRefDataFlowOpt() : FunctionPass(&MemRefDataFlowOpt::passID) {}
|
||||
|
||||
PassResult runOnFunction(Function *f) override;
|
||||
|
||||
void visitInstruction(Instruction *opInst);
|
||||
void forwardStoreToLoad(OpPointer<LoadOp> loadOp);
|
||||
|
||||
// A list of memref's that are potentially dead / could be eliminated.
|
||||
SmallPtrSet<Value *, 4> memrefsToErase;
|
||||
|
@ -100,14 +99,9 @@ FunctionPass *mlir::createMemRefDataFlowOptPass() {
|
|||
|
||||
// This is a straightforward implementation not optimized for speed. Optimize
|
||||
// this in the future if needed.
|
||||
void MemRefDataFlowOpt::visitInstruction(Instruction *opInst) {
|
||||
void MemRefDataFlowOpt::forwardStoreToLoad(OpPointer<LoadOp> loadOp) {
|
||||
Instruction *lastWriteStoreOp = nullptr;
|
||||
|
||||
auto loadOp = opInst->dyn_cast<LoadOp>();
|
||||
if (!loadOp)
|
||||
return;
|
||||
|
||||
Instruction *loadOpInst = opInst;
|
||||
Instruction *loadOpInst = loadOp->getInstruction();
|
||||
|
||||
// First pass over the use list to get minimum number of surrounding
|
||||
// loops common between the load op and the store op, with min taken across
|
||||
|
@ -235,7 +229,8 @@ PassResult MemRefDataFlowOpt::runOnFunction(Function *f) {
|
|||
memrefsToErase.clear();
|
||||
|
||||
// Walk all load's and perform load/store forwarding.
|
||||
walk(f);
|
||||
f->walk<LoadOp>(
|
||||
[&](OpPointer<LoadOp> loadOp) { forwardStoreToLoad(loadOp); });
|
||||
|
||||
// Erase all load op's whose results were replaced with store fwd'ed ones.
|
||||
for (auto *loadOp : loadOpsToErase) {
|
||||
|
|
|
@ -142,10 +142,8 @@ PassResult PipelineDataTransfer::runOnFunction(Function *f) {
|
|||
// deleted and replaced by a prologue, a new steady-state loop and an
|
||||
// epilogue).
|
||||
forOps.clear();
|
||||
f->walkPostOrder([&](Instruction *opInst) {
|
||||
if (auto forOp = opInst->dyn_cast<AffineForOp>())
|
||||
forOps.push_back(forOp);
|
||||
});
|
||||
f->walkPostOrder<AffineForOp>(
|
||||
[&](OpPointer<AffineForOp> forOp) { forOps.push_back(forOp); });
|
||||
bool ret = false;
|
||||
for (auto forOp : forOps) {
|
||||
ret = ret | runOnAffineForOp(forOp);
|
||||
|
|
|
@ -28,7 +28,6 @@
|
|||
#include "mlir/IR/BlockAndValueMapping.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/InstVisitor.h"
|
||||
#include "mlir/IR/Instruction.h"
|
||||
#include "mlir/StandardOps/StandardOps.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -135,10 +134,8 @@ bool mlir::promoteIfSingleIteration(OpPointer<AffineForOp> forOp) {
|
|||
/// their body into the containing Block.
|
||||
void mlir::promoteSingleIterationLoops(Function *f) {
|
||||
// Gathers all innermost loops through a post order pruned walk.
|
||||
f->walkPostOrder([](Instruction *inst) {
|
||||
if (auto forOp = inst->dyn_cast<AffineForOp>())
|
||||
promoteIfSingleIteration(forOp);
|
||||
});
|
||||
f->walkPostOrder<AffineForOp>(
|
||||
[](OpPointer<AffineForOp> forOp) { promoteIfSingleIteration(forOp); });
|
||||
}
|
||||
|
||||
/// Generates a 'for' inst with the specified lower and upper bounds while
|
||||
|
|
Loading…
Reference in New Issue