Revert "[mlir][linalg] Add a new pattern to handle folding unit reduction dims."
This reverts commit 6eee66d12a
.
It breaks builds, see https://lab.llvm.org/buildbot/#/builders/61/builds/35742
Differential Revision: https://reviews.llvm.org/D138633
This commit is contained in:
parent
eca62f9204
commit
a827c5c7ab
|
@ -19,15 +19,12 @@
|
|||
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/Transforms/FoldUtils.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
|
@ -228,125 +225,6 @@ struct FoldUnitDimLoops : public OpRewritePattern<GenericOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Pattern to add init operands to ins when all the loops are parallel and
|
||||
/// blockArgument corresponding to init is used in the region. This is a fix-up
|
||||
/// when unit reduction dimensions are all folded away. In this context, it
|
||||
/// becomes a elementwise generic op. E.g., it converts
|
||||
///
|
||||
/// %0 = tensor.empty() : tensor<1x1xf32>
|
||||
/// %1 = linalg.fill
|
||||
/// ins(%cst : f32)
|
||||
/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||
/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
|
||||
/// affine_map<(d0) -> (0, d0)>],
|
||||
/// iterator_types = ["parallel"]}
|
||||
/// ins(%arg0 : tensor<1x?x1x1xf32>)
|
||||
/// outs(%1 : tensor<1x1xf32>) {
|
||||
/// ^bb0(%in: f32, %out: f32):
|
||||
/// %3 = arith.addf %in, %out : f32
|
||||
/// linalg.yield %3 : f32
|
||||
/// } -> tensor<1x1xf32>
|
||||
///
|
||||
/// into
|
||||
///
|
||||
/// %0 = tensor.empty() : tensor<1x1xf32>
|
||||
/// %1 = linalg.fill
|
||||
/// ins(%cst : f32)
|
||||
/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||
/// %2 = tensor.empty() : tensor<1x1xf32>
|
||||
/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
|
||||
/// affine_map<(d0) -> (0, d0)>,
|
||||
/// affine_map<(d0) -> (0, d0)>],
|
||||
/// iterator_types = ["parallel"]}
|
||||
/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
|
||||
/// outs(%2 : tensor<1x1xf32>) {
|
||||
/// ^bb0(%in: f32, %in_0: f32, %out: f32):
|
||||
/// %4 = arith.addf %in, %in_0 : f32
|
||||
/// linalg.yield %4 : f32
|
||||
/// } -> tensor<1x1xf32>
|
||||
struct AddInitOperandsToInput : public OpRewritePattern<GenericOp> {
|
||||
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(GenericOp genericOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!genericOp.hasTensorSemantics())
|
||||
return failure();
|
||||
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
|
||||
return failure();
|
||||
|
||||
auto outputOperands = genericOp.getDpsInitOperands();
|
||||
SetVector<OpOperand *> candidates;
|
||||
for (OpOperand *op : outputOperands) {
|
||||
if (genericOp.getMatchingBlockArgument(op).use_empty())
|
||||
continue;
|
||||
candidates.insert(op);
|
||||
}
|
||||
|
||||
if (candidates.empty())
|
||||
return failure();
|
||||
|
||||
// Compute the modified indexing maps.
|
||||
int64_t origNumInput = genericOp.getNumDpsInputs();
|
||||
SmallVector<Value> newInputOperands = genericOp.getDpsInputOperands();
|
||||
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
|
||||
SmallVector<AffineMap> newIndexingMaps;
|
||||
newIndexingMaps.append(indexingMaps.begin(),
|
||||
std::next(indexingMaps.begin(), origNumInput));
|
||||
for (OpOperand *op : candidates) {
|
||||
newInputOperands.push_back(op->get());
|
||||
newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
|
||||
}
|
||||
newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
|
||||
indexingMaps.end());
|
||||
|
||||
Location loc = genericOp.getLoc();
|
||||
SmallVector<Value> newOutputOperands = outputOperands;
|
||||
for (OpOperand *op : candidates) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointAfterValue(op->get());
|
||||
auto elemType = op->get().getType().cast<ShapedType>().getElementType();
|
||||
auto empty = rewriter.create<tensor::EmptyOp>(
|
||||
loc, tensor::createDimValues(rewriter, loc, op->get()), elemType);
|
||||
|
||||
auto [start, end] = genericOp.getDpsInitsPositionRange();
|
||||
newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
|
||||
}
|
||||
|
||||
auto newOp = rewriter.create<GenericOp>(
|
||||
loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands,
|
||||
newIndexingMaps, genericOp.getIteratorTypesArray(),
|
||||
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
|
||||
|
||||
Region ®ion = newOp.getRegion();
|
||||
Block *block = new Block();
|
||||
region.push_back(block);
|
||||
BlockAndValueMapping mapper;
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(block);
|
||||
for (auto bbarg : genericOp.getRegionInputArgs())
|
||||
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
|
||||
|
||||
for (OpOperand *op : candidates) {
|
||||
BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
|
||||
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
|
||||
}
|
||||
|
||||
for (OpOperand *op : outputOperands) {
|
||||
BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
|
||||
if (candidates.count(op))
|
||||
block->addArgument(bbarg.getType(), loc);
|
||||
else
|
||||
mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
|
||||
}
|
||||
|
||||
for (auto &op : genericOp.getBody()->getOperations()) {
|
||||
rewriter.clone(op, mapper);
|
||||
}
|
||||
rewriter.replaceOp(genericOp, newOp.getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct UnitExtentReplacementInfo {
|
||||
Type type;
|
||||
AffineMap indexMap;
|
||||
|
@ -658,8 +536,7 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
|
|||
void mlir::linalg::populateFoldUnitExtentDimsPatterns(
|
||||
RewritePatternSet &patterns) {
|
||||
auto *context = patterns.getContext();
|
||||
patterns.add<FoldUnitDimLoops, AddInitOperandsToInput, ReplaceUnitExtents,
|
||||
RankReducedExtractSliceOp,
|
||||
patterns.add<FoldUnitDimLoops, ReplaceUnitExtents, RankReducedExtractSliceOp,
|
||||
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
|
||||
RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
|
||||
context);
|
||||
|
@ -667,8 +544,6 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns(
|
|||
tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
|
||||
tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
|
||||
tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
|
||||
memref::populateResolveRankedShapeTypeResultDimsPatterns(patterns);
|
||||
memref::populateResolveShapedTypeResultDimsPatterns(patterns);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -680,7 +555,7 @@ struct LinalgFoldUnitExtentDimsPass
|
|||
MLIRContext *context = op->getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
if (foldOneTripLoopsOnly)
|
||||
patterns.add<FoldUnitDimLoops, AddInitOperandsToInput>(context);
|
||||
patterns.add<FoldUnitDimLoops>(context);
|
||||
else
|
||||
populateFoldUnitExtentDimsPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(op, std::move(patterns));
|
||||
|
|
|
@ -384,12 +384,11 @@ func.func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1
|
|||
// CHECK-DAG: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3]
|
||||
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1xf32>
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill ins(%{{.+}}{{.*}}outs(%[[INIT]]
|
||||
// CHECK: %[[INIT2:.+]] = tensor.empty() : tensor<1xf32>
|
||||
// CHECK: %[[RESULT:.+]] = linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]]
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]]
|
||||
// CHECK-SAME: iterator_types = ["parallel"]
|
||||
// CHECK-SAME: ins(%[[RESHAPE]], %[[FILL]] : tensor<?xf32>, tensor<1xf32>)
|
||||
// CHECK-SAME: outs(%[[INIT2]] : tensor<1xf32>)
|
||||
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>)
|
||||
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]]
|
||||
// CHECK: return %[[RESULT_RESHAPE]]
|
||||
|
||||
|
|
|
@ -8301,7 +8301,6 @@ cc_library(
|
|||
":LinalgUtils",
|
||||
":MathDialect",
|
||||
":MemRefDialect",
|
||||
":MemRefTransforms",
|
||||
":Pass",
|
||||
":SCFDialect",
|
||||
":SCFTransforms",
|
||||
|
|
Loading…
Reference in New Issue