[mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops

Depends On D119025

Reviewed By: frgossen

Differential Revision: https://reviews.llvm.org/D119043
This commit is contained in:
Eugene Zhulenev 2022-02-06 14:38:33 -08:00
parent 1ef04326ec
commit edca177cbe
2 changed files with 138 additions and 3 deletions

View File

@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/raw_ostream.h"
@ -493,6 +494,99 @@ struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
}
};
// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
// are subsumed by others.
//
// %0 = shape.cstr_broadcastable %shape0, %shape1
// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
//
// %2 = shape.cstr_broadcastable %shape3, %shape4
// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
//
// %4 = shape.assuming_all %0, %1, %2, %3
//
// to:
//
// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
// %2 = shape.assuming_all %0, %1
//
// In this example if shapes [0, 1, 2] are broadcastable, then it means that
// shapes [0, 1] are broadcastable too, and can be removed from the list of
// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AssumingAllOp op,
PatternRewriter &rewriter) const override {
// Collect all `CstrBroadcastableOp` operands first.
SetVector<CstrBroadcastableOp> operands;
for (Value operand : op.getInputs()) {
// TODO: Apply this optimization if some of the witnesses are not
// produced by the `cstr_broadcastable`.
auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
if (!broadcastable)
return failure();
operands.insert(broadcastable);
}
// Skip trivial `assuming_all` operations.
if (operands.size() <= 1)
return failure();
// Collect shapes checked by `cstr_broadcastable` operands.
SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
for (auto cstr : operands) {
DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
shapes.emplace_back(cstr, std::move(shapes_set));
}
// Sort by the number of shape operands (larger to smaller).
llvm::sort(shapes, [](auto a, auto b) {
return a.first.getNumOperands() > b.first.getNumOperands();
});
// We start from the `cst_broadcastable` operations with largest number of
// shape operands, and remove redundant `cst_broadcastable` operations. We
// do this until we find a set of `cst_broadcastable` operations with
// non-overlapping constraints.
SmallVector<CstrBroadcastableOp> marked_for_erase;
for (unsigned i = 0; i < shapes.size(); ++i) {
auto isSubset = [&](auto pair) {
return llvm::set_is_subset(pair.second, shapes[i].second);
};
// Keep redundant `cstr_broadcastable` operations to be erased.
auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
for (auto *it0 = it; it0 < shapes.end(); ++it0)
marked_for_erase.push_back(it0->first);
shapes.erase(it, shapes.end());
}
// We didn't find any operands that could be removed.
if (marked_for_erase.empty())
return failure();
// Collect non-overlapping `cst_broadcastable` constraints.
SmallVector<Value> unique_constraints;
for (auto &shape : shapes)
unique_constraints.push_back(shape.first.getResult());
// Replace with a new `assuming_all` operation ...
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
// ... and maybe erase `cstr_broadcastable` ops without uses.
for (auto &op : marked_for_erase)
if (op->use_empty())
rewriter.eraseOp(op);
return success();
}
};
struct AssumingAllToCstrEqCanonicalization
: public OpRewritePattern<AssumingAllOp> {
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
@ -539,9 +633,10 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
patterns
.add<MergeAssumingAllOps, AssumingAllOneOp,
AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
}
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {

View File

@ -565,6 +565,46 @@ func @f() {
// -----
// merge cstr_broadcastable operations
//
// CHECK-LABEL: func @f
// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
// CHECK-NEXT: %[[W:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]], %[[ARG2]]
// CHECK-NEXT: "consume.witness"(%[[W]])
// CHECK-NEXT: return
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : !shape.shape, !shape.shape, !shape.shape
%2 = shape.assuming_all %0, %1
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
// -----
// do not merge cstr_broadcastable operations
//
// CHECK-LABEL: func @f
// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
// CHECK-NEXT: %[[W0:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
// CHECK-NEXT: %[[W1:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
// CHECK-NEXT: %[[W2:.*]] = shape.assuming_all %[[W0]], %[[W1]]
// CHECK-NEXT: "consume.witness"(%[[W2]])
// CHECK-NEXT: return
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
%1 = shape.cstr_broadcastable %arg1, %arg2 : !shape.shape, !shape.shape
%2 = shape.assuming_all %0, %1
"consume.witness"(%2) : (!shape.witness) -> ()
return
}
// -----
// any can be replaced with a constant input if it has one.
// CHECK-LABEL: func @f
func @f(%arg : !shape.shape) -> !shape.shape {