[mlir][Linalg] Avoid unnecessary propagating producer result to fused op result.

Elementwise op fusion conserves the result of the producer in the
fused op, relying on later clean up patterns to drop unused results of
the fused op. Instead, if the producer result has no other use apart
from the consumer op, avoid making the producer result available in
the fused node. This saves some unnecessary IR manipulations.

Differential Revision: https://reviews.llvm.org/D138096
This commit is contained in:
Mahesh Ravishankar 2022-11-16 07:52:34 +00:00
parent 57fd7ffeff
commit 2d4b998697
5 changed files with 91 additions and 19 deletions

View File

@ -36,6 +36,11 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
continue;
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand));
}
if (indexingMaps.empty()) {
// If there are no indexing maps, the operand can only be dropped
// if the op has no loops.
return linalgOp.getNumLoops() == 0;
}
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}

View File

@ -143,10 +143,10 @@ bool mlir::linalg::areElementwiseOpsFusable(OpOperand *fusedOperand) {
/// Generate the region of the fused tensor operation. The region of the fused
/// op must be empty.
static void
generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
AffineMap consumerToProducerLoopsMap,
OpOperand *fusedOperand, unsigned nloops) {
static void generateFusedElementwiseOpRegion(
RewriterBase &rewriter, GenericOp fusedOp,
AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand,
unsigned nloops, llvm::SmallDenseSet<int> &preservedProducerResults) {
auto producer = cast<GenericOp>(fusedOperand->get().getDefiningOp());
auto consumer = cast<GenericOp>(fusedOperand->getOwner());
// Build the region of the fused op.
@ -202,9 +202,13 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
// 6. All of the producer's output operands
for (BlockArgument bbArg :
producerBlock.getArguments().take_back(producer.getNumDpsInits()))
mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc()));
for (auto bbArg : llvm::enumerate(
producerBlock.getArguments().take_back(producer.getNumDpsInits()))) {
if (!preservedProducerResults.count(bbArg.index()))
continue;
mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(),
bbArg.value().getLoc()));
}
// 7. All of consumer's output operands.
for (BlockArgument bbArg :
@ -247,8 +251,11 @@ generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp,
SmallVector<Value> fusedYieldValues;
fusedYieldValues.reserve(producerYieldOp.getNumOperands() +
consumerYieldOp.getNumOperands());
for (auto producerYieldVal : producerYieldOp.getOperands())
fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal));
for (auto producerYieldVal : llvm::enumerate(producerYieldOp.getOperands())) {
if (preservedProducerResults.count(producerYieldVal.index()))
fusedYieldValues.push_back(
mapper.lookupOrDefault(producerYieldVal.value()));
}
for (auto consumerYieldVal : consumerYieldOp.getOperands())
fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal));
rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues);
@ -269,6 +276,18 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
// TODO: allow fusing the producer of an output operand.
assert(consumer.isDpsInput(fusedOperand) &&
"expected producer of input operand");
/// Find the results of the producer that have uses outside of the consumer.
llvm::SmallDenseSet<int> preservedProducerResults;
for (auto producerResult : llvm::enumerate(producer->getResults())) {
auto outputOperand = producer.getDpsInitOperand(producerResult.index());
if (producer.payloadUsesValueFromOperand(outputOperand) ||
!producer.canOpOperandsBeDropped(outputOperand) ||
llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) {
return user != consumer.getOperation();
})) {
preservedProducerResults.insert(producerResult.index());
}
}
// Compute the fused operands list and indexing maps.
SmallVector<Value> fusedInputOperands, fusedOutputOperands;
@ -276,9 +295,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
SmallVector<AffineMap> fusedIndexMaps;
fusedInputOperands.reserve(producer.getNumDpsInputs() +
consumer.getNumDpsInputs());
fusedOutputOperands.reserve(producer.getNumDpsInits() +
fusedOutputOperands.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
fusedResultTypes.reserve(producer.getNumDpsInits() +
fusedResultTypes.reserve(preservedProducerResults.size() +
consumer.getNumDpsInits());
fusedIndexMaps.reserve(producer->getNumOperands() +
consumer->getNumOperands());
@ -313,13 +332,16 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
}
// 6. Collect all of the producer outputs.
for (OpOperand *opOperand : producer.getDpsInitOperands()) {
fusedOutputOperands.push_back(opOperand->get());
for (auto opOperand : llvm::enumerate(producer.getDpsInitOperands())) {
if (!preservedProducerResults.count(opOperand.index()))
continue;
fusedOutputOperands.push_back(opOperand.value()->get());
AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp(
opOperand, producerResultIndexMap,
opOperand.value(), producerResultIndexMap,
consumer.getMatchingIndexingMap(fusedOperand));
fusedIndexMaps.push_back(map);
fusedResultTypes.push_back(opOperand->get().getType());
fusedResultTypes.push_back(opOperand.value()->get().getType());
}
// 7. All of consumer's output operands (skip operands: added by the builder).
@ -358,9 +380,9 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter,
AffineMap consumerToProducerLoopsMap =
invProducerResultIndexMap.compose(consumerResultIndexMap);
generateFusedElementwiseOpRegion(rewriter, fusedOp,
consumerToProducerLoopsMap, fusedOperand,
consumer.getNumLoops());
generateFusedElementwiseOpRegion(
rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand,
consumer.getNumLoops(), preservedProducerResults);
return fusedOp.getOperation();
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#binary2Dpointwise = {

View File

@ -0,0 +1,30 @@
// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops-control -split-input-file | FileCheck %s
#map = affine_map<(d0, d1) -> (d0, d1)>
func.func @drop_unused_producer_result(%arg0 : tensor<?x?xf32>,
%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%0:2 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<?x?xf32>) outs(%arg0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>) {
^bb0(%b0: f32, %b1: f32, %b2: f32):
%1 = arith.addf %b0, %b0 : f32
%2 = arith.mulf %b0, %b0 : f32
linalg.yield %1, %2 : f32, f32
} -> (tensor<?x?xf32>, tensor<?x?xf32>)
%3 = linalg.generic {
indexing_maps = [#map, #map, #map],
iterator_types = ["parallel", "parallel"]}
ins(%0#0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%arg0 : tensor<?x?xf32>) {
^bb0(%b0: f32, %b1: f32, %b2: f32):
%4 = arith.subf %b0, %b1 : f32
linalg.yield %4 : f32
} -> tensor<?x?xf32>
return %3 : tensor<?x?xf32>
}
// CHECK-LABEL: func @drop_unused_producer_result
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK: %[[FUSED_OP:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
// CHECK: return %[[FUSED_OP]]

View File

@ -75,6 +75,12 @@ struct TestLinalgElementwiseFusion
llvm::cl::desc("Test fusion of generic operations."),
llvm::cl::init(false)};
Option<bool> fuseGenericOpsControl{
*this, "fuse-generic-ops-control",
llvm::cl::desc(
"Test fusion of generic operations with a control function."),
llvm::cl::init(false)};
Option<bool> fuseWithReshapeByExpansion{
*this, "fuse-with-reshape-by-expansion",
llvm::cl::desc(
@ -108,6 +114,15 @@ struct TestLinalgElementwiseFusion
func::FuncOp funcOp = this->getOperation();
if (fuseGenericOps) {
RewritePatternSet fusionPatterns(context);
auto controlFn = [](OpOperand *operand) { return true; };
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns, controlFn);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(),
std::move(fusionPatterns));
return;
}
if (fuseGenericOpsControl) {
RewritePatternSet fusionPatterns(context);
linalg::populateElementwiseOpsFusionPatterns(fusionPatterns,
setFusedOpOperandLimit<4>);