[mlir][linalg] Replace AffineMinSCFCanonicalizationPattern with SCF reimplementation
Use the new canonicalization pattern in the SCF dialect. Differential Revision: https://reviews.llvm.org/D107732
This commit is contained in:
parent
629411d799
commit
2de2dbef2a
|
@ -982,69 +982,6 @@ struct LinalgCopyVTWForwardingPattern
|
|||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
using GetMinMaxExprFn =
|
||||
std::function<Optional<std::pair<AffineExpr, AffineExpr>>(
|
||||
Value value, SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols)>;
|
||||
|
||||
/// Canonicalize AffineMinOp operations in the context of ops with a known range
|
||||
/// by:
|
||||
/// 1. building an affine map where uses of the known ops are replaced by
|
||||
/// their min annd max expressions returned by the lambda `getMinMaxFn`.
|
||||
/// 2. checking whether any of the results of this affine map is known to be
|
||||
/// greater than all other results.
|
||||
/// 3. replacing the AffineMinOp by the result of (2).
|
||||
struct AffineMinRangeCanonicalizationPattern
|
||||
: public OpRewritePattern<AffineMinOp> {
|
||||
AffineMinRangeCanonicalizationPattern(MLIRContext *context,
|
||||
GetMinMaxExprFn getMinMaxFn)
|
||||
: OpRewritePattern<AffineMinOp>(context), getMinMaxFn(getMinMaxFn) {}
|
||||
LogicalResult matchAndRewrite(AffineMinOp minOp,
|
||||
PatternRewriter &rewriter) const override;
|
||||
|
||||
protected:
|
||||
GetMinMaxExprFn getMinMaxFn;
|
||||
};
|
||||
|
||||
/// Specialized version of `AffineMinRangeCanonicalizationPattern` pattern
|
||||
/// using `getSCFMinMaxExpr` to know the min and max expression of induction
|
||||
/// variables from scf loops.
|
||||
// TODO: move to a more appropriate place when it is determined. For now Linalg
|
||||
// depends both on Affine and SCF but they do not depend on each other.
|
||||
struct AffineMinSCFCanonicalizationPattern
|
||||
: public AffineMinRangeCanonicalizationPattern {
|
||||
static Optional<std::pair<AffineExpr, AffineExpr>>
|
||||
getMinMax(Value value, SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
return getSCFMinMaxExpr(value, dims, symbols);
|
||||
}
|
||||
AffineMinSCFCanonicalizationPattern(MLIRContext *context)
|
||||
: AffineMinRangeCanonicalizationPattern(context, getMinMax) {}
|
||||
};
|
||||
|
||||
/// Helper struct to return the results of `substituteMin`.
|
||||
struct AffineMapAndOperands {
|
||||
AffineMap map;
|
||||
SmallVector<Value> dims;
|
||||
SmallVector<Value> symbols;
|
||||
};
|
||||
|
||||
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
|
||||
/// dimensions with known range by new expressions involving the min or max
|
||||
/// expression:
|
||||
/// - If the AffineDimExpr mapped to a known value has a positive sign, it
|
||||
/// is replaced by the min expression.
|
||||
/// - If the AffineDimExpr mapped to a known value has a negative sign, it is
|
||||
/// replaced by the max expression.
|
||||
/// All known values are iteratively replaced.
|
||||
/// This is used as an intermediate step in computing bounding boxes and
|
||||
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
|
||||
/// positive values (positive orthant assumptions).
|
||||
/// Return a new AffineMap, dims and symbols that have been canonicalized and
|
||||
/// simplified.
|
||||
AffineMapAndOperands substituteMin(AffineMinOp affineMinOp,
|
||||
GetMinMaxExprFn getMinMaxExpr);
|
||||
|
||||
/// Converts Convolution op into vector contraction.
|
||||
///
|
||||
/// Conversion expects ConvOp to have dimensions marked in the *mask* as
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include "mlir/Dialect/Linalg/Transforms/CodegenStrategy.h"
|
||||
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/Vector/VectorOps.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -47,7 +48,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
|
|||
|
||||
RewritePatternSet stage2Patterns =
|
||||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
stage2Patterns.add<AffineMinSCFCanonicalizationPattern>(context);
|
||||
scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns);
|
||||
|
||||
auto stage3Transforms = [&](Operation *op) {
|
||||
// Some of these may be too aggressive as a stage 3 that is applied on each
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
|
@ -536,7 +537,7 @@ applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp,
|
|||
MLIRContext *ctx = funcOp.getContext();
|
||||
RewritePatternSet patterns(ctx);
|
||||
insertTilingPatterns(patterns, options);
|
||||
patterns.add<AffineMinSCFCanonicalizationPattern>(patterns.getContext());
|
||||
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
|
||||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
(void)applyPatternsAndFoldGreedily(
|
||||
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
|
||||
|
|
|
@ -494,145 +494,6 @@ LogicalResult mlir::linalg::applyStagedPatterns(
|
|||
return success();
|
||||
}
|
||||
|
||||
/// Traverse the `dims` and substitute known min or max expressions returned by
|
||||
/// the lambda |getMinMaxExpr|.
|
||||
static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols,
|
||||
GetMinMaxExprFn getMinMaxExpr) {
|
||||
auto exprs = llvm::to_vector<4>(map.getResults());
|
||||
for (AffineExpr &expr : exprs) {
|
||||
bool substituted = true;
|
||||
while (substituted) {
|
||||
substituted = false;
|
||||
for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
|
||||
Value dim = dims[dimIdx];
|
||||
auto minMax = getMinMaxExpr(dim, dims, symbols);
|
||||
if (!minMax)
|
||||
continue;
|
||||
AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
|
||||
LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
|
||||
LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
|
||||
// Substitute occurrences of `dimExpr` by either the min expression or
|
||||
// the max expression depending on whether the value is used with a
|
||||
// positive or negative coefficient.
|
||||
AffineExpr substitutedExpr =
|
||||
substWithMin(expr, dimExpr, minMax->first, minMax->second);
|
||||
LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n");
|
||||
substituted = (substitutedExpr != expr);
|
||||
expr = substitutedExpr;
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup and simplify the results.
|
||||
// This needs to happen outside of the loop iterating on dims.size() since
|
||||
// it modifies dims.
|
||||
SmallVector<Value, 4> operands(dims.begin(), dims.end());
|
||||
operands.append(symbols.begin(), symbols.end());
|
||||
auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
|
||||
exprs.front().getContext());
|
||||
|
||||
LLVM_DEBUG({
|
||||
DBGS() << "Map to simplify: " << map << "\n";
|
||||
DBGS() << "Operands:\n";
|
||||
for (Value v : operands)
|
||||
DBGS() << v << "\n";
|
||||
});
|
||||
|
||||
// Pull in affine.apply operations and compose them fully into the
|
||||
// result.
|
||||
fullyComposeAffineMapAndOperands(&map, &operands);
|
||||
canonicalizeMapAndOperands(&map, &operands);
|
||||
map = simplifyAffineMap(map);
|
||||
// Assign the results.
|
||||
exprs.assign(map.getResults().begin(), map.getResults().end());
|
||||
dims.assign(operands.begin(), operands.begin() + map.getNumDims());
|
||||
symbols.assign(operands.begin() + map.getNumDims(), operands.end());
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
|
||||
}
|
||||
|
||||
assert(!exprs.empty() && "Unexpected empty exprs");
|
||||
return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
|
||||
}
|
||||
|
||||
/// Traverse the dims of the AffineMap of `affineMinOp` and substitute
|
||||
/// dimensions with known range by new expressions involving the min or max
|
||||
/// expression:
|
||||
/// - If the AffineDimExpr mapped to a known value has a positive sign, it
|
||||
/// is replaced by the min expression.
|
||||
/// - If the AffineDimExpr mapped to a known value has a negative sign, it is
|
||||
/// replaced by the max expression.
|
||||
/// All known values are iteratively replaced.
|
||||
/// This is used as an intermediate step in computing bounding boxes and
|
||||
/// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
|
||||
/// positive values (positive orthant assumptions).
|
||||
/// Return a new AffineMap, dims and symbols that have been canonicalized and
|
||||
/// simplified.
|
||||
AffineMapAndOperands
|
||||
mlir::linalg::substituteMin(AffineMinOp affineMinOp,
|
||||
GetMinMaxExprFn getMinMaxExpr) {
|
||||
AffineMapAndOperands res{affineMinOp.getAffineMap(),
|
||||
SmallVector<Value>(affineMinOp.getDimOperands()),
|
||||
SmallVector<Value>(affineMinOp.getSymbolOperands())};
|
||||
res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
|
||||
getMinMaxExpr);
|
||||
return res;
|
||||
}
|
||||
|
||||
LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
|
||||
AffineMinOp minOp, PatternRewriter &rewriter) const {
|
||||
LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
|
||||
<< "\n");
|
||||
|
||||
auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn);
|
||||
AffineMap map = affineMapAndOperands.map;
|
||||
|
||||
LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
|
||||
|
||||
// Check whether any of the expressions, when subtracted from all other
|
||||
// expressions, produces only >= 0 constants. If so, it is the min.
|
||||
for (auto e : minOp.getAffineMap().getResults()) {
|
||||
LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
|
||||
if (!e.isSymbolicOrConstant())
|
||||
continue;
|
||||
|
||||
auto isNonPositive = [](AffineExpr e) {
|
||||
if (auto cst = e.dyn_cast<AffineConstantExpr>())
|
||||
return cst.getValue() < 0;
|
||||
return true;
|
||||
};
|
||||
|
||||
// Build the subMap and check everything is statically known to be
|
||||
// positive.
|
||||
SmallVector<AffineExpr, 4> subExprs;
|
||||
subExprs.reserve(map.getNumResults());
|
||||
for (auto ee : map.getResults())
|
||||
subExprs.push_back(ee - e);
|
||||
MLIRContext *ctx = minOp.getContext();
|
||||
AffineMap subMap = simplifyAffineMap(
|
||||
AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
|
||||
LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
|
||||
if (llvm::any_of(subMap.getResults(), isNonPositive))
|
||||
continue;
|
||||
|
||||
// Static min found.
|
||||
if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
|
||||
rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
|
||||
} else {
|
||||
auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
|
||||
SmallVector<Value> resultOperands = affineMapAndOperands.dims;
|
||||
llvm::append_range(resultOperands, affineMapAndOperands.symbols);
|
||||
canonicalizeMapAndOperands(&resultMap, &resultOperands);
|
||||
resultMap = simplifyAffineMap(resultMap);
|
||||
rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
|
||||
resultOperands);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
return failure();
|
||||
}
|
||||
|
||||
static SmallVector<StringRef> getNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||
return SmallVector<StringRef>(nParallelLoops, getParallelIteratorTypeName());
|
||||
}
|
||||
|
|
|
@ -145,43 +145,3 @@ bool mlir::getInnermostParallelLoops(Operation *rootOp,
|
|||
}
|
||||
return rootEnclosesPloops;
|
||||
}
|
||||
|
||||
/// Given the `lbVal`, `ubVal` and `stepVal` of a loop, append `lbVal` and
|
||||
/// `ubVal` to `dims` and `stepVal` to `symbols`.
|
||||
/// Create new AffineDimExpr (`%lb` and `%ub`) and AffineSymbolExpr (`%step`)
|
||||
/// with positions matching the newly appended values. Then create a min
|
||||
/// expression (i.e. `%lb`) and a max expression
|
||||
/// (i.e. `%lb + %step * floordiv(%ub -1 - %lb, %step)`.
|
||||
static std::pair<AffineExpr, AffineExpr>
|
||||
getMinMaxLoopIndVar(Value lbVal, Value ubVal, Value stepVal,
|
||||
SmallVectorImpl<Value> &dims,
|
||||
SmallVectorImpl<Value> &symbols) {
|
||||
MLIRContext *ctx = lbVal.getContext();
|
||||
AffineExpr lb = getAffineDimExpr(dims.size(), ctx);
|
||||
dims.push_back(lbVal);
|
||||
AffineExpr ub = getAffineDimExpr(dims.size(), ctx);
|
||||
dims.push_back(ubVal);
|
||||
AffineExpr step = getAffineSymbolExpr(symbols.size(), ctx);
|
||||
symbols.push_back(stepVal);
|
||||
return std::make_pair(lb, lb + step * ((ub - 1) - lb).floorDiv(step));
|
||||
}
|
||||
|
||||
/// Return the min/max expressions for `value` if it is an induction variable
|
||||
/// from scf.for or scf.parallel loop.
|
||||
/// if `loopFilter` is passed, the filter determines which loop to consider.
|
||||
/// Other induction variables are ignored.
|
||||
Optional<std::pair<AffineExpr, AffineExpr>> mlir::getSCFMinMaxExpr(
|
||||
Value value, SmallVectorImpl<Value> &dims, SmallVectorImpl<Value> &symbols,
|
||||
llvm::function_ref<bool(Operation *)> substituteOperation) {
|
||||
if (auto forOp = scf::getForInductionVarOwner(value))
|
||||
return getMinMaxLoopIndVar(forOp.lowerBound(), forOp.upperBound(),
|
||||
forOp.step(), dims, symbols);
|
||||
|
||||
if (auto parallelForOp = scf::getParallelForInductionVarOwner(value))
|
||||
for (unsigned idx = 0, e = parallelForOp.getNumLoops(); idx < e; ++idx)
|
||||
if (parallelForOp.getInductionVars()[idx] == value)
|
||||
return getMinMaxLoopIndVar(parallelForOp.lowerBound()[idx],
|
||||
parallelForOp.upperBound()[idx],
|
||||
parallelForOp.step()[idx], dims, symbols);
|
||||
return {};
|
||||
}
|
||||
|
|
|
@ -1,141 +0,0 @@
|
|||
// RUN: mlir-opt %s -test-linalg-transform-patterns=test-affine-min-scf-canonicalization-patterns | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: scf_for
|
||||
func @scf_for(%A : memref<i64>, %step : index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c7 = constant 7 : index
|
||||
%c4 = constant 4 : index
|
||||
%c16 = constant 16 : index
|
||||
%c1024 = constant 1024 : index
|
||||
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: scf.for
|
||||
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
|
||||
scf.for %i = %c0 to %c4 step %c2 {
|
||||
%1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// CHECK: scf.for
|
||||
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
|
||||
scf.for %i = %c1 to %c7 step %c2 {
|
||||
%1 = affine.min affine_map<(d0)[s0] -> (s0 - d0, 2)> (%i)[%c7]
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This should not canonicalize because: 4 - %i may take the value 1 < 2.
|
||||
// CHECK: scf.for
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
scf.for %i = %c1 to %c4 step %c2 {
|
||||
%1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c4]
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This should not canonicalize because: 16 - %i may take the value 15 < 1024.
|
||||
// CHECK: scf.for
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
scf.for %i = %c1 to %c16 step %c1024 {
|
||||
%1 = affine.min affine_map<(d0) -> (1024, 16 - d0)> (%i)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This example should simplify but affine_map is currently missing
|
||||
// semi-affine canonicalizations: `((s0 * 42 - 1) floordiv s0) * s0`
|
||||
// should evaluate to 41 * s0.
|
||||
// Note that this may require positivity assumptions on `s0`.
|
||||
// Revisit when support is added.
|
||||
// CHECK: scf.for
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
%ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step)
|
||||
scf.for %i = %c0 to %ub step %step {
|
||||
%1 = affine.min affine_map<(d0, d1, d2) -> (d0, d1 - d2)> (%step, %ub, %i)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This example should simplify but affine_map is currently missing
|
||||
// semi-affine canonicalizations.
|
||||
// This example should simplify but affine_map is currently missing
|
||||
// semi-affine canonicalizations: ` -(((s0 * s0 - 1) floordiv s0) * s0)`
|
||||
// should evaluate to (s0 - 1) * s0.
|
||||
// Note that this may require positivity assumptions on `s0`.
|
||||
// Revisit when support is added.
|
||||
// CHECK: scf.for
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
%ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step]
|
||||
scf.for %i = %c0 to %ub2 step %step {
|
||||
%1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: scf_parallel
|
||||
func @scf_parallel(%A : memref<i64>, %step : index) {
|
||||
%c0 = constant 0 : index
|
||||
%c1 = constant 1 : index
|
||||
%c2 = constant 2 : index
|
||||
%c7 = constant 7 : index
|
||||
%c4 = constant 4 : index
|
||||
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: scf.parallel
|
||||
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
|
||||
scf.parallel (%i) = (%c0) to (%c4) step (%c2) {
|
||||
%1 = affine.min affine_map<(d0, d1)[] -> (2, d1 - d0)> (%i, %c4)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// CHECK: scf.parallel
|
||||
// CHECK-NEXT: memref.store %[[C2]], %{{.*}}[] : memref<i64>
|
||||
scf.parallel (%i) = (%c1) to (%c7) step (%c2) {
|
||||
%1 = affine.min affine_map<(d0)[s0] -> (2, s0 - d0)> (%i)[%c7]
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This example should simplify but affine_map is currently missing
|
||||
// semi-affine canonicalizations.
|
||||
// This affine map does not currently evaluate to (0, 0):
|
||||
// (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * 42) mod s0)
|
||||
// TODO: Revisit when support is added.
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
%ub = affine.apply affine_map<(d0) -> (42 * d0)> (%step)
|
||||
scf.parallel (%i) = (%c0) to (%ub) step (%step) {
|
||||
%1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
// This example should simplify but affine_map is currently missing
|
||||
// semi-affine canonicalizations.
|
||||
// This affine map does not currently evaluate to (0, 0):
|
||||
// (d0)[s0] -> (s0 mod s0, (-((d0 floordiv s0) * s0) + s0 * s0) mod s0)
|
||||
// TODO: Revisit when support is added.
|
||||
// CHECK: scf.parallel
|
||||
// CHECK: affine.min
|
||||
// CHECK: index_cast
|
||||
%ub2 = affine.apply affine_map<(d0)[s0] -> (s0 * d0)> (%step)[%step]
|
||||
scf.parallel (%i) = (%c0) to (%ub2) step (%step) {
|
||||
%1 = affine.min affine_map<(d0, d1, d2) -> (d0, d2 - d1)> (%step, %i, %ub2)
|
||||
%2 = index_cast %1: index to i64
|
||||
memref.store %2, %A[]: memref<i64>
|
||||
}
|
||||
|
||||
return
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
// RUN: mlir-opt %s -canonicalize-scf-affine-min -split-input-file | FileCheck %s
|
||||
|
||||
// Note: This is mostly a copy of test/Dialect/Linalg/fold-affine-min-scf.mlir
|
||||
|
||||
// CHECK-LABEL: func @scf_for_canonicalize_min
|
||||
// CHECK: %[[C2:.*]] = constant 2 : i64
|
||||
// CHECK: scf.for
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Dialect/Vector/VectorTransforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
@ -70,7 +71,7 @@ void TestConvVectorization::runOnOperation() {
|
|||
|
||||
RewritePatternSet stage2Patterns =
|
||||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
stage2Patterns.add<linalg::AffineMinSCFCanonicalizationPattern>(context);
|
||||
scf::populateSCFLoopBodyCanonicalizationPatterns(stage2Patterns);
|
||||
|
||||
auto stage3Transforms = [](Operation *op) {
|
||||
PassManager pm(op->getContext());
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/SCF/Transforms.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
|
@ -235,8 +236,8 @@ struct TestLinalgGreedyFusion
|
|||
MLIRContext *context = &getContext();
|
||||
RewritePatternSet patterns =
|
||||
linalg::getLinalgTilingCanonicalizationPatterns(context);
|
||||
patterns.add<AffineMinSCFCanonicalizationPattern,
|
||||
ExtractSliceOfPadTensorSwapPattern>(context);
|
||||
patterns.add<ExtractSliceOfPadTensorSwapPattern>(context);
|
||||
scf::populateSCFLoopBodyCanonicalizationPatterns(patterns);
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
|
||||
do {
|
||||
(void)applyPatternsAndFoldGreedily(getFunction(), frozenPatterns);
|
||||
|
|
|
@ -83,10 +83,6 @@ struct TestLinalgTransforms
|
|||
llvm::cl::desc("Test a set of patterns that rewrite a linalg contraction "
|
||||
"in vector.contract form"),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testAffineMinSCFCanonicalizationPatterns{
|
||||
*this, "test-affine-min-scf-canonicalization-patterns",
|
||||
llvm::cl::desc("Test affine-min + scf canonicalization patterns."),
|
||||
llvm::cl::init(false)};
|
||||
Option<bool> testTileAndPadPattern{
|
||||
*this, "test-tile-and-pad-pattern",
|
||||
llvm::cl::desc("Test tile and pad pattern"), llvm::cl::init(false)};
|
||||
|
@ -546,18 +542,6 @@ static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) {
|
|||
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
|
||||
}
|
||||
|
||||
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
|
||||
RewritePatternSet foldPattern(funcOp.getContext());
|
||||
foldPattern.add<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
|
||||
FrozenRewritePatternSet frozenPatterns(std::move(foldPattern));
|
||||
|
||||
// Explicitly apply the pattern on affected ops to avoid more general folding
|
||||
// on the rest of the IR.
|
||||
SmallVector<Operation *, 4> minOps;
|
||||
funcOp.walk([&](AffineMinOp minOp) { minOps.push_back(minOp); });
|
||||
(void)applyOpPatternsAndFold(minOps, frozenPatterns, /*strict=*/false);
|
||||
}
|
||||
|
||||
// For now, just assume it is the zero of type.
|
||||
// In the future, it should be the zero of type + op.
|
||||
static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
|
||||
|
@ -628,8 +612,6 @@ void TestLinalgTransforms::runOnFunction() {
|
|||
return applyGeneralizePadTensorPatterns(getFunction());
|
||||
if (testSwapSubTensorPadTensor)
|
||||
return applyExtractSliceOfPadTensorSwapPattern(getFunction());
|
||||
if (testAffineMinSCFCanonicalizationPatterns)
|
||||
return applyAffineMinSCFCanonicalizationPatterns(getFunction());
|
||||
if (testTileAndPadPattern)
|
||||
return applyTileAndPadPattern(getFunction(), tileSizesForPadding);
|
||||
if (testHoistPadding) {
|
||||
|
|
|
@ -396,6 +396,7 @@ cc_library(
|
|||
"//mlir:LinalgOps",
|
||||
"//mlir:LinalgTransforms",
|
||||
"//mlir:Pass",
|
||||
"//mlir:SCFTransforms",
|
||||
"//mlir:StandardOps",
|
||||
"//mlir:TransformUtils",
|
||||
"//mlir:VectorOps",
|
||||
|
|
Loading…
Reference in New Issue