[mlir] Structured transforms: introduce op splitting
Introduce a new transformation on structured ops that splits the iteration space into two parts along the specified dimension. The index at which the splitting happens may be static or dynamic. This transformation can be seen as a rudimentary form of index-set splitting that only supports the splitting along hyperplanes parallel to the iteration space hyperplanes, and is therefore decomposable into per-dimension application. It is a key low-level transformation that enables independent scheduling for different parts of the iteration space of the same op, which hasn't been possible previously. It may be used to implement, e.g., multi-sized tiling. In future, peeling can be implemented as a combination of split-off amount computation and splitting. The transformation is conceptually close to tiling in its separation of the iteration and data spaces, but cannot be currently implemented on top of TilingInterface as the latter does not properly support `linalg.index` offsetting. Note that the transformation intentionally bypasses folding of `tensor.extract_slice` operations when creating them as this folding was found to prevent repeated splitting of the same operation because due to internal assumptions about extract/insert_slice combination in dialect utilities. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129090
This commit is contained in:
parent
1d9086bf05
commit
ff6e5508d6
|
@ -150,10 +150,19 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type,
|
|||
/// in shaped types.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size);
|
||||
|
||||
/// Returns the value indicating a dynamic size in a shaped type. Prefer
|
||||
/// mlirShapedTypeIsDynamicSize to direct comparisons with this value.
|
||||
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize();
|
||||
|
||||
/// Checks whether the given value is used as a placeholder for dynamic strides
|
||||
/// and offsets in shaped types.
|
||||
MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val);
|
||||
|
||||
/// Returns the value indicating a dynamic stride or offset in a shaped type.
|
||||
/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with
|
||||
/// this value.
|
||||
MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset();
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -25,6 +25,7 @@ namespace mlir {
|
|||
class AffineApplyOp;
|
||||
class AffineBound;
|
||||
class AffineValueMap;
|
||||
class IRRewriter;
|
||||
|
||||
/// TODO: These should be renamed if they are on the mlir namespace.
|
||||
/// Ideally, they should go in a mlir::affine:: namespace.
|
||||
|
@ -384,6 +385,12 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
|
|||
SmallVector<Value, 4> applyMapToValues(OpBuilder &b, Location loc,
|
||||
AffineMap map, ValueRange values);
|
||||
|
||||
/// Returns the values obtained by applying `map` to the list of values, which
|
||||
/// may be known constants.
|
||||
SmallVector<OpFoldResult> applyMapToValues(IRRewriter &b, Location loc,
|
||||
AffineMap map,
|
||||
ArrayRef<OpFoldResult> values);
|
||||
|
||||
/// Given an affine map `map` and its input `operands`, this method composes
|
||||
/// into `map`, maps of AffineApplyOps whose results are the values in
|
||||
/// `operands`, iteratively until no more of `operands` are the result of an
|
||||
|
|
|
@ -153,6 +153,38 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
|
|||
}];
|
||||
}
|
||||
|
||||
def SplitOp : Op<Transform_Dialect, "structured.split",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<TransformOpInterface>]> {
|
||||
let description = [{
|
||||
Indicates that the given `target` op should be split into two complementary
|
||||
parts, which combined cover the entire iteration domain of the original op.
|
||||
The split is performed along the iteration space dimension provided as
|
||||
attribute. In case of dimension overflow, the transformation fails. The
|
||||
split is performed at the dimension iterator value specified as either the
|
||||
static split point attribute when it is known at transform IR construction
|
||||
time or as the handle to an operation producing a single index-typed value
|
||||
when it is computed by payload IR. In the latter case, the static split
|
||||
point must be set to `ShapedType::kDynamicSize` and the dynamic size handle
|
||||
must point to as many value-producing operations as there are structured
|
||||
operations pointed to by the target handle.
|
||||
|
||||
The operation consumes the target handle, but preserves the split point
|
||||
handle if provided. It produces two new handles pointing to the two parts
|
||||
of the structured op after splitting, in the same order as the target
|
||||
operand, with the first handle corresponding to the part with lower
|
||||
iteration space indices.
|
||||
}];
|
||||
|
||||
let arguments = (ins PDL_Operation:$target,
|
||||
I64Attr:$dimension,
|
||||
Optional<PDL_Operation>:$dynamic_split_point,
|
||||
I64Attr:$static_split_point);
|
||||
let results = (outs PDL_Operation:$first, PDL_Operation:$second);
|
||||
let hasVerifier = 1;
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
|
||||
[FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
|
||||
TransformEachOpTrait, TransformOpInterface]> {
|
||||
|
|
|
@ -106,6 +106,34 @@ void populateInlineConstantOperandsPatterns(RewritePatternSet &patterns);
|
|||
/// Patterns that are used to bubble up extract slice op above linalg op.
|
||||
void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
|
||||
|
||||
/// Split the given `op` into two parts along the given iteration space
|
||||
/// `dimension` at the specified `splitPoint`, and return the two parts.
|
||||
///
|
||||
/// For example, the following op:
|
||||
///
|
||||
/// linalg.matmul ins(%0, %1 : tensor<128x32xf32>, tensor<32x64xf32>)
|
||||
/// outs(%2 : tensor<128x64xf32>)
|
||||
///
|
||||
/// split along the first dimension at position 42 will result in:
|
||||
///
|
||||
/// %3 = tensor.extract_slice %0[0, 0][42, 32][1, 1]
|
||||
/// %4 = tensor.extract_slice %2[0, 0][42, 64][1, 1]
|
||||
/// %5 = linalg.matmul ins(%3, %1 : tensor<42x32xf32>, tensor<32x64xf32>)
|
||||
/// outs(%5 : tensor<42x64xf32>)
|
||||
/// %6 = tensor.insert_slice %5 into %2[0, 0][42, 64][1, 1]
|
||||
///
|
||||
/// %7 = tensor.extract_slice %0[42, 0][86, 32][1, 1]
|
||||
/// %8 = tensor.extract_slice %6[42, 0][86, 64][1, 1]
|
||||
/// %9 = linalg.matmul ins(%7, %1 : tensor<86x32xf32>, tensor<32x64xf32>)
|
||||
/// outs(%8 : tensor<86x64xf32>)
|
||||
/// tensor.insert_slice %5 into %6[42, 0][86, 64][1, 1]
|
||||
///
|
||||
/// Note that there is no simplification other than constant propagation applied
|
||||
/// to slice extraction and insertion.
|
||||
std::pair<LinalgOp, LinalgOp> splitOp(RewriterBase &rewriter, LinalgOp op,
|
||||
unsigned dimension,
|
||||
OpFoldResult splitPoint);
|
||||
|
||||
/// Perform standalone tiling of a single LinalgOp by `tileSizes`.
|
||||
/// and permute the loop nest according to `interchangeVector`
|
||||
/// The permutation is expressed as a list of integers that specify
|
||||
|
|
|
@ -177,12 +177,18 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
|
|||
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
|
||||
Value consumedView, LinalgOp producer);
|
||||
|
||||
/// Compute tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
|
||||
/// Creates either a memref.subview or a tensor.extract_slice with the given
|
||||
/// offsets/sizes/strides based on the type of `value`.
|
||||
Value createSlice(OpBuilder &builder, Location loc, Value value,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides);
|
||||
|
||||
/// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a
|
||||
/// tile size is zero (i.e., no tiling), the corresponding offset is also zero.
|
||||
SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
|
||||
ValueRange ivs, ValueRange tileSizes);
|
||||
|
||||
/// Compute tile sizes, given a list of `tileSizes` and dimension
|
||||
/// Computes tile sizes, given a list of `tileSizes` and dimension
|
||||
/// sizes (`sizeBounds`). In case a tile size is zero (i.e., no tiling), the
|
||||
/// corresponding result size is the corresponding value from `sizeBounds`.
|
||||
/// Note: The returned tile sizes are closed intervals.
|
||||
|
@ -190,6 +196,20 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
|
|||
ValueRange tileSizes,
|
||||
ArrayRef<Value> sizeBounds);
|
||||
|
||||
/// Returns the list of tensor output types produced when the given structured
|
||||
/// operation `op` is applied to the given `operands`. Note that `operands` are
|
||||
/// not necessarily the actual operands of `op`.
|
||||
SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands);
|
||||
|
||||
/// Creates `insert_slice` ops that insert `results` back into larger tensors
|
||||
/// they were originally extracted from with `extract_slice` before being passed
|
||||
/// as `operands` to the given structured operation `op` or its clone. Note that
|
||||
/// `operands` are not necessarily the actual operands of `op`, the operation
|
||||
/// serves only as metadata container for operand types and positions.
|
||||
SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
||||
LinalgOp op, ValueRange operands,
|
||||
ValueRange results);
|
||||
|
||||
/// Creates an extract_slice/subview op for a single `valueToTile` with
|
||||
/// `builder`. This new operation extracts a tile of `valueToTile`, starting
|
||||
/// at offsets `lbs` and with sizes `subShapeSizes`. `omitPartialTileCheck`
|
||||
|
|
|
@ -301,6 +301,15 @@ public:
|
|||
return shape;
|
||||
},
|
||||
"Returns the shape of the ranked shaped type as a list of integers.");
|
||||
c.def_static(
|
||||
"_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
|
||||
"Returns the value used to indicate dynamic dimensions in shaped "
|
||||
"types.");
|
||||
c.def_static(
|
||||
"_get_dynamic_stride_or_offset",
|
||||
[]() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
|
||||
"Returns the value used to indicate dynamic strides or offsets in "
|
||||
"shaped types.");
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
|
@ -149,6 +149,8 @@ int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) {
|
|||
return unwrap(type).cast<ShapedType>().getDimSize(static_cast<unsigned>(dim));
|
||||
}
|
||||
|
||||
int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; }
|
||||
|
||||
bool mlirShapedTypeIsDynamicSize(int64_t size) {
|
||||
return ShapedType::isDynamic(size);
|
||||
}
|
||||
|
@ -157,6 +159,10 @@ bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) {
|
|||
return ShapedType::isDynamicStrideOrOffset(val);
|
||||
}
|
||||
|
||||
int64_t mlirShapedTypeGetDynamicStrideOrOffset() {
|
||||
return ShapedType::kDynamicStrideOrOffset;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Vector type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -748,6 +748,76 @@ SmallVector<Value, 4> mlir::applyMapToValues(OpBuilder &b, Location loc,
|
|||
return res;
|
||||
}
|
||||
|
||||
SmallVector<OpFoldResult>
|
||||
mlir::applyMapToValues(IRRewriter &b, Location loc, AffineMap map,
|
||||
ArrayRef<OpFoldResult> values) {
|
||||
// Materialize constants and keep track of produced operations so we can clean
|
||||
// them up later.
|
||||
SmallVector<Operation *> constants;
|
||||
SmallVector<Value> actualValues;
|
||||
actualValues.reserve(values.size());
|
||||
auto *dialect = b.getContext()->getLoadedDialect<AffineDialect>();
|
||||
for (OpFoldResult ofr : values) {
|
||||
if (auto value = ofr.dyn_cast<Value>()) {
|
||||
actualValues.push_back(value);
|
||||
continue;
|
||||
}
|
||||
constants.push_back(dialect->materializeConstant(b, ofr.get<Attribute>(),
|
||||
b.getIndexType(), loc));
|
||||
actualValues.push_back(constants.back()->getResult(0));
|
||||
}
|
||||
|
||||
// Compose, fold and construct maps for each result independently because they
|
||||
// may simplify more effectively.
|
||||
SmallVector<OpFoldResult> results;
|
||||
results.reserve(map.getNumResults());
|
||||
bool foldedAll = true;
|
||||
for (auto i : llvm::seq<unsigned>(0, map.getNumResults())) {
|
||||
AffineMap submap = map.getSubMap({i});
|
||||
SmallVector<Value> operands = actualValues;
|
||||
fullyComposeAffineMapAndOperands(&submap, &operands);
|
||||
canonicalizeMapAndOperands(&submap, &operands);
|
||||
|
||||
// Identify the constant operands and extract their values as attributes.
|
||||
// Note that we cannot use the original values directly because the list of
|
||||
// operands may have changed due to canonicalization and composition.
|
||||
SmallVector<Attribute> constantOperands;
|
||||
constantOperands.reserve(operands.size());
|
||||
for (Value operand : operands) {
|
||||
IntegerAttr attr;
|
||||
if (matchPattern(operand, m_Constant(&attr)))
|
||||
constantOperands.push_back(attr);
|
||||
else
|
||||
constantOperands.push_back(nullptr);
|
||||
}
|
||||
|
||||
// Create an apply operation and immediately attempt to fold it. On sucess,
|
||||
// delete the operation and prepare the (unmaterialized) value for being
|
||||
// returned. On failure, return the function result.
|
||||
// TODO: arguably, the main folder (createOrFold) API should support this
|
||||
// use case instead of indiscriminately materializing constants.
|
||||
auto apply = b.create<AffineApplyOp>(loc, submap, operands);
|
||||
SmallVector<OpFoldResult, 1> foldResult;
|
||||
if (succeeded(apply->fold(constantOperands, foldResult))) {
|
||||
assert(foldResult.size() == 1 && "expected single-result map");
|
||||
b.eraseOp(apply);
|
||||
results.push_back(foldResult.front());
|
||||
} else {
|
||||
results.push_back(apply.getResult());
|
||||
foldedAll = false;
|
||||
}
|
||||
}
|
||||
|
||||
// If the entire map could be folded, remove the constants that were used in
|
||||
// the initial ops.
|
||||
if (foldedAll) {
|
||||
for (Operation *constant : constants)
|
||||
b.eraseOp(constant);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
// A symbol may appear as a dim in affine.apply operations. This function
|
||||
// canonicalizes dims that are valid symbols into actual symbols.
|
||||
template <class MapOrSet>
|
||||
|
|
|
@ -399,6 +399,161 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
|
|||
return result->op;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
|
||||
TransformState &state) {
|
||||
// Collect the dynamic split points if provided.
|
||||
ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
|
||||
SimpleRewriter rewriter(getContext());
|
||||
SmallVector<OpFoldResult> splitPoints;
|
||||
splitPoints.reserve(payload.size());
|
||||
if (getDynamicSplitPoint()) {
|
||||
auto diag = DiagnosedSilenceableFailure::success();
|
||||
splitPoints = llvm::to_vector(llvm::map_range(
|
||||
state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
|
||||
if (op->getNumResults() != 1 ||
|
||||
!op->getResult(0).getType().isIndex()) {
|
||||
diag = emitSilenceableError()
|
||||
<< "expected dynamic split point handle to point to a "
|
||||
"single-result index-typed op";
|
||||
diag.attachNote(op->getLoc()) << "dynamic split point";
|
||||
}
|
||||
return OpFoldResult(op->getResult(0));
|
||||
}));
|
||||
if (!diag.succeeded())
|
||||
return diag;
|
||||
|
||||
if (splitPoints.size() != payload.size()) {
|
||||
emitError() << "expected the dynamic split point handle to point to as "
|
||||
"many operations ("
|
||||
<< splitPoints.size() << ") as the target handle ("
|
||||
<< payload.size() << ")";
|
||||
return DiagnosedSilenceableFailure::definiteFailure();
|
||||
}
|
||||
} else {
|
||||
splitPoints.resize(payload.size(),
|
||||
rewriter.getIndexAttr(getStaticSplitPoint()));
|
||||
}
|
||||
|
||||
// Split each target operation.
|
||||
SmallVector<Operation *> first, second;
|
||||
for (const auto &pair : llvm::zip(payload, splitPoints)) {
|
||||
Operation *target = std::get<0>(pair);
|
||||
auto linalgOp = dyn_cast<LinalgOp>(target);
|
||||
if (!linalgOp) {
|
||||
auto diag = emitSilenceableError() << "only applies to structured ops";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
if (getDimension() >= linalgOp.getNumLoops()) {
|
||||
auto diag = emitSilenceableError() << "dimension " << getDimension()
|
||||
<< " does not exist in target op";
|
||||
diag.attachNote(target->getLoc()) << "target op";
|
||||
return diag;
|
||||
}
|
||||
|
||||
rewriter.setInsertionPoint(linalgOp);
|
||||
std::tie(first.emplace_back(), second.emplace_back()) =
|
||||
linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair));
|
||||
}
|
||||
|
||||
results.set(getFirst().cast<OpResult>(), first);
|
||||
results.set(getSecond().cast<OpResult>(), second);
|
||||
return DiagnosedSilenceableFailure::success();
|
||||
}
|
||||
|
||||
void SplitOp::getEffects(
|
||||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
|
||||
// The target handle is consumed.
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
|
||||
TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
|
||||
TransformMappingResource::get());
|
||||
|
||||
// The dynamic split point handle is not consumed.
|
||||
if (getDynamicSplitPoint()) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), getDynamicSplitPoint(),
|
||||
TransformMappingResource::get());
|
||||
}
|
||||
|
||||
// The resulting handles are produced.
|
||||
for (Value result : getResults()) {
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), result,
|
||||
TransformMappingResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), result,
|
||||
TransformMappingResource::get());
|
||||
}
|
||||
|
||||
effects.emplace_back(MemoryEffects::Read::get(), PayloadIRResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), PayloadIRResource::get());
|
||||
}
|
||||
|
||||
ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
|
||||
IntegerAttr staticSplitPoint;
|
||||
auto pdlOperationType =
|
||||
pdl::OperationType::get(parser.getBuilder().getContext());
|
||||
if (parser.parseOperand(target) ||
|
||||
parser.resolveOperand(target, pdlOperationType, result.operands) ||
|
||||
parser.parseKeyword("after"))
|
||||
return failure();
|
||||
|
||||
OptionalParseResult dynamicPointParseResult =
|
||||
parser.parseOptionalOperand(dynamicSplitPoint);
|
||||
if (!dynamicPointParseResult.hasValue()) {
|
||||
int64_t staticSplitPointValue;
|
||||
if (failed(parser.parseInteger(staticSplitPointValue)))
|
||||
return failure();
|
||||
|
||||
staticSplitPoint =
|
||||
parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
|
||||
} else {
|
||||
if (failed(*dynamicPointParseResult) ||
|
||||
parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
|
||||
result.operands)) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
staticSplitPoint =
|
||||
parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
|
||||
}
|
||||
|
||||
result.addAttribute(
|
||||
SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
|
||||
staticSplitPoint);
|
||||
if (failed(parser.parseOptionalAttrDict(result.attributes)))
|
||||
return failure();
|
||||
|
||||
result.addTypes({pdlOperationType, pdlOperationType});
|
||||
return success();
|
||||
}
|
||||
|
||||
void SplitOp::print(OpAsmPrinter &printer) {
|
||||
printer << " " << getTarget() << " after ";
|
||||
int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
|
||||
if (staticSplitSize != ShapedType::kDynamicSize)
|
||||
printer << staticSplitSize;
|
||||
else
|
||||
printer << getDynamicSplitPoint();
|
||||
printer << " ";
|
||||
printer.printOptionalAttrDict(getOperation()->getAttrs(),
|
||||
{getStaticSplitPointAttrName()});
|
||||
}
|
||||
|
||||
LogicalResult SplitOp::verify() {
|
||||
if ((static_cast<int64_t>(getStaticSplitPoint()) !=
|
||||
ShapedType::kDynamicSize) ^
|
||||
(getDynamicSplitPoint() == nullptr)) {
|
||||
return emitOpError()
|
||||
<< "expects either a dynamic or a static split point to be provided";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplitReductionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -22,6 +22,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
|
|||
NamedOpConversions.cpp
|
||||
Promotion.cpp
|
||||
SparseTensorRewriting.cpp
|
||||
Split.cpp
|
||||
SplitReduction.cpp
|
||||
Tiling.cpp
|
||||
TilingInterfaceImpl.cpp
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
//===- Split.cpp - Structured op splitting --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
||||
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::linalg;
|
||||
|
||||
/// Turns an OpFoldResult into a value, creating an index-typed constant if
|
||||
/// necessary.
|
||||
static Value materializeOpFoldResult(ImplicitLocOpBuilder &builder,
|
||||
OpFoldResult opFoldResult) {
|
||||
if (opFoldResult.is<Value>())
|
||||
return opFoldResult.get<Value>();
|
||||
auto attr = opFoldResult.get<Attribute>().cast<IntegerAttr>();
|
||||
return builder.create<arith::ConstantIndexOp>(attr.getValue().getSExtValue());
|
||||
}
|
||||
|
||||
/// Extract the slices of `operands` supplied to the given operation `op` such
|
||||
/// that they are sufficient to execute the op for the subset of its iteration
|
||||
/// space defined by `splitIterationSpace`. The subset is a part of the original
|
||||
/// iteration space split at the given `dimension`. If `offset` is provided, it
|
||||
/// indicates the iterator value at which the dimension has been split and
|
||||
/// requires the "high" part starting at the given offset of the operands to be
|
||||
/// generated; otherwise, the "low" part with no offset is generated. Note that
|
||||
/// `operands` are not necessarily the actual operands of `op`.
|
||||
static SmallVector<Value>
|
||||
getOperandSlices(ImplicitLocOpBuilder &builder, LinalgOp op,
|
||||
ValueRange splitIterationSpace, ValueRange operands,
|
||||
unsigned dimension, Value offset = nullptr) {
|
||||
SmallVector<Value> slices;
|
||||
slices.reserve(op.getNumInputsAndOutputs());
|
||||
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
||||
auto type = opOperand->get().getType().dyn_cast<ShapedType>();
|
||||
AffineMap indexing = op.getTiedIndexingMap(opOperand);
|
||||
|
||||
// If the type is not sliceable, or the slice is requested along the
|
||||
// dimension that is not used in indexing this type, just use the entire
|
||||
// operand.
|
||||
if (!type || dimension >= indexing.getNumDims() ||
|
||||
!indexing.isFunctionOfDim(dimension)) {
|
||||
slices.push_back(opOperand->get());
|
||||
continue;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> sizes =
|
||||
applyMapToValues(builder, op.getLoc(), indexing, splitIterationSpace);
|
||||
SmallVector<OpFoldResult> offsets(type.getRank(), builder.getIndexAttr(0));
|
||||
SmallVector<OpFoldResult> strides(type.getRank(), builder.getIndexAttr(1));
|
||||
|
||||
if (offset) {
|
||||
offsets[dimension] = offset;
|
||||
IRRewriter rewriter(builder);
|
||||
offsets = applyMapToValues(rewriter, builder.getLoc(), indexing, offsets);
|
||||
}
|
||||
|
||||
slices.push_back(createSlice(builder, op.getLoc(),
|
||||
operands[opOperand->getOperandNumber()],
|
||||
offsets, getAsOpFoldResult(sizes), strides));
|
||||
}
|
||||
|
||||
return slices;
|
||||
}
|
||||
|
||||
/// Creates a part of the given `op` split along the iteration space `dimension`
|
||||
/// with the given `size` and an optional `offset` (default 0). Makes slices
|
||||
/// of operands, using the input operands of the original op and the output
|
||||
/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be
|
||||
/// a list of values representing the shape of the iteration space of the
|
||||
/// original op and updates it to be the iteration space of the curent part.
|
||||
/// Returns the split-out op as well as the output operand values updated with
|
||||
/// the partial results produced by this op through `results`.
|
||||
static LinalgOp createSplitPart(
|
||||
ImplicitLocOpBuilder &builder, LinalgOp op, ValueRange resultOperands,
|
||||
llvm::MutableArrayRef<Value> splitIterationSpace, unsigned dimension,
|
||||
Value size, SmallVectorImpl<Value> &results, Value offset = nullptr) {
|
||||
splitIterationSpace[dimension] = size;
|
||||
SmallVector<Value> operands = llvm::to_vector(
|
||||
llvm::map_range(op.getInputOperands(),
|
||||
[](OpOperand *opOperand) { return opOperand->get(); }));
|
||||
llvm::append_range(operands, resultOperands);
|
||||
operands = getOperandSlices(builder, op, splitIterationSpace, operands,
|
||||
dimension, offset);
|
||||
Operation *part = op.clone(builder, op.getLoc(),
|
||||
getTensorOutputTypes(op, operands), operands);
|
||||
results = insertSlicesBack(builder, builder.getLoc(), op, operands,
|
||||
part->getResults());
|
||||
return cast<LinalgOp>(part);
|
||||
}
|
||||
|
||||
std::pair<LinalgOp, LinalgOp> linalg::splitOp(RewriterBase &rewriter,
|
||||
LinalgOp op, unsigned dimension,
|
||||
OpFoldResult splitPoint) {
|
||||
// Bail out on dimension overflow.
|
||||
if (dimension >= op.getNumLoops())
|
||||
return std::make_pair(op, LinalgOp());
|
||||
|
||||
// Compute the iteration space size as values.
|
||||
ImplicitLocOpBuilder builder(op.getLoc(), rewriter);
|
||||
SmallVector<Value, 4> allShapes =
|
||||
op.createFlatListOfOperandDims(builder, op.getLoc());
|
||||
AffineMap shapesToLoops = op.getShapesToLoopsMap();
|
||||
SmallVector<Value, 4> iterationSpaceShapes =
|
||||
applyMapToValues(builder, op.getLoc(), shapesToLoops, allShapes);
|
||||
|
||||
// Update the iteration space to have `splitPoint` as the size of `dimension`
|
||||
// and use it to slice operands and results for a new, smaller instance of the
|
||||
// `op`. Adjust the size if necessary to prevent overflows. Insert the partial
|
||||
// results back.
|
||||
Value splitPointValue = materializeOpFoldResult(builder, splitPoint);
|
||||
splitPointValue = builder.createOrFold<AffineMinOp>(
|
||||
builder.getIndexType(),
|
||||
AffineMap::getMultiDimIdentityMap(/*numDims=*/2, builder.getContext()),
|
||||
ValueRange({splitPointValue, iterationSpaceShapes[dimension]}));
|
||||
SmallVector<Value> splitIterationSpace =
|
||||
llvm::to_vector(iterationSpaceShapes);
|
||||
SmallVector<Value> originalResults = llvm::to_vector(
|
||||
llvm::map_range(op.getOutputOperands(),
|
||||
[](OpOperand *opOperand) { return opOperand->get(); }));
|
||||
SmallVector<Value> firstResults;
|
||||
LinalgOp first =
|
||||
createSplitPart(builder, op, originalResults, splitIterationSpace,
|
||||
dimension, splitPointValue, firstResults);
|
||||
|
||||
// Update the iteration space to cover the remaining part of the original
|
||||
// space, then create another instance of the `op` in that space. The size of
|
||||
// the remaining part may become zero, but is never negative because of the
|
||||
// adjustment above.
|
||||
AffineExpr d0 = builder.getAffineDimExpr(0);
|
||||
AffineExpr d1 = builder.getAffineDimExpr(1);
|
||||
SmallVector<Value, 4> remainingSizes = applyMapToValues(
|
||||
builder, op.getLoc(), AffineMap::inferFromExprList({d0 - d1}).front(),
|
||||
{iterationSpaceShapes[dimension], splitPointValue});
|
||||
SmallVector<Value> secondResults;
|
||||
LinalgOp second =
|
||||
createSplitPart(builder, op, firstResults, splitIterationSpace, dimension,
|
||||
remainingSizes.front(), secondResults, splitPointValue);
|
||||
|
||||
// Fixup the linalg.index results in the second part.
|
||||
SmallVector<Value> ivAdditions;
|
||||
ivAdditions.resize(splitIterationSpace.size());
|
||||
ivAdditions[dimension] = splitPointValue;
|
||||
linalg::addTileLoopIvsToIndexOpResults(builder, cast<LinalgOp>(second),
|
||||
ivAdditions);
|
||||
|
||||
// Replace the original op with the results of the two newly created ops.
|
||||
rewriter.replaceOp(op, secondResults);
|
||||
return std::make_pair(first, second);
|
||||
}
|
|
@ -182,32 +182,11 @@ tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes,
|
|||
makeTiledShapes(b, loc, op, valuesToTile, interchangedIvs, tileSizes,
|
||||
sizeBounds, /*omitPartialTileCheck=*/false);
|
||||
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
SmallVector<Type, 4> resultTensorTypes;
|
||||
for (OpOperand *opOperand : op.getOutputTensorOperands())
|
||||
resultTensorTypes.push_back(
|
||||
tiledOperands[opOperand->getOperandNumber()].getType());
|
||||
|
||||
SmallVector<Type> resultTensorTypes =
|
||||
getTensorOutputTypes(op, tiledOperands);
|
||||
res = op.clone(b, loc, resultTensorTypes, tiledOperands);
|
||||
|
||||
// Insert a insert_slice for each output tensor.
|
||||
unsigned resultIdx = 0;
|
||||
for (OpOperand *opOperand : op.getOutputTensorOperands()) {
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
Value outputTensor = tiledOperands[opOperand->getOperandNumber()];
|
||||
// TODO: Propagate RewriterBase everywhere.
|
||||
IRRewriter rewriter(b);
|
||||
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp,
|
||||
res->getResult(resultIdx),
|
||||
sliceOp.getSource()));
|
||||
} else {
|
||||
tensorResults.push_back(res->getResult(resultIdx));
|
||||
}
|
||||
++resultIdx;
|
||||
}
|
||||
tensorResults =
|
||||
insertSlicesBack(builder, loc, op, tiledOperands, res->getResults());
|
||||
return scf::ValueVector(tensorResults.begin(), tensorResults.end());
|
||||
};
|
||||
GenerateLoopNest<LoopTy>::doit(b, op.getLoc(), loopRanges, op, iteratorTypes,
|
||||
|
|
|
@ -913,6 +913,21 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
|
|||
return sliceOp->getResult(0);
|
||||
}
|
||||
|
||||
Value createSlice(OpBuilder &builder, Location loc, Value value,
|
||||
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides) {
|
||||
if (value.getType().isa<MemRefType>()) {
|
||||
return builder.create<memref::SubViewOp>(loc, value, offsets, sizes,
|
||||
strides);
|
||||
}
|
||||
|
||||
// This intentionally does not attempt to compose the extractslice operations.
|
||||
assert(value.getType().isa<RankedTensorType>() &&
|
||||
"expected a ranked tensor type");
|
||||
return builder.create<tensor::ExtractSliceOp>(loc, value, offsets, sizes,
|
||||
strides);
|
||||
}
|
||||
|
||||
SmallVector<Value> computeTileOffsets(OpBuilder &b, Location loc,
|
||||
ValueRange ivs, ValueRange tileSizes) {
|
||||
SmallVector<Value> offsets;
|
||||
|
@ -943,6 +958,41 @@ SmallVector<Value> computeTileSizes(OpBuilder &b, Location loc,
|
|||
return sizes;
|
||||
}
|
||||
|
||||
SmallVector<Type> getTensorOutputTypes(LinalgOp op, ValueRange operands) {
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
return llvm::to_vector(
|
||||
llvm::map_range(op.getOutputTensorOperands(), [&](OpOperand *opOperand) {
|
||||
return operands[opOperand->getOperandNumber()].getType();
|
||||
}));
|
||||
}
|
||||
|
||||
SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc,
|
||||
LinalgOp op, ValueRange operands,
|
||||
ValueRange results) {
|
||||
SmallVector<Value> tensorResults;
|
||||
tensorResults.reserve(results.size());
|
||||
// Insert a insert_slice for each output tensor.
|
||||
unsigned resultIdx = 0;
|
||||
for (OpOperand *opOperand : op.getOutputTensorOperands()) {
|
||||
// TODO: use an interface/adaptor to avoid leaking position in
|
||||
// `tiledOperands`.
|
||||
Value outputTensor = operands[opOperand->getOperandNumber()];
|
||||
if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) {
|
||||
Value inserted = builder.create<tensor::InsertSliceOp>(
|
||||
loc, sliceOp.source().getType(), results[resultIdx], sliceOp.source(),
|
||||
sliceOp.offsets(), sliceOp.sizes(), sliceOp.strides(),
|
||||
sliceOp.static_offsets(), sliceOp.static_sizes(),
|
||||
sliceOp.static_strides());
|
||||
tensorResults.push_back(inserted);
|
||||
} else {
|
||||
tensorResults.push_back(results[resultIdx]);
|
||||
}
|
||||
++resultIdx;
|
||||
}
|
||||
return tensorResults;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> makeTiledShapes(OpBuilder &b, Location loc,
|
||||
LinalgOp linalgOp,
|
||||
ArrayRef<Value> valuesToTile,
|
||||
|
|
|
@ -15,6 +15,12 @@ IntOrAttrList = Sequence[Union[IntegerAttr, int]]
|
|||
OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]]
|
||||
|
||||
|
||||
def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr:
|
||||
if isinstance(value, int):
|
||||
return IntegerAttr.get(IntegerType.get_signless(64), value)
|
||||
return value
|
||||
|
||||
|
||||
def _get_array_attr(
|
||||
values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr:
|
||||
"""Creates an array attribute from its operand."""
|
||||
|
@ -41,13 +47,7 @@ def _get_int_array_attr(
|
|||
if isinstance(values, ArrayAttr):
|
||||
return values
|
||||
|
||||
attributes = []
|
||||
for value in values:
|
||||
if isinstance(value, IntegerAttr):
|
||||
attributes.append(value)
|
||||
else:
|
||||
attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value))
|
||||
return ArrayAttr.get(attributes)
|
||||
return ArrayAttr.get([_get_int64_attr(v) for v in values])
|
||||
|
||||
|
||||
def _get_int_int_array_attr(
|
||||
|
@ -152,6 +152,39 @@ class ScalarizeOp:
|
|||
pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip)
|
||||
|
||||
|
||||
class SplitOp:
|
||||
"""Specialization for SplitOp class."""
|
||||
|
||||
def __init__(self,
|
||||
target: Union[Operation, Value],
|
||||
dimension: Union[int, Attribute],
|
||||
split_point: Union[int, Operation, Value, Attribute],
|
||||
*,
|
||||
loc=None,
|
||||
ip=None):
|
||||
dimension = _get_int64_attr(dimension)
|
||||
if isinstance(split_point, int):
|
||||
split_point = _get_int64_attr(split_point)
|
||||
|
||||
if isinstance(split_point, Attribute):
|
||||
static_split_point = split_point
|
||||
dynamic_split_point = None
|
||||
else:
|
||||
static_split_point = _get_int64_attr(ShapedType._get_dynamic_size())
|
||||
dynamic_split_point = _get_op_result_or_value(split_point)
|
||||
|
||||
pdl_operation_type = pdl.OperationType.get()
|
||||
super().__init__(
|
||||
pdl_operation_type,
|
||||
pdl_operation_type,
|
||||
_get_op_result_or_value(target),
|
||||
dimension=dimension,
|
||||
static_split_point=static_split_point,
|
||||
dynamic_split_point=dynamic_split_point,
|
||||
loc=loc,
|
||||
ip=ip)
|
||||
|
||||
|
||||
class TileOp:
|
||||
"""Specialization for TileOp class."""
|
||||
|
||||
|
|
|
@ -0,0 +1,366 @@
|
|||
// RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
%1:2 = transform.structured.split %0 after 42 { dimension = 0 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
||||
|
||||
// CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)>
|
||||
// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)>
|
||||
|
||||
// CHECK-LABEL: @one_d_static
|
||||
// CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32>
|
||||
func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
// CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
|
||||
// CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [42] [1] : tensor<100xf32> to tensor<42xf32>
|
||||
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_LOW]]
|
||||
// CHECK: outs(%[[OUT_SLICE_LOW]]
|
||||
// CHECK: linalg.index 0
|
||||
// CHECK: func.call @elem
|
||||
// CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [42] [1]
|
||||
//
|
||||
// CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
|
||||
// CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][42] [58] [1] : tensor<100xf32> to tensor<58xf32>
|
||||
// CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_HIGH]]
|
||||
// CHECK: outs(%[[OUT_SLICE_HIGH]]
|
||||
// CHECK: %[[IDX:.+]] = linalg.index 0
|
||||
// CHECK: affine.apply #[[$ADD_42_MAP]](%[[IDX]])
|
||||
// CHECK: func.call @elem
|
||||
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][42] [58] [1]
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
||||
^bb0(%0: f32, %1: f32):
|
||||
%i = linalg.index 0 : index
|
||||
%call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
|
||||
linalg.yield %call_res : f32
|
||||
} -> tensor<100xf32>
|
||||
|
||||
// CHECK: return %[[RES]]
|
||||
return %0 : tensor<100xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @one_d_static_overflow
|
||||
// CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32>
|
||||
func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
|
||||
// CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
|
||||
// CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32>
|
||||
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_LOW]]
|
||||
// CHECK: outs(%[[OUT_SLICE_LOW]]
|
||||
// CHECK: linalg.index 0
|
||||
// CHECK: func.call @elem
|
||||
// CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1]
|
||||
//
|
||||
// CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
|
||||
// CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32>
|
||||
// CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_HIGH]]
|
||||
// CHECK: outs(%[[OUT_SLICE_HIGH]]
|
||||
// CHECK: %[[IDX:.+]] = linalg.index 0
|
||||
// CHECK: affine.apply #[[$ADD_10_MAP]](%[[IDX]])
|
||||
// CHECK: func.call @elem
|
||||
// CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1]
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<10xf32>) outs(%arg1: tensor<10xf32>) {
|
||||
^bb0(%0: f32, %1: f32):
|
||||
%i = linalg.index 0 : index
|
||||
%call_res = func.call @elem(%0, %i, %i) : (f32, index, index) -> f32
|
||||
linalg.yield %call_res : f32
|
||||
} -> tensor<10xf32>
|
||||
return %0 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func_call : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
%1 = transform.pdl_match @func_call in %arg1
|
||||
transform.structured.split %0 after %1 { dimension = 0 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @get_size() -> index
|
||||
|
||||
// CHECK: #[[$MAP_MIN_100:.+]] = affine_map<(d0, d1) -> (d0, 100)>
|
||||
// CHECK: #[[$MAP_S_MINUS_100:.+]] = affine_map<()[s0] -> (-s0 + 100)>
|
||||
|
||||
// CHECK-LABEL: @dynamic
|
||||
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
// CHECK: %[[SPLIT:.+]] = call @get_size
|
||||
// CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]](%[[SPLIT]]
|
||||
// CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
|
||||
// CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor<?xf32>
|
||||
// CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_LOW]]
|
||||
// CHECK: outs(%[[OUT_SLICE_LOW]]
|
||||
// CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1]
|
||||
//
|
||||
// CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
||||
// CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
||||
// CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor<?xf32>
|
||||
// CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]]
|
||||
// CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[PARTIAL:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] : tensor<100xf32> to tensor<?xf32>
|
||||
// CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic
|
||||
// CHECK: ins(%[[IN_SLICE_HIGH]]
|
||||
// CHECK: outs(%[[OUT_SLICE_HIGH]]
|
||||
// CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1]
|
||||
%0 = func.call @get_size() : () -> index
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
||||
^bb0(%3: f32, %4: f32):
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<100xf32>
|
||||
return %1 : tensor<100xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
%1:2 = transform.structured.split %0 after 4 { dimension = 0}
|
||||
%2:2 = transform.structured.split %1#1 after 16 { dimension = 1 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32
|
||||
|
||||
// CHECK-LABEL: @two_d
|
||||
func.func @two_d(%arg0: tensor<10x34xf32>,
|
||||
%arg1: tensor<10x34xf32>) -> tensor<10x34xf32> {
|
||||
// Check the overall structure: split along the dimension 0, and then split
|
||||
// the second half only along the dimension 1.
|
||||
// CHECK: %[[IN_1:.+]] = tensor.extract_slice %[[IN:.+]][0, 0]
|
||||
// CHECK: %[[OUT_1:.+]] = tensor.extract_slice %[[OUT:.+]][0, 0]
|
||||
// CHECK: %[[RES_1:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[IN_1]] : tensor<4x34xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT_1]] : tensor<4x34xf32>)
|
||||
// CHECK: %[[PARTIAL_1:.+]] = tensor.insert_slice %[[RES_1]] into %[[OUT]]
|
||||
//
|
||||
// CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]]
|
||||
// CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]]
|
||||
// CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]]
|
||||
// CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]]
|
||||
// CHECK: %[[RES_21:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>)
|
||||
// CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]]
|
||||
//
|
||||
// CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]]
|
||||
// CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]]
|
||||
// CHECK: %[[RES_22:.+]] = linalg.generic
|
||||
// CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>)
|
||||
// CHECK-SAME: outs(%[[OUT_22]] : tensor<6x18xf32>)
|
||||
// CHECK: %[[PARTIAL_22:.+]] = tensor.insert_slice %[[RES_22]] into %[[PARTIAL_21]]
|
||||
// CHECK: %[[PARTIAL_2:.+]] = tensor.insert_slice %[[PARTIAL_22]] into %[[PARTIAL_1]]
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i, j) -> (i, j)>,
|
||||
affine_map<(i, j) -> (i, j)>],
|
||||
iterator_types = ["parallel", "parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<10x34xf32>)
|
||||
outs(%arg1: tensor<10x34xf32>) {
|
||||
^bb0(%0: f32, %1: f32):
|
||||
%i = linalg.index 0 : index
|
||||
%j = linalg.index 1 : index
|
||||
%call_res = func.call @elem(%0, %i, %j) : (f32, index, index) -> f32
|
||||
linalg.yield %call_res : f32
|
||||
} -> tensor<10x34xf32>
|
||||
return %0 : tensor<10x34xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.sequence {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
// expected-error @below {{expects either a dynamic or a static split point to be provided}}
|
||||
%0:2 = "transform.structured.split"(%arg1) { dimension = 1, static_split_point = -1 } : (!pdl.operation) -> (!pdl.operation, !pdl.operation)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func_call : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
%1 = transform.pdl_match @func_call in %arg1
|
||||
// expected-error @below {{expected dynamic split point handle to point to a single-result index-typed op}}
|
||||
transform.structured.split %0 after %1 { dimension = 0 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @get_size() -> i64
|
||||
|
||||
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
// expected-note @below {{dynamic split point}}
|
||||
%0 = func.call @get_size() : () -> i64
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
||||
^bb0(%3: f32, %4: f32):
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<100xf32>
|
||||
return %1 : tensor<100xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func_call : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "func.call"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
%1 = transform.pdl_match @func_call in %arg1
|
||||
// expected-error @below {{expected the dynamic split point handle to point to as many operations (0) as the target handle (1)}}
|
||||
transform.structured.split %0 after %1 { dimension = 0 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func private @get_size() -> i64
|
||||
|
||||
func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
%1 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
||||
^bb0(%3: f32, %4: f32):
|
||||
linalg.yield %3 : f32
|
||||
} -> tensor<100xf32>
|
||||
return %1 : tensor<100xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @func_return : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "func.return"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @func_return in %arg1
|
||||
// expected-error @below {{only applies to structured ops}}
|
||||
transform.structured.split %0 after 16 { dimension = 1 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func @noop(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
// expected-note @below {{target op}}
|
||||
return %arg0 : tensor<100xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
transform.with_pdl_patterns {
|
||||
^bb0(%arg0: !pdl.operation):
|
||||
pdl.pattern @linalg_generic : benefit(1) {
|
||||
%0 = pdl.operands
|
||||
%1 = pdl.types
|
||||
%2 = pdl.operation "linalg.generic"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
|
||||
pdl.rewrite %2 with "transform.dialect"
|
||||
}
|
||||
|
||||
transform.sequence %arg0 {
|
||||
^bb1(%arg1: !pdl.operation):
|
||||
%0 = transform.pdl_match @linalg_generic in %arg1
|
||||
// expected-error @below {{dimension 1 does not exist in target op}}
|
||||
transform.structured.split %0 after 16 { dimension = 1 }
|
||||
}
|
||||
}
|
||||
|
||||
func.func @one_d_static(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> {
|
||||
// expected-note @below {{target op}}
|
||||
%0 = linalg.generic {
|
||||
indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>],
|
||||
iterator_types = ["parallel"]
|
||||
}
|
||||
ins(%arg0: tensor<100xf32>) outs(%arg1: tensor<100xf32>) {
|
||||
^bb0(%0: f32, %1: f32):
|
||||
linalg.yield %0 : f32
|
||||
} -> tensor<100xf32>
|
||||
return %0 : tensor<100xf32>
|
||||
}
|
||||
|
|
@ -84,6 +84,19 @@ def testScalarize():
|
|||
# CHECK: transform.structured.scalarize
|
||||
|
||||
|
||||
@run
|
||||
def testSplit():
|
||||
sequence = transform.SequenceOp()
|
||||
with InsertionPoint(sequence.body):
|
||||
split = structured.SplitOp(sequence.bodyTarget, dimension=1, split_point=42)
|
||||
structured.SplitOp(
|
||||
split.results[0], dimension=3, split_point=split.results[1])
|
||||
transform.YieldOp()
|
||||
# CHECK-LABEL: TEST: testSplit
|
||||
# CHECK: %[[F:.+]], %[[S:.+]] = transform.structured.split %{{.*}} after 42 {dimension = 1
|
||||
# CHECK: transform.structured.split %[[F]] after %[[S]] {dimension = 3
|
||||
|
||||
|
||||
@run
|
||||
def testTileCompact():
|
||||
sequence = transform.SequenceOp()
|
||||
|
|
Loading…
Reference in New Issue