[mlir] Add canonicalizer to remove redundant shape.cstr_broadcastable ops
Depends On D119025 Reviewed By: frgossen Differential Revision: https://reviews.llvm.org/D119043
This commit is contained in:
parent
1ef04326ec
commit
edca177cbe
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/SetOperations.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -493,6 +494,99 @@ struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
|
|||
}
|
||||
};
|
||||
|
||||
// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
|
||||
// are subsumed by others.
|
||||
//
|
||||
// %0 = shape.cstr_broadcastable %shape0, %shape1
|
||||
// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
|
||||
//
|
||||
// %2 = shape.cstr_broadcastable %shape3, %shape4
|
||||
// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
|
||||
//
|
||||
// %4 = shape.assuming_all %0, %1, %2, %3
|
||||
//
|
||||
// to:
|
||||
//
|
||||
// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
|
||||
// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
|
||||
// %2 = shape.assuming_all %0, %1
|
||||
//
|
||||
// In this example if shapes [0, 1, 2] are broadcastable, then it means that
|
||||
// shapes [0, 1] are broadcastable too, and can be removed from the list of
|
||||
// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
|
||||
// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
|
||||
struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
|
||||
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(AssumingAllOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Collect all `CstrBroadcastableOp` operands first.
|
||||
SetVector<CstrBroadcastableOp> operands;
|
||||
for (Value operand : op.getInputs()) {
|
||||
// TODO: Apply this optimization if some of the witnesses are not
|
||||
// produced by the `cstr_broadcastable`.
|
||||
auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
|
||||
if (!broadcastable)
|
||||
return failure();
|
||||
|
||||
operands.insert(broadcastable);
|
||||
}
|
||||
|
||||
// Skip trivial `assuming_all` operations.
|
||||
if (operands.size() <= 1)
|
||||
return failure();
|
||||
|
||||
// Collect shapes checked by `cstr_broadcastable` operands.
|
||||
SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
|
||||
for (auto cstr : operands) {
|
||||
DenseSet<Value> shapes_set(cstr->operand_begin(), cstr->operand_end());
|
||||
shapes.emplace_back(cstr, std::move(shapes_set));
|
||||
}
|
||||
|
||||
// Sort by the number of shape operands (larger to smaller).
|
||||
llvm::sort(shapes, [](auto a, auto b) {
|
||||
return a.first.getNumOperands() > b.first.getNumOperands();
|
||||
});
|
||||
|
||||
// We start from the `cst_broadcastable` operations with largest number of
|
||||
// shape operands, and remove redundant `cst_broadcastable` operations. We
|
||||
// do this until we find a set of `cst_broadcastable` operations with
|
||||
// non-overlapping constraints.
|
||||
SmallVector<CstrBroadcastableOp> marked_for_erase;
|
||||
|
||||
for (unsigned i = 0; i < shapes.size(); ++i) {
|
||||
auto isSubset = [&](auto pair) {
|
||||
return llvm::set_is_subset(pair.second, shapes[i].second);
|
||||
};
|
||||
|
||||
// Keep redundant `cstr_broadcastable` operations to be erased.
|
||||
auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
|
||||
for (auto *it0 = it; it0 < shapes.end(); ++it0)
|
||||
marked_for_erase.push_back(it0->first);
|
||||
shapes.erase(it, shapes.end());
|
||||
}
|
||||
|
||||
// We didn't find any operands that could be removed.
|
||||
if (marked_for_erase.empty())
|
||||
return failure();
|
||||
|
||||
// Collect non-overlapping `cst_broadcastable` constraints.
|
||||
SmallVector<Value> unique_constraints;
|
||||
for (auto &shape : shapes)
|
||||
unique_constraints.push_back(shape.first.getResult());
|
||||
|
||||
// Replace with a new `assuming_all` operation ...
|
||||
rewriter.replaceOpWithNewOp<AssumingAllOp>(op, unique_constraints);
|
||||
|
||||
// ... and maybe erase `cstr_broadcastable` ops without uses.
|
||||
for (auto &op : marked_for_erase)
|
||||
if (op->use_empty())
|
||||
rewriter.eraseOp(op);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct AssumingAllToCstrEqCanonicalization
|
||||
: public OpRewritePattern<AssumingAllOp> {
|
||||
using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
|
||||
|
@ -539,9 +633,10 @@ struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
|
|||
|
||||
void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<MergeAssumingAllOps, AssumingAllOneOp,
|
||||
AssumingAllToCstrEqCanonicalization,
|
||||
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
|
||||
patterns
|
||||
.add<MergeAssumingAllOps, AssumingAllOneOp,
|
||||
AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
|
||||
RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
|
||||
}
|
||||
|
||||
OpFoldResult AssumingAllOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
|
|
@ -565,6 +565,46 @@ func @f() {
|
|||
|
||||
// -----
|
||||
|
||||
// merge cstr_broadcastable operations
|
||||
//
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
|
||||
// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
|
||||
// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
|
||||
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
|
||||
// CHECK-NEXT: %[[W:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]], %[[ARG2]]
|
||||
// CHECK-NEXT: "consume.witness"(%[[W]])
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
|
||||
%1 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : !shape.shape, !shape.shape, !shape.shape
|
||||
%2 = shape.assuming_all %0, %1
|
||||
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// do not merge cstr_broadcastable operations
|
||||
//
|
||||
// CHECK-LABEL: func @f
|
||||
// CHECK: %[[ARG0:[a-z0-9]*]]: !shape.shape
|
||||
// CHECK-SAME: %[[ARG1:[a-z0-9]*]]: !shape.shape
|
||||
// CHECK-SAME: %[[ARG2:[a-z0-9]*]]: !shape.shape
|
||||
func @f(%arg0 : !shape.shape, %arg1 : !shape.shape, %arg2 : !shape.shape) {
|
||||
// CHECK-NEXT: %[[W0:.*]] = shape.cstr_broadcastable %[[ARG0]], %[[ARG1]]
|
||||
// CHECK-NEXT: %[[W1:.*]] = shape.cstr_broadcastable %[[ARG1]], %[[ARG2]]
|
||||
// CHECK-NEXT: %[[W2:.*]] = shape.assuming_all %[[W0]], %[[W1]]
|
||||
// CHECK-NEXT: "consume.witness"(%[[W2]])
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape
|
||||
%1 = shape.cstr_broadcastable %arg1, %arg2 : !shape.shape, !shape.shape
|
||||
%2 = shape.assuming_all %0, %1
|
||||
"consume.witness"(%2) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// any can be replaced with a constant input if it has one.
|
||||
// CHECK-LABEL: func @f
|
||||
func @f(%arg : !shape.shape) -> !shape.shape {
|
||||
|
|
Loading…
Reference in New Issue