[mlir-reduce] Reducer refactor.

* A Reducer is a kind of RewritePattern, so it's just the same as
writing graph rewrite.
* ReductionTreePass operates on Operation rather than ModuleOp, so that
* we are able to reduce a nested structure(e.g., module in module) by
* self-nesting.

Reviewed By: jpienaar, rriddle

Differential Revision: https://reviews.llvm.org/D101046
This commit is contained in:
Chia-hung Duan 2021-06-02 07:00:19 +08:00
parent 26044c6a54
commit c484c7dd9d
26 changed files with 519 additions and 371 deletions

View File

@ -1,41 +0,0 @@
//===- OptReductionPass.h - Optimization Reduction Pass Wrapper -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Opt Reduction Pass Wrapper. It creates a MLIR pass to
// run any optimization pass within it and only replaces the output module with
// the transformed version if it is smaller and interesting.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_OPTREDUCTIONPASS_H
#define MLIR_REDUCER_OPTREDUCTIONPASS_H
#include "PassDetail.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/Debug.h"
namespace mlir {
class OptReductionPass : public OptReductionBase<OptReductionPass> {
public:
OptReductionPass() = default;
OptReductionPass(const OptReductionPass &srcPass) = default;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
};
} // end namespace mlir
#endif

View File

@ -9,8 +9,6 @@
#define MLIR_REDUCER_PASSES_H #define MLIR_REDUCER_PASSES_H
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Reducer/ReductionTreePass.h"
namespace mlir { namespace mlir {

View File

@ -24,14 +24,12 @@ def CommonReductionPassOptions {
]; ];
} }
def ReductionTree : Pass<"reduction-tree", "ModuleOp"> { def ReductionTree : Pass<"reduction-tree"> {
let summary = "A general reduction tree pass for the MLIR Reduce Tool"; let summary = "A general reduction tree pass for the MLIR Reduce Tool";
let constructor = "mlir::createReductionTreePass()"; let constructor = "mlir::createReductionTreePass()";
let options = [ let options = [
Option<"opReducerName", "op-reducer", "std::string", /* default */"",
"The OpReducer to reduce the module">,
Option<"traversalModeId", "traversal-mode", "unsigned", Option<"traversalModeId", "traversal-mode", "unsigned",
/* default */"0", "The graph traversal mode">, /* default */"0", "The graph traversal mode">,
] # CommonReductionPassOptions.options; ] # CommonReductionPassOptions.options;

View File

@ -1,76 +0,0 @@
//===- OpReducer.h - MLIR Reduce Operation Reducer ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the OpReducer class. It defines a variant generator method
// with the purpose of producing different variants by eliminating a
// parameterizable type of operations from the parent module.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_PASSES_OPREDUCER_H
#define MLIR_REDUCER_PASSES_OPREDUCER_H
#include <limits>
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/Tester.h"
namespace mlir {
class OpReducer {
public:
virtual ~OpReducer() = default;
/// According to rangeToKeep, try to reduce the given module. We implicitly
/// number each interesting operation and rangeToKeep indicates that if an
/// operation's number falls into certain range, then we will not try to
/// reduce that operation.
virtual void reduce(ModuleOp module,
ArrayRef<ReductionNode::Range> rangeToKeep) = 0;
/// Return the number of certain kind of operations that we would like to
/// reduce. This can be used to build a range map to exclude uninterested
/// operations.
virtual int getNumTargetOps(ModuleOp module) const = 0;
};
/// Reducer is a helper class to remove potential uninteresting operations from
/// module.
template <typename OpType>
class Reducer : public OpReducer {
public:
~Reducer() override = default;
int getNumTargetOps(ModuleOp module) const override {
return std::distance(module.getOps<OpType>().begin(),
module.getOps<OpType>().end());
}
void reduce(ModuleOp module,
ArrayRef<ReductionNode::Range> rangeToKeep) override {
std::vector<Operation *> opsToRemove;
size_t keepIndex = 0;
for (auto op : enumerate(module.getOps<OpType>())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() ||
index < rangeToKeep[keepIndex].first)
opsToRemove.push_back(op.value());
}
for (Operation *o : opsToRemove) {
o->dropAllUses();
o->erase();
}
}
};
} // end namespace mlir
#endif

View File

@ -21,19 +21,25 @@
#include <vector> #include <vector>
#include "mlir/Reducer/Tester.h" #include "mlir/Reducer/Tester.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/Allocator.h" #include "llvm/Support/Allocator.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
namespace mlir { namespace mlir {
class ModuleOp;
class Region;
/// Defines the traversal method options to be used in the reduction tree /// Defines the traversal method options to be used in the reduction tree
/// traversal. /// traversal.
enum TraversalMode { SinglePath, Backtrack, MultiPath }; enum TraversalMode { SinglePath, Backtrack, MultiPath };
/// This class defines the ReductionNode which is used to generate variant and /// ReductionTreePass will build a reduction tree during module reduction and
/// keep track of the necessary metadata for the reduction pass. The nodes are /// the ReductionNode represents the vertex of the tree. A ReductionNode records
/// linked together in a reduction tree structure which defines the relationship /// the information such as the reduced module, how this node is reduced from
/// between all the different generated variants. /// the parent node, etc. This information will be used to construct a reduction
/// path to reduce the certain module.
class ReductionNode { class ReductionNode {
public: public:
template <TraversalMode mode> template <TraversalMode mode>
@ -44,23 +50,46 @@ public:
ReductionNode(ReductionNode *parent, std::vector<Range> range, ReductionNode(ReductionNode *parent, std::vector<Range> range,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator); llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator);
ReductionNode *getParent() const; ReductionNode *getParent() const { return parent; }
size_t getSize() const; /// If the ReductionNode hasn't been tested the interestingness, it'll be the
/// same module as the one in the parent node. Otherwise, the returned module
/// will have been applied certain reduction strategies. Note that it's not
/// necessary to be an interesting case or a reduced module (has smaller size
/// than parent's).
ModuleOp getModule() const { return module; }
/// Return the region we're reducing.
Region &getRegion() const { return *region; }
/// Return the size of the module.
size_t getSize() const { return size; }
/// Returns true if the module exhibits the interesting behavior. /// Returns true if the module exhibits the interesting behavior.
Tester::Interestingness isInteresting() const; Tester::Interestingness isInteresting() const { return interesting; }
std::vector<Range> getRanges() const; /// Return the range information that how this node is reduced from the parent
/// node.
ArrayRef<Range> getStartRanges() const { return startRanges; }
std::vector<ReductionNode *> &getVariants(); /// Return the range set we are using to generate variants.
ArrayRef<Range> getRanges() const { return ranges; }
/// Return the generated variants(the child nodes).
ArrayRef<ReductionNode *> getVariants() const { return variants; }
/// Split the ranges and generate new variants. /// Split the ranges and generate new variants.
std::vector<ReductionNode *> generateNewVariants(); ArrayRef<ReductionNode *> generateNewVariants();
/// Update the interestingness result from tester. /// Update the interestingness result from tester.
void update(std::pair<Tester::Interestingness, size_t> result); void update(std::pair<Tester::Interestingness, size_t> result);
/// Each Reduction Node contains a copy of module for applying rewrite
/// patterns. In addition, we only apply rewrite patterns in a certain region.
/// In init(), we will duplicate the module from parent node and locate the
/// corresponding region.
LogicalResult initialize(ModuleOp parentModule, Region &parentRegion);
private: private:
/// A custom BFS iterator. The difference between /// A custom BFS iterator. The difference between
/// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic. /// llvm/ADT/BreadthFirstIterator.h is the graph we're exploring is dynamic.
@ -87,8 +116,7 @@ private:
BaseIterator &operator++() { BaseIterator &operator++() {
ReductionNode *top = visitQueue.front(); ReductionNode *top = visitQueue.front();
visitQueue.pop(); visitQueue.pop();
std::vector<ReductionNode *> neighbors = getNeighbors(top); for (ReductionNode *node : getNeighbors(top))
for (ReductionNode *node : neighbors)
visitQueue.push(node); visitQueue.push(node);
return *this; return *this;
} }
@ -103,7 +131,7 @@ private:
ReductionNode *operator->() const { return visitQueue.front(); } ReductionNode *operator->() const { return visitQueue.front(); }
protected: protected:
std::vector<ReductionNode *> getNeighbors(ReductionNode *node) { ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node) {
return static_cast<T *>(this)->getNeighbors(node); return static_cast<T *>(this)->getNeighbors(node);
} }
@ -111,21 +139,42 @@ private:
std::queue<ReductionNode *> visitQueue; std::queue<ReductionNode *> visitQueue;
}; };
/// The size of module after applying the range constraints. /// This is a copy of module from parent node. All the reducer patterns will
/// be applied to this instance.
ModuleOp module;
/// The region of certain operation we're reducing in the module
Region *region;
/// The node we are reduced from. It means we will be in variants of parent
/// node.
ReductionNode *parent;
/// The size of module after applying the reducer patterns with range
/// constraints. This is only valid while the interestingness has been tested.
size_t size; size_t size;
/// This is true if the module has been evaluated and it exhibits the /// This is true if the module has been evaluated and it exhibits the
/// interesting behavior. /// interesting behavior.
Tester::Interestingness interesting; Tester::Interestingness interesting;
ReductionNode *parent; /// `ranges` represents the selected subset of operations in the region. We
/// implictly number each operation in the region and ReductionTreePass will
/// We will only keep the operation with index falls into the ranges. /// apply reducer patterns on the operation falls into the `ranges`. We will
/// For example, number each function in a certain module and then we will /// generate new ReductionNode with subset of `ranges` to see if we can do
/// remove the functions with index outside the ranges and see if the /// further reduction. we may split the element in the `ranges` so that we can
/// resulting module is still interesting. /// have more subset variants from `ranges`.
/// Note that after applying the reducer patterns the number of operation in
/// the region may have changed, we need to update the `ranges` after that.
std::vector<Range> ranges; std::vector<Range> ranges;
/// `startRanges` records the ranges of operations selected from the parent
/// node to produce this ReductionNode. It can be used to construct the
/// reduction path from the root. I.e., if we apply the same reducer patterns
/// and `startRanges` selection on the parent region, we will get the same
/// module as this node.
const std::vector<Range> startRanges;
/// This points to the child variants that were created using this node as a /// This points to the child variants that were created using this node as a
/// starting point. /// starting point.
std::vector<ReductionNode *> variants; std::vector<ReductionNode *> variants;
@ -139,9 +188,9 @@ class ReductionNode::iterator<SinglePath>
: public BaseIterator<iterator<SinglePath>> { : public BaseIterator<iterator<SinglePath>> {
friend BaseIterator<iterator<SinglePath>>; friend BaseIterator<iterator<SinglePath>>;
using BaseIterator::BaseIterator; using BaseIterator::BaseIterator;
std::vector<ReductionNode *> getNeighbors(ReductionNode *node); ArrayRef<ReductionNode *> getNeighbors(ReductionNode *node);
}; };
} // end namespace mlir } // end namespace mlir
#endif #endif // MLIR_REDUCER_REDUCTIONNODE_H

View File

@ -0,0 +1,56 @@
//===- ReducePatternInterface.h - Collecting Reduce Patterns ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#define MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H
#include "mlir/IR/DialectInterface.h"
namespace mlir {
class RewritePatternSet;
/// This is used to report the reduction patterns for a Dialect. While using
/// mlir-reduce to reduce a module, we may want to transform certain cases into
/// simpler forms by applying certain rewrite patterns. Implement the
/// `populateReductionPatterns` to report those patterns by adding them to the
/// RewritePatternSet.
///
/// Example:
/// MyDialectReductionPattern::populateReductionPatterns(
/// RewritePatternSet &patterns) {
/// patterns.add<TensorOpReduction>(patterns.getContext());
/// }
///
/// For DRR, mlir-tblgen will generate a helper function
/// `populateWithGenerated` which has the same signature therefore you can
/// delegate to the helper function as well.
///
/// Example:
/// MyDialectReductionPattern::populateReductionPatterns(
/// RewritePatternSet &patterns) {
/// // Include the autogen file somewhere above.
/// populateWithGenerated(patterns);
/// }
class DialectReductionPatternInterface
: public DialectInterface::Base<DialectReductionPatternInterface> {
public:
/// Patterns provided here are intended to transform operations from a complex
/// form to a simpler form, without breaking the semantics of the program
/// being reduced. For example, you may want to replace the
/// tensor<?xindex> with a known rank and type, e.g. tensor<1xi32>, or
/// replacing an operation with a constant.
virtual void populateReductionPatterns(RewritePatternSet &patterns) const = 0;
protected:
DialectReductionPatternInterface(Dialect *dialect) : Base(dialect) {}
};
} // end namespace mlir
#endif // MLIR_REDUCER_REDUCTIONPATTERNINTERFACE_H

View File

@ -1,50 +0,0 @@
//===- ReductionTreePass.h - Reduction Tree Pass Implementation -*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REDUCER_REDUCTIONTREEPASS_H
#define MLIR_REDUCER_REDUCTIONTREEPASS_H
#include <vector>
#include "PassDetail.h"
#include "ReductionNode.h"
#include "mlir/Reducer/Passes/OpReducer.h"
#include "mlir/Reducer/Tester.h"
#define DEBUG_TYPE "mlir-reduce"
namespace mlir {
/// This class defines the Reduction Tree Pass. It provides a framework to
/// to implement a reduction pass using a tree structure to keep track of the
/// generated reduced variants.
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
private:
template <typename IteratorType>
ModuleOp findOptimal(ModuleOp module, std::unique_ptr<OpReducer> reducer,
ReductionNode *node);
};
} // end namespace mlir
#endif

View File

@ -1,7 +1,13 @@
add_mlir_library(MLIRReduce add_mlir_library(MLIRReduce
OptReductionPass.cpp
ReductionNode.cpp
ReductionTreePass.cpp
Tester.cpp Tester.cpp
LINK_LIBS PUBLIC LINK_LIBS PUBLIC
MLIRIR MLIRIR
MLIRPass
MLIRRewrite
MLIRTransformUtils
) )
mlir_check_all_link_libraries(MLIRReduce) mlir_check_all_link_libraries(MLIRReduce)

View File

@ -12,15 +12,27 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Reducer/OptReductionPass.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Pass/PassRegistry.h" #include "mlir/Pass/PassRegistry.h"
#include "mlir/Reducer/PassDetail.h"
#include "mlir/Reducer/Passes.h" #include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/Tester.h" #include "mlir/Reducer/Tester.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "mlir-reduce" #define DEBUG_TYPE "mlir-reduce"
using namespace mlir; using namespace mlir;
namespace {
class OptReductionPass : public OptReductionBase<OptReductionPass> {
public:
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
};
} // end anonymous namespace
/// Runs the pass instance in the pass pipeline. /// Runs the pass instance in the pass pipeline.
void OptReductionPass::runOnOperation() { void OptReductionPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: "); LLVM_DEBUG(llvm::dbgs() << "\nOptimization Reduction pass: ");

View File

@ -15,6 +15,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionNode.h" #include "mlir/Reducer/ReductionNode.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include <algorithm> #include <algorithm>
@ -23,102 +24,102 @@
using namespace mlir; using namespace mlir;
ReductionNode::ReductionNode( ReductionNode::ReductionNode(
ReductionNode *parent, std::vector<Range> ranges, ReductionNode *parentNode, std::vector<Range> ranges,
llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator) llvm::SpecificBumpPtrAllocator<ReductionNode> &allocator)
: size(std::numeric_limits<size_t>::max()), /// Root node will have the parent pointer point to themselves.
interesting(Tester::Interestingness::Untested), : parent(parentNode == nullptr ? this : parentNode),
/// Root node will have the parent pointer point to themselves. size(std::numeric_limits<size_t>::max()),
parent(parent == nullptr ? this : parent), ranges(ranges), interesting(Tester::Interestingness::Untested), ranges(ranges),
allocator(allocator) {} startRanges(ranges), allocator(allocator) {
if (parent != this)
/// Returns the size in bytes of the module. if (failed(initialize(parent->getModule(), parent->getRegion())))
size_t ReductionNode::getSize() const { return size; } llvm_unreachable("unexpected initialization failure");
ReductionNode *ReductionNode::getParent() const { return parent; }
/// Returns true if the module exhibits the interesting behavior.
Tester::Interestingness ReductionNode::isInteresting() const {
return interesting;
} }
std::vector<ReductionNode::Range> ReductionNode::getRanges() const { LogicalResult ReductionNode::initialize(ModuleOp parentModule,
return ranges; Region &targetRegion) {
// Use the mapper help us find the corresponding region after module clone.
BlockAndValueMapping mapper;
module = cast<ModuleOp>(parentModule->clone(mapper));
// Use the first block of targetRegion to locate the cloned region.
Block *block = mapper.lookup(&*targetRegion.begin());
region = block->getParent();
return success();
} }
std::vector<ReductionNode *> &ReductionNode::getVariants() { return variants; }
#include <iostream>
/// If we haven't explored any variants from this node, we will create N /// If we haven't explored any variants from this node, we will create N
/// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the /// variants, N is the length of `ranges` if N > 1. Otherwise, we will split the
/// max element in `ranges` and create 2 new variants for each call. /// max element in `ranges` and create 2 new variants for each call.
std::vector<ReductionNode *> ReductionNode::generateNewVariants() { ArrayRef<ReductionNode *> ReductionNode::generateNewVariants() {
std::vector<ReductionNode *> newNodes; int oldNumVariant = getVariants().size();
auto createNewNode = [this](std::vector<Range> ranges) {
return new (allocator.Allocate())
ReductionNode(this, std::move(ranges), allocator);
};
// If we haven't created new variant, then we can create varients by removing // If we haven't created new variant, then we can create varients by removing
// each of them respectively. For example, given {{1, 3}, {4, 9}}, we can // each of them respectively. For example, given {{1, 3}, {4, 9}}, we can
// produce variants with range {{1, 3}} and {{4, 9}}. // produce variants with range {{1, 3}} and {{4, 9}}.
if (variants.size() == 0 && ranges.size() != 1) { if (variants.size() == 0 && getRanges().size() > 1) {
for (const Range &range : ranges) { for (const Range &range : getRanges()) {
std::vector<Range> subRanges = ranges; std::vector<Range> subRanges = getRanges();
llvm::erase_value(subRanges, range); llvm::erase_value(subRanges, range);
ReductionNode *newNode = allocator.Allocate(); variants.push_back(createNewNode(std::move(subRanges)));
new (newNode) ReductionNode(this, subRanges, allocator);
newNodes.push_back(newNode);
variants.push_back(newNode);
} }
return newNodes; return getVariants().drop_front(oldNumVariant);
} }
// At here, we have created the type of variants mentioned above. We would // At here, we have created the type of variants mentioned above. We would
// like to split the max range into 2 to create 2 new variants. Continue on // like to split the max range into 2 to create 2 new variants. Continue on
// the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and // the above example, we split the range {4, 9} into {4, 6}, {6, 9}, and
// create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The // create two variants with range {{1, 3}, {4, 6}} and {{1, 3}, {6, 9}}. The
// result ranges vector will be {{1, 3}, {4, 6}, {6, 9}}. // final ranges vector will be {{1, 3}, {4, 6}, {6, 9}}.
auto maxElement = std::max_element( auto maxElement = std::max_element(
ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) { ranges.begin(), ranges.end(), [](const Range &lhs, const Range &rhs) {
return (lhs.second - lhs.first) > (rhs.second - rhs.first); return (lhs.second - lhs.first) > (rhs.second - rhs.first);
}); });
// We can't split range with lenght 1, which means we can't produce new // The length of range is less than 1, we can't split it to create new
// variant. // variant.
if (maxElement->second - maxElement->first == 1) if (maxElement->second - maxElement->first <= 1)
return {}; return {};
auto createNewNode = [this](const std::vector<Range> &ranges) {
ReductionNode *newNode = allocator.Allocate();
new (newNode) ReductionNode(this, ranges, allocator);
return newNode;
};
Range maxRange = *maxElement; Range maxRange = *maxElement;
std::vector<Range> subRanges = ranges; std::vector<Range> subRanges = getRanges();
auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin()); auto subRangesIter = subRanges.begin() + (maxElement - ranges.begin());
int half = (maxRange.first + maxRange.second) / 2; int half = (maxRange.first + maxRange.second) / 2;
*subRangesIter = std::make_pair(maxRange.first, half); *subRangesIter = std::make_pair(maxRange.first, half);
newNodes.push_back(createNewNode(subRanges)); variants.push_back(createNewNode(subRanges));
*subRangesIter = std::make_pair(half, maxRange.second); *subRangesIter = std::make_pair(half, maxRange.second);
newNodes.push_back(createNewNode(subRanges)); variants.push_back(createNewNode(std::move(subRanges)));
variants.insert(variants.end(), newNodes.begin(), newNodes.end());
auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second)); auto it = ranges.insert(maxElement, std::make_pair(half, maxRange.second));
it = ranges.insert(it, std::make_pair(maxRange.first, half)); it = ranges.insert(it, std::make_pair(maxRange.first, half));
// Remove the range that has been split. // Remove the range that has been split.
ranges.erase(it + 2); ranges.erase(it + 2);
return newNodes; return getVariants().drop_front(oldNumVariant);
} }
void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) { void ReductionNode::update(std::pair<Tester::Interestingness, size_t> result) {
std::tie(interesting, size) = result; std::tie(interesting, size) = result;
// After applying reduction, the number of operation in the region may have
// changed. Non-interesting case won't be explored thus it's safe to keep it
// in a stale status.
if (interesting == Tester::Interestingness::True) {
// This module may has been updated. Reset the range.
ranges.clear();
ranges.push_back({0, std::distance(region->op_begin(), region->op_end())});
}
} }
std::vector<ReductionNode *> ArrayRef<ReductionNode *>
ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) { ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
// Single Path: Traverses the smallest successful variant at each level until // Single Path: Traverses the smallest successful variant at each level until
// no new successful variants can be created at that level. // no new successful variants can be created at that level.
llvm::ArrayRef<ReductionNode *> variantsFromParent = ArrayRef<ReductionNode *> variantsFromParent =
node->getParent()->getVariants(); node->getParent()->getVariants();
// The parent node created several variants and they may be waiting for // The parent node created several variants and they may be waiting for
@ -139,7 +140,8 @@ ReductionNode::iterator<SinglePath>::getNeighbors(ReductionNode *node) {
smallest = node; smallest = node;
} }
if (smallest != nullptr) { if (smallest != nullptr &&
smallest->getSize() < node->getParent()->getSize()) {
// We got a smallest one, keep traversing from this node. // We got a smallest one, keep traversing from this node.
node = smallest; node = smallest;
} else { } else {

View File

@ -0,0 +1,247 @@
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/DialectInterface.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Reducer/PassDetail.h"
#include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Allocator.h"
#include "llvm/Support/ManagedStatic.h"
using namespace mlir;
/// We implicitly number each operation in the region and if an operation's
/// number falls into rangeToKeep, we need to keep it and apply the given
/// rewrite patterns on it.
static void applyPatterns(Region &region,
const FrozenRewritePatternSet &patterns,
ArrayRef<ReductionNode::Range> rangeToKeep,
bool eraseOpNotInRange) {
std::vector<Operation *> opsNotInRange;
std::vector<Operation *> opsInRange;
size_t keepIndex = 0;
for (auto op : enumerate(region.getOps())) {
int index = op.index();
if (keepIndex < rangeToKeep.size() &&
index == rangeToKeep[keepIndex].second)
++keepIndex;
if (keepIndex == rangeToKeep.size() || index < rangeToKeep[keepIndex].first)
opsNotInRange.push_back(&op.value());
else
opsInRange.push_back(&op.value());
}
// `applyOpPatternsAndFold` may erase the ops so we can't do the pattern
// matching in above iteration. Besides, erase op not-in-range may end up in
// invalid module, so `applyOpPatternsAndFold` should come before that
// transform.
for (Operation *op : opsInRange)
// `applyOpPatternsAndFold` returns whether the op is convered. Omit it
// because we don't have expectation this reduction will be success or not.
(void)applyOpPatternsAndFold(op, patterns);
if (eraseOpNotInRange)
for (Operation *op : opsNotInRange) {
op->dropAllUses();
op->erase();
}
}
/// We will apply the reducer patterns to the operations in the ranges specified
/// by ReductionNode. Note that we are not able to remove an operation without
/// replacing it with another valid operation. However, The validity of module
/// reduction is based on the Tester provided by the user and that means certain
/// invalid module is still interested by the use. Thus we provide an
/// alternative way to remove operations, which is using `eraseOpNotInRange` to
/// erase the operations not in the range specified by ReductionNode.
template <typename IteratorType>
static void findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
const Tester &test, bool eraseOpNotInRange) {
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
// While exploring the reduction tree, we always branch from an interesting
// node. Thus the root node must be interesting.
if (initStatus.first != Tester::Interestingness::True)
return;
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
std::vector<ReductionNode::Range> ranges{
{0, std::distance(region.op_begin(), region.op_end())}};
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, std::move(ranges), allocator);
// Duplicate the module for root node and locate the region in the copy.
if (failed(root->initialize(module, region)))
llvm_unreachable("unexpected initialization failure");
root->update(initStatus);
ReductionNode *smallestNode = root;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ReductionNode &currentNode = *iter;
Region &curRegion = currentNode.getRegion();
applyPatterns(curRegion, patterns, currentNode.getRanges(),
eraseOpNotInRange);
currentNode.update(test.isInteresting(currentNode.getModule()));
if (currentNode.isInteresting() == Tester::Interestingness::True &&
currentNode.getSize() < smallestNode->getSize())
smallestNode = &currentNode;
++iter;
}
// At here, we have found an optimal path to reduce the given region. Retrieve
// the path and apply the reducer to it.
SmallVector<ReductionNode *> trace;
ReductionNode *curNode = smallestNode;
trace.push_back(curNode);
while (curNode != root) {
curNode = curNode->getParent();
trace.push_back(curNode);
}
// Reduce the region through the optimal path.
while (!trace.empty()) {
ReductionNode *top = trace.pop_back_val();
applyPatterns(region, patterns, top->getStartRanges(), eraseOpNotInRange);
}
if (test.isInteresting(module).first != Tester::Interestingness::True)
llvm::report_fatal_error("Reduced module is not interesting");
if (test.isInteresting(module).second != smallestNode->getSize())
llvm::report_fatal_error(
"Reduced module doesn't have consistent size with smallestNode");
}
template <typename IteratorType>
static void findOptimal(ModuleOp module, Region &region,
const FrozenRewritePatternSet &patterns,
const Tester &test) {
// We separate the reduction process into 2 steps, the first one is to erase
// redundant operations and the second one is to apply the reducer patterns.
// In the first phase, we don't apply any patterns so that we only select the
// range of operations to keep to the module stay interesting.
findOptimal<IteratorType>(module, region, /*patterns=*/{}, test,
/*eraseOpNotInRange=*/true);
// In the second phase, we suppose that no operation is redundant, so we try
// to rewrite the operation into simpler form.
findOptimal<IteratorType>(module, region, patterns, test,
/*eraseOpNotInRange=*/false);
}
namespace {
//===----------------------------------------------------------------------===//
// Reduction Pattern Interface Collection
//===----------------------------------------------------------------------===//
class ReductionPatternInterfaceCollection
: public DialectInterfaceCollection<DialectReductionPatternInterface> {
public:
using Base::Base;
// Collect the reduce patterns defined by each dialect.
void populateReductionPatterns(RewritePatternSet &pattern) const {
for (const DialectReductionPatternInterface &interface : *this)
interface.populateReductionPatterns(pattern);
}
};
//===----------------------------------------------------------------------===//
// ReductionTreePass
//===----------------------------------------------------------------------===//
/// This class defines the Reduction Tree Pass. It provides a framework to
/// to implement a reduction pass using a tree structure to keep track of the
/// generated reduced variants.
class ReductionTreePass : public ReductionTreeBase<ReductionTreePass> {
public:
ReductionTreePass() = default;
ReductionTreePass(const ReductionTreePass &pass) = default;
LogicalResult initialize(MLIRContext *context) override;
/// Runs the pass instance in the pass pipeline.
void runOnOperation() override;
private:
void reduceOp(ModuleOp module, Region &region);
FrozenRewritePatternSet reducerPatterns;
};
} // end anonymous namespace
LogicalResult ReductionTreePass::initialize(MLIRContext *context) {
RewritePatternSet patterns(context);
ReductionPatternInterfaceCollection reducePatternCollection(context);
reducePatternCollection.populateReductionPatterns(patterns);
reducerPatterns = std::move(patterns);
return success();
}
void ReductionTreePass::runOnOperation() {
Operation *topOperation = getOperation();
while (topOperation->getParentOp() != nullptr)
topOperation = topOperation->getParentOp();
ModuleOp module = cast<ModuleOp>(topOperation);
SmallVector<Operation *, 8> workList;
workList.push_back(getOperation());
do {
Operation *op = workList.pop_back_val();
for (Region &region : op->getRegions())
if (!region.empty())
reduceOp(module, region);
for (Region &region : op->getRegions())
for (Operation &op : region.getOps())
if (op.getNumRegions() != 0)
workList.push_back(&op);
} while (!workList.empty());
}
void ReductionTreePass::reduceOp(ModuleOp module, Region &region) {
Tester test(testerName, testerArgs);
switch (traversalModeId) {
case TraversalMode::SinglePath:
findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, region, reducerPatterns, test);
break;
default:
llvm_unreachable("Unsupported mode");
}
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}

View File

@ -15,7 +15,7 @@
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "mlir/Reducer/Tester.h" #include "mlir/Reducer/Tester.h"
#include "mlir/IR/Verifier.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"
using namespace mlir; using namespace mlir;
@ -25,6 +25,12 @@ Tester::Tester(StringRef scriptName, ArrayRef<std::string> scriptArgs)
std::pair<Tester::Interestingness, size_t> std::pair<Tester::Interestingness, size_t>
Tester::isInteresting(ModuleOp module) const { Tester::isInteresting(ModuleOp module) const {
// The reduced module should always be vaild, or we may end up retaining the
// error message by an invalid case. Besides, an invalid module may not be
// able to print properly.
if (failed(verify(module)))
return std::make_pair(Interestingness::False, /*size=*/0);
SmallString<128> filepath; SmallString<128> filepath;
int fd; int fd;
@ -50,7 +56,6 @@ Tester::isInteresting(ModuleOp module) const {
/// true if the interesting behavior is present in the test case or false /// true if the interesting behavior is present in the test case or false
/// otherwise. /// otherwise.
Tester::Interestingness Tester::isInteresting(StringRef testCase) const { Tester::Interestingness Tester::isInteresting(StringRef testCase) const {
std::vector<StringRef> testerArgs; std::vector<StringRef> testerArgs;
testerArgs.push_back(testCase); testerArgs.push_back(testCase);

View File

@ -60,6 +60,7 @@ add_mlir_library(MLIRTestDialect
MLIRInferTypeOpInterface MLIRInferTypeOpInterface
MLIRLinalgTransforms MLIRLinalgTransforms
MLIRPass MLIRPass
MLIRReduce
MLIRStandard MLIRStandard
MLIRStandardOpsTransforms MLIRStandardOpsTransforms
MLIRTransformUtils MLIRTransformUtils

View File

@ -8,6 +8,7 @@
#include "TestDialect.h" #include "TestDialect.h"
#include "TestAttributes.h" #include "TestAttributes.h"
#include "TestInterfaces.h"
#include "TestTypes.h" #include "TestTypes.h"
#include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
@ -16,6 +17,7 @@
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Reducer/ReductionPatternInterface.h"
#include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h" #include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/StringSwitch.h"
@ -170,6 +172,18 @@ struct TestInlinerInterface : public DialectInlinerInterface {
return builder.create<TestCastOp>(conversionLoc, resultType, input); return builder.create<TestCastOp>(conversionLoc, resultType, input);
} }
}; };
struct TestReductionPatternInterface : public DialectReductionPatternInterface {
public:
TestReductionPatternInterface(Dialect *dialect)
: DialectReductionPatternInterface(dialect) {}
virtual void
populateReductionPatterns(RewritePatternSet &patterns) const final {
populateTestReductionPatterns(patterns);
}
};
} // end anonymous namespace } // end anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -207,7 +221,7 @@ void TestDialect::initialize() {
#include "TestOps.cpp.inc" #include "TestOps.cpp.inc"
>(); >();
addInterfaces<TestOpAsmInterface, TestDialectFoldInterface, addInterfaces<TestOpAsmInterface, TestDialectFoldInterface,
TestInlinerInterface>(); TestInlinerInterface, TestReductionPatternInterface>();
allowUnknownOperations(); allowUnknownOperations();
// Instantiate our fallback op interface that we'll use on specific // Instantiate our fallback op interface that we'll use on specific

View File

@ -34,6 +34,7 @@
namespace mlir { namespace mlir {
class DLTIDialect; class DLTIDialect;
class RewritePatternSet;
} // namespace mlir } // namespace mlir
#include "TestOpEnums.h.inc" #include "TestOpEnums.h.inc"
@ -47,6 +48,7 @@ class DLTIDialect;
namespace mlir { namespace mlir {
namespace test { namespace test {
void registerTestDialect(DialectRegistry &registry); void registerTestDialect(DialectRegistry &registry);
void populateTestReductionPatterns(RewritePatternSet &patterns);
} // namespace test } // namespace test
} // namespace mlir } // namespace mlir

View File

@ -2113,4 +2113,19 @@ def DataLayoutQueryOp : TEST_Op<"data_layout_query"> {
let results = (outs AnyType:$res); let results = (outs AnyType:$res);
} }
//===----------------------------------------------------------------------===//
// Test Reducer Patterns
//===----------------------------------------------------------------------===//
def OpCrashLong : TEST_Op<"op_crash_long"> {
let arguments = (ins I32, I32, I32);
let results = (outs I32);
}
def OpCrashShort : TEST_Op<"op_crash_short"> {
let results = (outs I32);
}
def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>;
#endif // TEST_OPS #endif // TEST_OPS

View File

@ -58,6 +58,14 @@ namespace {
#include "TestPatterns.inc" #include "TestPatterns.inc"
} // end anonymous namespace } // end anonymous namespace
//===----------------------------------------------------------------------===//
// Test Reduce Pattern Interface
//===----------------------------------------------------------------------===//
void mlir::test::populateTestReductionPatterns(RewritePatternSet &patterns) {
populateWithGenerated(patterns);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Canonicalizer Driver. // Canonicalizer Driver.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -38,7 +38,7 @@ void TestReducer::runOnFunction() {
op.walk([&](Operation *op) { op.walk([&](Operation *op) {
StringRef opName = op->getName().getStringRef(); StringRef opName = op->getName().getStringRef();
if (opName == "test.crashOp") { if (opName.contains("op_crash")) {
llvm::errs() << "MLIR Reducer Test generated failure: Found " llvm::errs() << "MLIR Reducer Test generated failure: Found "
"\"crashOp\" operation\n"; "\"crashOp\" operation\n";
exit(1); exit(1);

View File

@ -0,0 +1,20 @@
// UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// "test.op_crash_long" should be replaced with a shorter form "test.op_crash_short".
// CHECK-NOT: func @simple1() {
func @simple1() {
return
}
// CHECK-LABEL: func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
func @simple2(%arg0: i32, %arg1: i32, %arg2: i32) {
// CHECK-LABEL: %0 = "test.op_crash_short"() : () -> i32
%0 = "test.op_crash_long" (%arg0, %arg1, %arg2) : (i32, i32, i32) -> i32
return
}
// CHECK-NOT: func @simple5() {
func @simple5() {
return
}

View File

@ -12,6 +12,6 @@ func nested @dead_nested_function()
// CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { // CHECK-LABEL: func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
"test.crashOp" () : () -> () "test.op_crash" () : () -> ()
return return
} }

View File

@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows // UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s // RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/failure-test.sh' | FileCheck %s
// This input should be reduced by the pass pipeline so that only // This input should be reduced by the pass pipeline so that only
// the @simple5 function remains as this is the shortest function // the @simple5 function remains as this is the shortest function
// containing the interesting behavior. // containing the interesting behavior.
@ -16,7 +16,7 @@ func @simple2() {
// CHECK-LABEL: func @simple3() { // CHECK-LABEL: func @simple3() {
func @simple3() { func @simple3() {
"test.crashOp" () : () -> () "test.op_crash" () : () -> ()
return return
} }
@ -29,7 +29,7 @@ func @simple4(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
%0 = memref.alloc() : memref<2xf32> %0 = memref.alloc() : memref<2xf32>
br ^bb3(%0 : memref<2xf32>) br ^bb3(%0 : memref<2xf32>)
^bb3(%1: memref<2xf32>): ^bb3(%1: memref<2xf32>):
"test.crashOp"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> () "test.op_crash"(%1, %arg2) : (memref<2xf32>, memref<2xf32>) -> ()
return return
} }

View File

@ -1,5 +1,5 @@
// UNSUPPORTED: system-windows // UNSUPPORTED: system-windows
// RUN: mlir-reduce %s -reduction-tree='op-reducer=func traversal-mode=0 test=%S/test.sh' // RUN: mlir-reduce %s -reduction-tree='traversal-mode=0 test=%S/test.sh'
func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) { func @simple1(%arg0: i1, %arg1: memref<2xf32>, %arg2: memref<2xf32>) {
cond_br %arg0, ^bb1, ^bb2 cond_br %arg0, ^bb1, ^bb2

View File

@ -2,6 +2,6 @@
// RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer // RUN: not mlir-opt %s -test-mlir-reducer -pass-test function-reducer
func @test() { func @test() {
"test.crashOp"() : () -> () "test.op_crash"() : () -> ()
return return
} }

View File

@ -43,9 +43,6 @@ set(LIBS
) )
add_llvm_tool(mlir-reduce add_llvm_tool(mlir-reduce
OptReductionPass.cpp
ReductionNode.cpp
ReductionTreePass.cpp
mlir-reduce.cpp mlir-reduce.cpp
ADDITIONAL_HEADER_DIRS ADDITIONAL_HEADER_DIRS

View File

@ -1,107 +0,0 @@
//===- ReductionTreePass.cpp - ReductionTreePass Implementation -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Reduction Tree Pass class. It provides a framework for
// the implementation of different reduction passes in the MLIR Reduce tool. It
// allows for custom specification of the variant generation behavior. It
// implements methods that define the different possible traversals of the
// reduction tree.
//
//===----------------------------------------------------------------------===//
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Passes.h"
#include "llvm/Support/Allocator.h"
using namespace mlir;
static std::unique_ptr<OpReducer> getOpReducer(llvm::StringRef opType) {
if (opType == ModuleOp::getOperationName())
return std::make_unique<Reducer<ModuleOp>>();
else if (opType == FuncOp::getOperationName())
return std::make_unique<Reducer<FuncOp>>();
llvm_unreachable("Now only supports two built-in ops");
}
void ReductionTreePass::runOnOperation() {
ModuleOp module = this->getOperation();
std::unique_ptr<OpReducer> reducer = getOpReducer(opReducerName);
std::vector<std::pair<int, int>> ranges = {
{0, reducer->getNumTargetOps(module)}};
llvm::SpecificBumpPtrAllocator<ReductionNode> allocator;
ReductionNode *root = allocator.Allocate();
new (root) ReductionNode(nullptr, ranges, allocator);
ModuleOp golden = module;
switch (traversalModeId) {
case TraversalMode::SinglePath:
golden = findOptimal<ReductionNode::iterator<TraversalMode::SinglePath>>(
module, std::move(reducer), root);
break;
default:
llvm_unreachable("Unsupported mode");
}
if (golden != module) {
module.getBody()->clear();
module.getBody()->getOperations().splice(module.getBody()->begin(),
golden.getBody()->getOperations());
golden->destroy();
}
}
template <typename IteratorType>
ModuleOp ReductionTreePass::findOptimal(ModuleOp module,
std::unique_ptr<OpReducer> reducer,
ReductionNode *root) {
Tester test(testerName, testerArgs);
std::pair<Tester::Interestingness, size_t> initStatus =
test.isInteresting(module);
if (initStatus.first != Tester::Interestingness::True) {
LLVM_DEBUG(llvm::dbgs() << "\nThe original input is not interested");
return module;
}
root->update(initStatus);
ReductionNode *smallestNode = root;
ModuleOp golden = module;
IteratorType iter(root);
while (iter != IteratorType::end()) {
ModuleOp cloneModule = module.clone();
ReductionNode &currentNode = *iter;
reducer->reduce(cloneModule, currentNode.getRanges());
std::pair<Tester::Interestingness, size_t> result =
test.isInteresting(cloneModule);
currentNode.update(result);
if (result.first == Tester::Interestingness::True &&
result.second < smallestNode->getSize()) {
smallestNode = &currentNode;
golden = cloneModule;
} else {
cloneModule->destroy();
}
++iter;
}
return golden;
}
std::unique_ptr<Pass> mlir::createReductionTreePass() {
return std::make_unique<ReductionTreePass>();
}

View File

@ -13,22 +13,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include <vector>
#include "mlir/InitAllDialects.h" #include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h" #include "mlir/InitAllPasses.h"
#include "mlir/Parser.h" #include "mlir/Parser.h"
#include "mlir/Pass/Pass.h" #include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h" #include "mlir/Pass/PassManager.h"
#include "mlir/Reducer/OptReductionPass.h"
#include "mlir/Reducer/Passes.h" #include "mlir/Reducer/Passes.h"
#include "mlir/Reducer/Passes/OpReducer.h"
#include "mlir/Reducer/ReductionNode.h"
#include "mlir/Reducer/ReductionTreePass.h"
#include "mlir/Reducer/Tester.h"
#include "mlir/Support/FileUtilities.h" #include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LogicalResult.h" #include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/Support/InitLLVM.h" #include "llvm/Support/InitLLVM.h"
#include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/ToolOutputFile.h"