1939 lines
78 KiB
C++
1939 lines
78 KiB
C++
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements the Linalg operations.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
|
|
#include "mlir/AsmParser/AsmParser.h"
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
|
|
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
|
|
#include "mlir/Dialect/Complex/IR/Complex.h"
|
|
#include "mlir/Dialect/Math/IR/Math.h"
|
|
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
|
|
#include "mlir/Dialect/Utils/StaticValueUtils.h"
|
|
#include "mlir/IR/AffineExprVisitor.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/Interfaces/InferTypeOpInterface.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallSet.h"
|
|
#include "llvm/ADT/StringSet.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/MathExtras.h"
|
|
#include "llvm/Support/raw_ostream.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Support for named Linalg ops defined in ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
|
|
ArrayRef<NamedAttribute>)>;
|
|
|
|
/// Fills the region of a structured operation using the provided
|
|
/// `regionBuilder`. The method is used by both named structured ops created by
|
|
/// ods-gen and by manually defined C++ ops. It is called by both builders and
|
|
/// parsers and creates a block with arguments corresponding to the elemental
|
|
/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
|
|
/// ShapedType.
|
|
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
|
|
TypeRange inputTypes, TypeRange outputTypes,
|
|
ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
|
|
|
|
// TODO: atm all operands go through getElementTypeOrSelf,
|
|
// reconsider when we have evidence we need to.
|
|
SmallVector<Type, 8> argTypes;
|
|
SmallVector<Location, 8> argLocs;
|
|
for (auto containers : {inputTypes, outputTypes}) {
|
|
for (auto t : containers) {
|
|
argTypes.push_back(getElementTypeOrSelf(t));
|
|
|
|
// TODO: Pass in a proper location here.
|
|
argLocs.push_back(opBuilder.getUnknownLoc());
|
|
}
|
|
}
|
|
|
|
// RAII.
|
|
OpBuilder::InsertionGuard guard(opBuilder);
|
|
Block *body =
|
|
opBuilder.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs);
|
|
|
|
opBuilder.setInsertionPointToStart(body);
|
|
ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
|
|
regionBuilder(b, *body, attrs);
|
|
|
|
// indexing_maps is an auto-generated method.
|
|
|
|
// iterator_types is an auto-generated method.
|
|
}
|
|
|
|
/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
|
|
/// The result types are derived automatically if `resultTensorTypes` is none.
|
|
/// The body of the operation is filled using `regionBuilder`. All ods-gen
|
|
/// created structured operations use the method to implement their builders.
|
|
static void buildStructuredOp(OpBuilder &b, OperationState &state,
|
|
llvm::Optional<TypeRange> resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs,
|
|
ArrayRef<NamedAttribute> attributes,
|
|
RegionBuilderFn regionBuilder) {
|
|
// Derive the result types if needed.
|
|
SmallVector<Type> derivedResultTypes =
|
|
resultTensorTypes.value_or(TypeRange());
|
|
if (!resultTensorTypes)
|
|
copy_if(outputs.getTypes(), std::back_inserter(derivedResultTypes),
|
|
[](Type type) { return type.isa<RankedTensorType>(); });
|
|
|
|
state.addOperands(inputs);
|
|
state.addOperands(outputs);
|
|
state.addTypes(derivedResultTypes);
|
|
state.addAttributes(attributes);
|
|
state.addAttribute(
|
|
"operand_segment_sizes",
|
|
b.getI32VectorAttr({static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
|
|
// Create and fill the region of the structured operation.
|
|
Region ®ion = *state.addRegion();
|
|
fillStructuredOpRegion(b, region, TypeRange(inputs), TypeRange(outputs),
|
|
state.attributes.getAttrs(), regionBuilder);
|
|
}
|
|
|
|
/// Common parsing used for both named structured ops created by ods-gen and by
|
|
/// manually defined C++ ops. Does not handle regions.
|
|
static ParseResult
|
|
parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
|
|
SmallVectorImpl<Type> &inputTypes,
|
|
SmallVectorImpl<Type> &outputTypes) {
|
|
SMLoc inputsOperandsLoc, outputsOperandsLoc;
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
|
|
outputsOperands;
|
|
|
|
if (parser.parseOptionalAttrDict(result.attributes))
|
|
return failure();
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("ins"))) {
|
|
if (parser.parseLParen())
|
|
return failure();
|
|
|
|
inputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseOperandList(inputsOperands) ||
|
|
parser.parseColonTypeList(inputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (succeeded(parser.parseOptionalKeyword("outs"))) {
|
|
outputsOperandsLoc = parser.getCurrentLocation();
|
|
if (parser.parseLParen() || parser.parseOperandList(outputsOperands) ||
|
|
parser.parseColonTypeList(outputTypes) || parser.parseRParen())
|
|
return failure();
|
|
}
|
|
|
|
if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
|
|
result.operands) ||
|
|
parser.resolveOperands(outputsOperands, outputTypes, outputsOperandsLoc,
|
|
result.operands))
|
|
return failure();
|
|
|
|
result.addAttribute("operand_segment_sizes",
|
|
parser.getBuilder().getI32VectorAttr(
|
|
{static_cast<int32_t>(inputsOperands.size()),
|
|
static_cast<int32_t>(outputsOperands.size())}));
|
|
return success();
|
|
}
|
|
|
|
static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
|
|
ValueRange outputs) {
|
|
if (!inputs.empty())
|
|
p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
|
|
if (!outputs.empty())
|
|
p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Specific parsing and printing for named structured ops created by ods-gen.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static ParseResult parseNamedStructuredOpRegion(
|
|
OpAsmParser &parser, Region ®ion, unsigned numRegionArgs,
|
|
TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
|
|
RegionBuilderFn regionBuilder) {
|
|
if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
|
|
return parser.emitError(
|
|
parser.getCurrentLocation(),
|
|
llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
|
|
"region expects {0} args, got {1}",
|
|
numRegionArgs, inputTypes.size() + outputTypes.size()));
|
|
}
|
|
|
|
OpBuilder opBuilder(parser.getContext());
|
|
fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
|
|
regionBuilder);
|
|
return success();
|
|
}
|
|
|
|
static ParseResult
|
|
parseNamedStructuredOpResults(OpAsmParser &parser,
|
|
SmallVectorImpl<Type> &resultTypes) {
|
|
if (parser.parseOptionalArrowTypeList(resultTypes))
|
|
return failure();
|
|
return success();
|
|
}
|
|
|
|
static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
|
|
OperationState &result,
|
|
unsigned numRegionArgs,
|
|
RegionBuilderFn regionBuilder) {
|
|
// TODO: Enable when ods-gen supports captures.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// TODO: consider merging results parsing into region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parseNamedStructuredOpRegion(parser, *region, numRegionArgs, inputTypes,
|
|
outputTypes, result.attributes.getAttrs(),
|
|
regionBuilder))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
return success();
|
|
}
|
|
|
|
static void printNamedStructuredOpResults(OpAsmPrinter &p,
|
|
TypeRange resultTypes) {
|
|
if (resultTypes.empty())
|
|
return;
|
|
p.printOptionalArrowTypeList(resultTypes);
|
|
}
|
|
|
|
static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
|
|
ValueRange inputs, ValueRange outputs) {
|
|
p.printOptionalAttrDict(
|
|
op->getAttrs(),
|
|
/*elidedAttrs=*/{"operand_segment_sizes",
|
|
// See generated code in mlir-linalg-yaml-gen.cpp
|
|
"linalg.memoized_indexing_maps"});
|
|
|
|
// Printing is shared with generic ops, except for the region and
|
|
// attributes.
|
|
printCommonStructuredOpParts(p, inputs, outputs);
|
|
|
|
// Results printing.
|
|
printNamedStructuredOpResults(p, op->getResultTypes());
|
|
|
|
// Region is elided.
|
|
}
|
|
|
|
/// This is a common class used for patterns of the form
|
|
/// ```
|
|
/// someop(memrefcast(%src)) -> someop(%src)
|
|
/// ```
|
|
/// It folds the source of the memref.cast into the root operation directly.
|
|
static LogicalResult foldMemRefCast(Operation *op) {
|
|
bool folded = false;
|
|
for (OpOperand &operand : op->getOpOperands()) {
|
|
auto castOp = operand.get().getDefiningOp<memref::CastOp>();
|
|
if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
|
|
operand.set(castOp.getOperand());
|
|
folded = true;
|
|
}
|
|
}
|
|
return success(folded);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Region builder helper.
|
|
// TODO: Move this to a utility library.
|
|
// The public methods on this class are referenced directly from generated code.
|
|
// Helper build the unary, binary, and type conversion functions defined by the
|
|
// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class.
|
|
//
|
|
// Implementations of the math functions must be polymorphic over numeric types,
|
|
// internally performing necessary casts. If the function application makes no
|
|
// sense, then the only recourse is to assert and return nullptr. This can be
|
|
// extended later if it becomes possible to fail construction of the region. The
|
|
// invariant should be enforced at a higher level.
|
|
//
|
|
// TODO: These helpers are currently type polymorphic over the class of integer
|
|
// and floating point types, but they will not internally cast within bit
|
|
// widths of a class (mixed precision such as i8->i32) or across classes
|
|
// (i.e. mixed float and integer). Many such combinations are ambiguous or need
|
|
// to be handled with care and work is being considered to extend the op
|
|
// language to make such cases explicit. In the mean-time, violating this will
|
|
// fail verification, which is deemed acceptable.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
class RegionBuilderHelper {
|
|
public:
|
|
RegionBuilderHelper(MLIRContext *context, Block &block)
|
|
: context(context), block(block) {}
|
|
|
|
// Build the unary functions defined by OpDSL.
|
|
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
|
|
if (!isFloatingPoint(arg))
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder builder = getBuilder();
|
|
switch (unaryFn) {
|
|
case UnaryFn::exp:
|
|
return builder.create<math::ExpOp>(arg.getLoc(), arg);
|
|
case UnaryFn::log:
|
|
return builder.create<math::LogOp>(arg.getLoc(), arg);
|
|
case UnaryFn::abs:
|
|
return builder.create<math::AbsFOp>(arg.getLoc(), arg);
|
|
case UnaryFn::ceil:
|
|
return builder.create<math::CeilOp>(arg.getLoc(), arg);
|
|
case UnaryFn::floor:
|
|
return builder.create<math::FloorOp>(arg.getLoc(), arg);
|
|
case UnaryFn::negf:
|
|
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
|
|
}
|
|
llvm_unreachable("unsupported unary function");
|
|
}
|
|
|
|
// Build the binary functions defined by OpDSL.
|
|
Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
|
|
bool allComplex = isComplex(arg0) && isComplex(arg1);
|
|
bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1);
|
|
bool allInteger = isInteger(arg0) && isInteger(arg1);
|
|
bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
|
|
arg1.getType().getIntOrFloatBitWidth() == 1;
|
|
if (!allComplex && !allFloatingPoint && !allInteger)
|
|
llvm_unreachable("unsupported non numeric type");
|
|
OpBuilder builder = getBuilder();
|
|
switch (binaryFn) {
|
|
case BinaryFn::add:
|
|
if (allComplex)
|
|
return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::sub:
|
|
if (allComplex)
|
|
return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
llvm_unreachable("unsupported operation: sub with bools");
|
|
return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::mul:
|
|
if (allComplex)
|
|
return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
|
|
if (allBool)
|
|
return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_signed:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::max_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MaxFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
|
|
case BinaryFn::min_unsigned:
|
|
assert(!allComplex);
|
|
if (allFloatingPoint)
|
|
return builder.create<arith::MinFOp>(arg0.getLoc(), arg0, arg1);
|
|
return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
|
|
}
|
|
llvm_unreachable("unsupported binary function");
|
|
}
|
|
|
|
// Build the type functions defined by OpDSL.
|
|
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
|
|
switch (typeFn) {
|
|
case TypeFn::cast_signed:
|
|
return cast(toType, operand, false);
|
|
case TypeFn::cast_unsigned:
|
|
return cast(toType, operand, true);
|
|
}
|
|
llvm_unreachable("unsupported type conversion function");
|
|
}
|
|
|
|
void yieldOutputs(ValueRange values) {
|
|
OpBuilder builder = getBuilder();
|
|
Location loc = builder.getUnknownLoc();
|
|
builder.create<YieldOp>(loc, values);
|
|
}
|
|
|
|
Value constant(const std::string &value) {
|
|
OpBuilder builder = getBuilder();
|
|
Location loc = builder.getUnknownLoc();
|
|
Attribute valueAttr = parseAttribute(value, builder.getContext());
|
|
Type type = NoneType::get(builder.getContext());
|
|
if (auto typedAttr = valueAttr.dyn_cast<TypedAttr>())
|
|
type = typedAttr.getType();
|
|
return builder.create<arith::ConstantOp>(loc, type, valueAttr);
|
|
}
|
|
|
|
Value index(int64_t dim) {
|
|
OpBuilder builder = getBuilder();
|
|
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
|
|
}
|
|
|
|
Type getIntegerType(unsigned width) {
|
|
return IntegerType::get(context, width);
|
|
}
|
|
|
|
Type getFloat32Type() { return Float32Type::get(context); }
|
|
Type getFloat64Type() { return Float64Type::get(context); }
|
|
|
|
private:
|
|
// Generates operations to cast the given operand to a specified type.
|
|
// If the cast cannot be performed, a warning will be issued and the
|
|
// operand returned as-is (which will presumably yield a verification
|
|
// issue downstream).
|
|
Value cast(Type toType, Value operand, bool isUnsignedCast) {
|
|
OpBuilder builder = getBuilder();
|
|
auto loc = operand.getLoc();
|
|
|
|
if (operand.getType() == toType)
|
|
return operand;
|
|
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
|
|
// If operand is floating point, cast directly to the int type.
|
|
if (operand.getType().isa<FloatType>()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::FPToUIOp>(loc, toType, operand);
|
|
return builder.create<arith::FPToSIOp>(loc, toType, operand);
|
|
}
|
|
// Cast index operands directly to the int type.
|
|
if (operand.getType().isIndex())
|
|
return builder.create<arith::IndexCastOp>(loc, toType, operand);
|
|
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
|
|
// Either extend or truncate.
|
|
if (toIntType.getWidth() > fromIntType.getWidth()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::ExtUIOp>(loc, toType, operand);
|
|
return builder.create<arith::ExtSIOp>(loc, toType, operand);
|
|
}
|
|
if (toIntType.getWidth() < fromIntType.getWidth())
|
|
return builder.create<arith::TruncIOp>(loc, toType, operand);
|
|
}
|
|
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
|
|
// If operand is integer, cast directly to the float type.
|
|
// Note that it is unclear how to cast from BF16<->FP16.
|
|
if (operand.getType().isa<IntegerType>()) {
|
|
if (isUnsignedCast)
|
|
return builder.create<arith::UIToFPOp>(loc, toFloatType, operand);
|
|
return builder.create<arith::SIToFPOp>(loc, toFloatType, operand);
|
|
}
|
|
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
|
|
if (toFloatType.getWidth() > fromFloatType.getWidth())
|
|
return builder.create<arith::ExtFOp>(loc, toFloatType, operand);
|
|
if (toFloatType.getWidth() < fromFloatType.getWidth())
|
|
return builder.create<arith::TruncFOp>(loc, toFloatType, operand);
|
|
}
|
|
}
|
|
|
|
emitWarning(operand.getLoc()) << "could not cast operand of type "
|
|
<< operand.getType() << " to " << toType;
|
|
return operand;
|
|
}
|
|
|
|
bool isComplex(Value value) { return value.getType().isa<ComplexType>(); }
|
|
bool isFloatingPoint(Value value) { return value.getType().isa<FloatType>(); }
|
|
bool isInteger(Value value) { return value.getType().isa<IntegerType>(); }
|
|
|
|
OpBuilder getBuilder() {
|
|
OpBuilder builder(context);
|
|
builder.setInsertionPointToEnd(&block);
|
|
return builder;
|
|
}
|
|
|
|
MLIRContext *context;
|
|
Block █
|
|
};
|
|
|
|
} // namespace
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// FillOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
|
|
///
|
|
/// For such op chains, we can create new linalg.fill ops with the result
|
|
/// type of the tensor.expand/collapse_shape op.
|
|
template <typename TensorReshapeOp>
|
|
struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
|
|
if (!oldFill)
|
|
return failure();
|
|
|
|
Location loc = oldFill.getLoc();
|
|
auto newInit = rewriter.create<TensorReshapeOp>(
|
|
loc, reshapeOp.getResultType(), oldFill.output(),
|
|
reshapeOp.getReassociation());
|
|
rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
|
|
ValueRange{newInit});
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
|
|
if (!fillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = padOp.getConstantPaddingValue();
|
|
if (!padValue || fillOp.value() != padValue)
|
|
return failure();
|
|
|
|
ReifiedRankedShapedTypeDims reifiedShape;
|
|
ReifyRankedShapedTypeOpInterface interface =
|
|
cast<ReifyRankedShapedTypeOpInterface>(padOp.getOperation());
|
|
if (failed(interface.reifyResultShapes(rewriter, reifiedShape)))
|
|
return rewriter.notifyMatchFailure(
|
|
padOp, "failed to reify tensor.pad op result shape");
|
|
|
|
auto oldResultType = padOp.getResultType();
|
|
SmallVector<int64_t, 4> staticShape(oldResultType.getRank(),
|
|
ShapedType::kDynamicSize);
|
|
auto newInitOp = rewriter.create<InitTensorOp>(
|
|
padOp.getLoc(), reifiedShape.front(), staticShape,
|
|
oldResultType.getElementType());
|
|
auto newFillOp = rewriter.create<FillOp>(
|
|
fillOp.getLoc(), ValueRange{padValue}, ValueRange{newInitOp});
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(padOp, oldResultType,
|
|
newFillOp.result());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
|
|
/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
|
|
/// filling value are the same.
|
|
struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
|
|
using OpRewritePattern::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
|
|
if (!srcPadOp)
|
|
return failure();
|
|
|
|
if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Walk back the tensor.insert_slice chain and find the first destination
|
|
// value at the start of the chain.
|
|
Value firstDest = insertOp.getDest();
|
|
while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
|
|
if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
|
|
return failure();
|
|
|
|
// Make sure the range of values accessed are disjoint. Without this, we
|
|
// cannot fold tensor.pad away.
|
|
bool disjoint = false;
|
|
for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
|
|
// If the dimension has dynamic offset/size, we cannot guarantee
|
|
// disjoint. So just skip it.
|
|
if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
|
|
insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
|
|
prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
|
|
continue;
|
|
|
|
// Get the range start and end, inclusively for both.
|
|
int64_t prevStart = prevOp.getStaticOffset(i);
|
|
int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
|
|
prevOp.getStaticStride(i);
|
|
int64_t nextStart = insertOp.getStaticOffset(i);
|
|
int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
|
|
insertOp.getStaticStride(i);
|
|
if (prevEnd < nextStart || nextEnd < prevStart) {
|
|
disjoint = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!disjoint)
|
|
break;
|
|
firstDest = prevOp.getDest();
|
|
}
|
|
|
|
// Check whether the first destination is a fill op. For overlapped cases,
|
|
// this also cannot be true.
|
|
auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
|
|
if (!dstFillOp)
|
|
return failure();
|
|
|
|
// We can only fold if the padding value is the same as the original
|
|
// filling value.
|
|
Value padValue = srcPadOp.getConstantPaddingValue();
|
|
if (!padValue || dstFillOp.value() != padValue)
|
|
return failure();
|
|
|
|
SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
|
|
|
|
Location loc = insertOp.getLoc();
|
|
MLIRContext *context = getContext();
|
|
|
|
AffineExpr sym0, sym1;
|
|
bindSymbols(context, sym0, sym1);
|
|
auto addMap = AffineMap::get(0, 2, {sym0 + sym1}, context);
|
|
|
|
// Calculate the new offsets for the insert. It should be the old offsets
|
|
// plus low padding sizes.
|
|
SmallVector<OpFoldResult, 4> newOffsets;
|
|
for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
|
|
newOffsets.push_back(makeComposedFoldedAffineApply(
|
|
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
|
|
}
|
|
|
|
SmallVector<OpFoldResult, 4> newSizes;
|
|
for (int i = 0, e = srcPadOp.getSourceType().getRank(); i < e; ++i) {
|
|
newSizes.push_back(
|
|
rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
|
|
.getResult());
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
|
|
insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
|
|
newSizes, insertOp.getMixedStrides());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results
|
|
.add<FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
|
|
FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
|
|
FoldInsertPadIntoFill>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// GenericOps
|
|
//===----------------------------------------------------------------------===//
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
|
|
ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall);
|
|
result.addAttributes(attributes);
|
|
if (!bodyBuild)
|
|
return;
|
|
|
|
SmallVector<Type, 4> blockArgTypes;
|
|
SmallVector<Location, 4> blockArgLocs;
|
|
for (ValueRange container : {inputs, outputs}) {
|
|
for (Value v : container) {
|
|
blockArgTypes.push_back(getElementTypeOrSelf(v));
|
|
blockArgLocs.push_back(v.getLoc());
|
|
}
|
|
}
|
|
|
|
OpBuilder::InsertionGuard guard(builder);
|
|
auto ®ion = *result.regions.front();
|
|
Block *bodyBlock =
|
|
builder.createBlock(®ion, region.end(), blockArgTypes, blockArgLocs);
|
|
bodyBuild(builder, result.location, bodyBlock->getArguments());
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs,
|
|
builder.getAffineMapArrayAttr(indexingMaps),
|
|
builder.getStrArrayAttr(iteratorTypes),
|
|
doc.empty() ? StringAttr() : builder.getStringAttr(doc),
|
|
libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
|
|
bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
|
|
iteratorTypes, doc, libraryCall, bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, ValueRange inputs,
|
|
ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<StringRef> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::build(
|
|
OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
|
|
ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
|
|
ArrayRef<StringRef> iteratorTypes,
|
|
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
|
|
ArrayRef<NamedAttribute> attributes) {
|
|
build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
|
|
iteratorTypes,
|
|
/*doc=*/"",
|
|
/*libraryCall=*/"", bodyBuild, attributes);
|
|
}
|
|
|
|
void GenericOp::print(OpAsmPrinter &p) {
|
|
p << " ";
|
|
|
|
// Print extra attributes.
|
|
auto genericAttrNames = linalgTraitAttrNames();
|
|
|
|
llvm::StringSet<> genericAttrNamesSet;
|
|
genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
|
|
SmallVector<NamedAttribute, 8> genericAttrs;
|
|
for (auto attr : (*this)->getAttrs())
|
|
if (genericAttrNamesSet.count(attr.getName().strref()) > 0)
|
|
genericAttrs.push_back(attr);
|
|
if (!genericAttrs.empty()) {
|
|
auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
|
|
p << genericDictAttr;
|
|
}
|
|
|
|
// Printing is shared with named ops, except for the region and attributes
|
|
printCommonStructuredOpParts(p, getInputs(), getOutputs());
|
|
|
|
genericAttrNames.push_back("operand_segment_sizes");
|
|
genericAttrNamesSet.insert(genericAttrNames.back());
|
|
|
|
bool hasExtraAttrs = false;
|
|
for (NamedAttribute n : (*this)->getAttrs()) {
|
|
if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
|
|
break;
|
|
}
|
|
if (hasExtraAttrs) {
|
|
p << " attrs = ";
|
|
p.printOptionalAttrDict((*this)->getAttrs(),
|
|
/*elidedAttrs=*/genericAttrNames);
|
|
}
|
|
|
|
// Print region.
|
|
if (!getRegion().empty()) {
|
|
p << ' ';
|
|
p.printRegion(getRegion());
|
|
}
|
|
|
|
// Print results.
|
|
printNamedStructuredOpResults(p, getResultTensors().getTypes());
|
|
}
|
|
|
|
ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
DictionaryAttr dictAttr;
|
|
// Parse the core linalg traits that must check into a dictAttr.
|
|
// The name is unimportant as we will overwrite result.attributes.
|
|
// The core linalg traits must contain the information necessary to pass the
|
|
// verifier.
|
|
if (parser.parseAttribute(dictAttr, "_", result.attributes))
|
|
return failure();
|
|
result.attributes.assign(dictAttr.getValue().begin(),
|
|
dictAttr.getValue().end());
|
|
|
|
// Parsing is shared with named ops, except for the region.
|
|
SmallVector<Type, 1> inputTypes, outputTypes;
|
|
if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
|
|
return failure();
|
|
|
|
// Optional attributes may be added.
|
|
if (succeeded(parser.parseOptionalKeyword("attrs")))
|
|
if (failed(parser.parseEqual()) ||
|
|
failed(parser.parseOptionalAttrDict(result.attributes)))
|
|
return failure();
|
|
|
|
std::unique_ptr<Region> region = std::make_unique<Region>();
|
|
if (parser.parseRegion(*region, {}))
|
|
return failure();
|
|
result.addRegion(std::move(region));
|
|
|
|
// Generic ops may specify that a subset of its outputs are tensors. Such
|
|
// outputs are specified in the result type.
|
|
// TODO: may need to move output parsing before region parsing.
|
|
// Need to wait for declarative assembly resolution to decide.
|
|
SmallVector<Type, 1> outputTensorsTypes;
|
|
if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
|
|
return failure();
|
|
result.addTypes(outputTensorsTypes);
|
|
|
|
return success();
|
|
}
|
|
|
|
static void getGenericEffectsImpl(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects,
|
|
ValueRange results, ValueRange inputBuffers, ValueRange outputs) {
|
|
for (Value value : inputBuffers) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
for (Value value : outputs) {
|
|
effects.emplace_back(MemoryEffects::Read::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
effects.emplace_back(MemoryEffects::Write::get(), value,
|
|
SideEffects::DefaultResource::get());
|
|
}
|
|
}
|
|
|
|
void GenericOp::getEffects(
|
|
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
|
&effects) {
|
|
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
|
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
|
getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
|
|
outputBuffers);
|
|
}
|
|
|
|
LogicalResult GenericOp::verify() { return success(); }
|
|
|
|
namespace {
|
|
|
|
struct DeduplicateAndRemoveDeadOperandsAndResults
|
|
: public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Create a map from argument position in the original op to the argument
|
|
// position in the new op. If the argument is dropped it wont have an entry.
|
|
SmallVector<OpOperand *> droppedOpOperands;
|
|
|
|
// Information needed to build the new op.
|
|
SmallVector<Value> newInputOperands, newOutputOperands;
|
|
SmallVector<AffineMap> newIndexingMaps;
|
|
|
|
// Gather information about duplicate input operands.
|
|
llvm::SmallDenseMap<unsigned, unsigned> origInsToNewInsPos =
|
|
deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands,
|
|
newIndexingMaps);
|
|
|
|
// Gather information about the dropped outputs.
|
|
llvm::SmallDenseMap<unsigned, unsigned> origOutsToNewOutsPos =
|
|
deduplicateOutputOperands(genericOp, droppedOpOperands,
|
|
newOutputOperands, newIndexingMaps);
|
|
|
|
// Check if there is any change to operands.
|
|
if (newInputOperands.size() + newOutputOperands.size() ==
|
|
static_cast<size_t>(genericOp.getNumInputsAndOutputs()))
|
|
return failure();
|
|
|
|
// Create the new op with the body being empty.
|
|
Location loc = genericOp.getLoc();
|
|
SmallVector<Type> newResultTypes;
|
|
if (genericOp.hasTensorSemantics()) {
|
|
newResultTypes = llvm::to_vector(llvm::map_range(
|
|
newOutputOperands, [](Value v) { return v.getType(); }));
|
|
}
|
|
auto newOp = rewriter.create<GenericOp>(
|
|
loc, newResultTypes, newInputOperands, newOutputOperands,
|
|
rewriter.getAffineMapArrayAttr(newIndexingMaps),
|
|
genericOp.getIteratorTypes(), genericOp.getDocAttr(),
|
|
genericOp.getLibraryCallAttr(),
|
|
[](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) {
|
|
return;
|
|
});
|
|
// Copy over unknown attributes. They might be load bearing for some flow.
|
|
ArrayRef<StringRef> odsAttrs = genericOp.getAttributeNames();
|
|
for (NamedAttribute kv : genericOp->getAttrs())
|
|
if (!llvm::is_contained(odsAttrs, kv.getName().getValue()))
|
|
newOp->setAttr(kv.getName(), kv.getValue());
|
|
|
|
// Fix up the payload of the canonicalized operation.
|
|
populateOpPayload(genericOp, newOp, origInsToNewInsPos,
|
|
origOutsToNewOutsPos, rewriter);
|
|
|
|
// Replace all live uses of the op.
|
|
SmallVector<Value> replacementsVals(genericOp->getNumResults(), nullptr);
|
|
for (const auto &result : llvm::enumerate(genericOp.getResults())) {
|
|
auto it = origOutsToNewOutsPos.find(result.index());
|
|
if (it == origOutsToNewOutsPos.end())
|
|
continue;
|
|
replacementsVals[result.index()] = newOp.getResult(it->second);
|
|
}
|
|
rewriter.replaceOp(genericOp, replacementsVals);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
// Deduplicate input operands, and return the
|
|
// - Mapping from operand position in the original op, to operand position in
|
|
// the canonicalized op.
|
|
// - The preserved input operands list (by reference).
|
|
llvm::SmallDenseMap<unsigned, unsigned>
|
|
deduplicateInputOperands(GenericOp genericOp,
|
|
SmallVector<OpOperand *> &droppedOpOperands,
|
|
SmallVector<Value> &newInputOperands,
|
|
SmallVector<AffineMap> &newIndexingMaps) const {
|
|
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
|
|
llvm::SmallDenseMap<std::pair<Value, AffineMap>, unsigned> dedupedInputs;
|
|
for (const auto &inputOpOperand :
|
|
llvm::enumerate(genericOp.getInputOperands())) {
|
|
// Check if operand is dead and if dropping the indexing map makes the
|
|
// loops to shape computation invalid.
|
|
if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) {
|
|
// Add the current operands to the list of potentially droppable
|
|
// operands. If it cannot be dropped, this needs to be popped back.
|
|
droppedOpOperands.push_back(inputOpOperand.value());
|
|
if (genericOp.canOpOperandsBeDropped(droppedOpOperands))
|
|
continue;
|
|
droppedOpOperands.pop_back();
|
|
}
|
|
|
|
// Check if this operand is a duplicate.
|
|
AffineMap indexingMap =
|
|
genericOp.getTiedIndexingMap(inputOpOperand.value());
|
|
auto it = dedupedInputs.find(
|
|
std::make_pair(inputOpOperand.value()->get(), indexingMap));
|
|
if (it != dedupedInputs.end()) {
|
|
origToNewPos[inputOpOperand.index()] = it->second;
|
|
droppedOpOperands.push_back(inputOpOperand.value());
|
|
continue;
|
|
}
|
|
|
|
// This is a preserved argument.
|
|
origToNewPos[inputOpOperand.index()] = newInputOperands.size();
|
|
dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] =
|
|
newInputOperands.size();
|
|
newInputOperands.push_back(inputOpOperand.value()->get());
|
|
newIndexingMaps.push_back(indexingMap);
|
|
}
|
|
return origToNewPos;
|
|
}
|
|
|
|
// Deduplicate output operands, and return the
|
|
// - Mapping from operand position in the original op, to operand position in
|
|
// the canonicalized op.
|
|
// - The preserved output operands list (by reference).
|
|
llvm::SmallDenseMap<unsigned, unsigned>
|
|
deduplicateOutputOperands(GenericOp genericOp,
|
|
SmallVector<OpOperand *> &droppedOpOperands,
|
|
SmallVector<Value> &newOutputOperands,
|
|
SmallVector<AffineMap> &newIndexingMaps) const {
|
|
llvm::SmallDenseMap<unsigned, unsigned> origToNewPos;
|
|
llvm::SmallDenseMap<std::tuple<Value, AffineMap, Value>, unsigned>
|
|
dedupedOutpts;
|
|
// If the op doesnt have tensor semantics, keep all the outputs as
|
|
// preserved.
|
|
if (!genericOp.hasTensorSemantics()) {
|
|
for (const auto &outputOpOperand :
|
|
llvm::enumerate(genericOp.getOutputOperands())) {
|
|
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
|
|
newOutputOperands.push_back(outputOpOperand.value()->get());
|
|
newIndexingMaps.push_back(
|
|
genericOp.getTiedIndexingMap(outputOpOperand.value()));
|
|
}
|
|
} else {
|
|
// Output argument can be dropped if the result has
|
|
// - no users, and
|
|
// - it is not used in the payload, and
|
|
// - the corresponding indexing maps are not needed for loop bound
|
|
// computation.
|
|
auto yieldOp = cast<YieldOp>(genericOp.getBody()->getTerminator());
|
|
for (const auto &outputOpOperand :
|
|
llvm::enumerate(genericOp.getOutputOperands())) {
|
|
Value result = genericOp.getResult(outputOpOperand.index());
|
|
AffineMap indexingMap =
|
|
genericOp.getTiedIndexingMap(outputOpOperand.value());
|
|
auto key =
|
|
std::make_tuple(outputOpOperand.value()->get(), indexingMap,
|
|
yieldOp->getOperand(outputOpOperand.index()));
|
|
|
|
// Do not drop an out if its value is used in the payload.
|
|
if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) {
|
|
if (result.use_empty()) {
|
|
// Check if the opoperand can be dropped without affecting loop
|
|
// bound computation. Add the operand to the list of dropped op
|
|
// operand for checking. If it cannot be dropped, need to pop the
|
|
// value back.
|
|
droppedOpOperands.push_back(outputOpOperand.value());
|
|
if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) {
|
|
continue;
|
|
}
|
|
droppedOpOperands.pop_back();
|
|
}
|
|
|
|
// The out operand can also be dropped if it is computed redundantly
|
|
// by another result, the conditions for that are
|
|
// - The same operand is used as the out operand
|
|
// - The same indexing map is used
|
|
// - The same yield value is used.
|
|
auto it = dedupedOutpts.find(key);
|
|
if (it != dedupedOutpts.end()) {
|
|
origToNewPos[outputOpOperand.index()] = it->second;
|
|
droppedOpOperands.push_back(outputOpOperand.value());
|
|
continue;
|
|
}
|
|
}
|
|
|
|
origToNewPos[outputOpOperand.index()] = newOutputOperands.size();
|
|
dedupedOutpts[key] = newOutputOperands.size();
|
|
newOutputOperands.push_back(outputOpOperand.value()->get());
|
|
newIndexingMaps.push_back(
|
|
genericOp.getTiedIndexingMap(outputOpOperand.value()));
|
|
}
|
|
}
|
|
|
|
return origToNewPos;
|
|
}
|
|
|
|
// Populate the body of the canonicalized operation.
|
|
void populateOpPayload(
|
|
GenericOp genericOp, GenericOp newOp,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &origInsToNewInsPos,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &origOutsToNewOutsPos,
|
|
PatternRewriter &rewriter) const {
|
|
// Merge the body of the original op with the new op.
|
|
Block *newOpBlock = &newOp.getRegion().front();
|
|
assert(newOpBlock->empty() && "expected new op to have an empty payload");
|
|
Block *origOpBlock = &genericOp.getRegion().front();
|
|
SmallVector<Value> replacements(origOpBlock->getNumArguments(), nullptr);
|
|
|
|
// Replace all arguments in the original op, with arguments from the
|
|
// canonicalized op.
|
|
auto updateReplacements =
|
|
[&](OpOperandVector &origOperands, OpOperandVector &newOperands,
|
|
const llvm::SmallDenseMap<unsigned, unsigned> &map) {
|
|
for (const auto &origOperand : llvm::enumerate(origOperands)) {
|
|
auto it = map.find(origOperand.index());
|
|
if (it == map.end())
|
|
continue;
|
|
OpOperand *newOperand = newOperands[it->second];
|
|
replacements[origOperand.value()->getOperandNumber()] =
|
|
newOpBlock->getArgument(newOperand->getOperandNumber());
|
|
}
|
|
};
|
|
|
|
OpOperandVector origInputOperands = genericOp.getInputOperands();
|
|
OpOperandVector newInputOperands = newOp.getInputOperands();
|
|
updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos);
|
|
|
|
OpOperandVector origOutputOperands = genericOp.getOutputOperands();
|
|
OpOperandVector newOutputOperands = newOp.getOutputOperands();
|
|
updateReplacements(origOutputOperands, newOutputOperands,
|
|
origOutsToNewOutsPos);
|
|
|
|
rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements);
|
|
|
|
// Drop the unused yield args.
|
|
if (newOp.getNumOutputs() != genericOp.getNumOutputs()) {
|
|
OpBuilder::InsertionGuard g(rewriter);
|
|
YieldOp origYieldOp = cast<YieldOp>(newOpBlock->getTerminator());
|
|
rewriter.setInsertionPoint(origYieldOp);
|
|
|
|
SmallVector<Value> newYieldVals(newOp.getNumOutputs(), nullptr);
|
|
for (const auto &yieldOpOperands :
|
|
llvm::enumerate(origYieldOp.getValues())) {
|
|
auto it = origOutsToNewOutsPos.find(yieldOpOperands.index());
|
|
if (it == origOutsToNewOutsPos.end())
|
|
continue;
|
|
newYieldVals[it->second] = yieldOpOperands.value();
|
|
}
|
|
rewriter.replaceOpWithNewOp<YieldOp>(origYieldOp, newYieldVals);
|
|
}
|
|
}
|
|
};
|
|
|
|
/// Remove generic operations (on tensors) that are just copying
|
|
/// the values from inputs to the results. Requirements are
|
|
/// 1) All iterator types are parallel
|
|
/// 2) The body contains just a yield operation with the yielded values being
|
|
/// the arguments corresponding to the operands.
|
|
struct EraseIdentityGenericOp : public OpRewritePattern<GenericOp> {
|
|
using OpRewritePattern<GenericOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
// Check all indexing maps are identity.
|
|
if (llvm::any_of(genericOp.getIndexingMapsArray(),
|
|
[](AffineMap map) { return !map.isIdentity(); }))
|
|
return failure();
|
|
|
|
// Check that the body of the linalg operation is just a linalg.yield
|
|
// operation.
|
|
Block &body = genericOp.getRegion().front();
|
|
if (!llvm::hasSingleElement(body))
|
|
return failure();
|
|
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
|
|
if (!yieldOp)
|
|
return failure();
|
|
|
|
// In the buffer case, we need to check exact buffer equality.
|
|
if (genericOp.hasBufferSemantics()) {
|
|
if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 &&
|
|
genericOp.getInputOperand(0)->get() ==
|
|
genericOp.getOutputOperand(0)->get()) {
|
|
rewriter.eraseOp(genericOp);
|
|
return success();
|
|
}
|
|
return failure();
|
|
}
|
|
|
|
// Get the argument number of the returned values. That is the operand
|
|
// number to use for replacing uses of this operation.
|
|
SmallVector<Value> returnedArgs;
|
|
for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
|
|
auto yieldArg = yieldVal.value().dyn_cast<BlockArgument>();
|
|
if (!yieldArg || yieldArg.getOwner() != &body)
|
|
return failure();
|
|
unsigned argumentNumber = yieldArg.getArgNumber();
|
|
Value returnedArg = genericOp->getOperand(argumentNumber);
|
|
Type resultType = genericOp->getResult(yieldVal.index()).getType();
|
|
// The input can have a different type than the result, e.g. a dynamic
|
|
// input dimension can be turned into a static output dimension.
|
|
Type returnType = returnedArg.getType();
|
|
if (returnType != resultType) {
|
|
// Distinguish between sparse conversion or dense tensor casting.
|
|
// TODO: unify the two ops?
|
|
if (sparse_tensor::getSparseTensorEncoding(returnType) ||
|
|
sparse_tensor::getSparseTensorEncoding(resultType))
|
|
returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
|
|
genericOp.getLoc(), resultType, returnedArg);
|
|
else {
|
|
if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
|
|
resultType))
|
|
return failure();
|
|
returnedArg = rewriter.create<tensor::CastOp>(
|
|
genericOp.getLoc(), resultType, returnedArg);
|
|
}
|
|
}
|
|
returnedArgs.push_back(returnedArg);
|
|
}
|
|
|
|
if (returnedArgs.size() != genericOp->getNumResults())
|
|
return failure();
|
|
rewriter.replaceOp(genericOp, returnedArgs);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results
|
|
.add<DeduplicateAndRemoveDeadOperandsAndResults, EraseIdentityGenericOp>(
|
|
context);
|
|
}
|
|
|
|
LogicalResult GenericOp::fold(ArrayRef<Attribute>,
|
|
SmallVectorImpl<OpFoldResult> &) {
|
|
return foldMemRefCast(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// InitTensorOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void InitTensorOp::build(OpBuilder &b, OperationState &result,
|
|
ArrayRef<OpFoldResult> sizes, Type elementType,
|
|
ArrayRef<NamedAttribute> attrs) {
|
|
SmallVector<Value, 4> dynamicSizes;
|
|
SmallVector<int64_t, 4> staticSizes;
|
|
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
|
|
ShapedType::kDynamicSize);
|
|
auto resultType = RankedTensorType ::get(staticSizes, elementType);
|
|
build(b, result, resultType, dynamicSizes, b.getI64ArrayAttr(staticSizes));
|
|
result.addAttributes(attrs);
|
|
}
|
|
|
|
LogicalResult InitTensorOp::verify() {
|
|
RankedTensorType resultType = getType();
|
|
SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
|
|
getStaticSizes().cast<ArrayAttr>(),
|
|
[](Attribute a) -> int64_t { return a.cast<IntegerAttr>().getInt(); }));
|
|
|
|
if (failed(verifyListOfOperandsOrIntegers(
|
|
*this, "sizes", resultType.getRank(), getStaticSizes(), getSizes(),
|
|
ShapedType::isDynamic)))
|
|
return failure();
|
|
|
|
if (getStaticSizes().size() != static_cast<unsigned>(resultType.getRank()))
|
|
return emitError("expected ") << resultType.getRank() << " sizes values";
|
|
|
|
Type expectedType = InitTensorOp::inferResultType(
|
|
staticSizes, resultType.getElementType(), resultType.getEncoding());
|
|
if (resultType != expectedType) {
|
|
return emitError("specified type ")
|
|
<< resultType << " does not match the inferred type "
|
|
<< expectedType;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
Type InitTensorOp::inferResultType(ArrayRef<int64_t> staticSizes,
|
|
Type elementType, Attribute encoding) {
|
|
return RankedTensorType::get(staticSizes, elementType, encoding);
|
|
}
|
|
|
|
SmallVector<OpFoldResult> InitTensorOp::getMixedSizes() {
|
|
SmallVector<OpFoldResult> mixedSizes;
|
|
mixedSizes.reserve(getType().getRank());
|
|
unsigned dynamicValIndex = 0;
|
|
for (Attribute attr : getStaticSizes()) {
|
|
auto intAttr = attr.cast<IntegerAttr>();
|
|
if (!ShapedType::isDynamic(intAttr.getInt())) {
|
|
mixedSizes.push_back(intAttr);
|
|
continue;
|
|
}
|
|
mixedSizes.push_back(getSizes()[dynamicValIndex++]);
|
|
}
|
|
return mixedSizes;
|
|
}
|
|
|
|
namespace {
|
|
/// Change the type of the result of a `linalg.init_tensor` by making the result
|
|
/// type statically sized along dimension that in the original operation where
|
|
/// defined as dynamic, but the size was defined using a `constant` op. For
|
|
/// example
|
|
///
|
|
/// %c5 = arith.constant 5: index
|
|
/// %0 = linalg.init_tensor [%arg0, %c5] : tensor<?x?xf32>
|
|
///
|
|
/// to
|
|
///
|
|
/// %0 = linalg.init_tensor [%arg0, 5] : tensor<?x5xf32>
|
|
struct ReplaceStaticShapeDims : OpRewritePattern<InitTensorOp> {
|
|
using OpRewritePattern<InitTensorOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(InitTensorOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
SmallVector<Value, 4> dynamicSizes;
|
|
SmallVector<int64_t, 4> staticSizes;
|
|
for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) {
|
|
// If the size is already static, nothing to do.
|
|
if (!op.isDynamicSize(i)) {
|
|
staticSizes.push_back(op.getStaticSize(i));
|
|
continue;
|
|
}
|
|
|
|
// If the size is dynamic but defined using a `constant` op, get the
|
|
// constant value to find the static size to use.
|
|
unsigned operandNum = op.getIndexOfDynamicSize(i);
|
|
Value sizeOperand = op.getOperand(operandNum);
|
|
if (auto constantIndexOp =
|
|
sizeOperand.getDefiningOp<arith::ConstantIndexOp>()) {
|
|
staticSizes.push_back(constantIndexOp.value());
|
|
continue;
|
|
}
|
|
|
|
// Fallback case. Keep the size dynamic.
|
|
dynamicSizes.push_back(sizeOperand);
|
|
staticSizes.push_back(ShapedType::kDynamicSize);
|
|
}
|
|
RankedTensorType newType =
|
|
RankedTensorType::get(staticSizes, op.getType().getElementType());
|
|
if (newType == op.getType())
|
|
return failure();
|
|
auto newOp =
|
|
rewriter.create<InitTensorOp>(op.getLoc(), newType, dynamicSizes,
|
|
rewriter.getI64ArrayAttr(staticSizes));
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
|
|
return success();
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
namespace {
|
|
/// Since `init_tensor` operation creates a tensor needed only for its shape, a
|
|
/// slice of this is also needed only for its shape. The result can be
|
|
/// replaced by a new init_tensor operation of the same size as the extract
|
|
/// slice op.
|
|
struct FoldInitTensorWithExtractSliceOp
|
|
: public OpRewritePattern<tensor::ExtractSliceOp> {
|
|
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!sliceOp.getSource().getDefiningOp<linalg::InitTensorOp>())
|
|
return failure();
|
|
// ExtractSliceOp may be rank-reducing; its dynamic sizes must be preserved
|
|
// as well as its result type.
|
|
rewriter.replaceOpWithNewOp<linalg::InitTensorOp>(
|
|
sliceOp, sliceOp.getSizes(),
|
|
sliceOp.getResult().getType().cast<RankedTensorType>().getShape(),
|
|
sliceOp.getSourceType().getElementType());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <typename TensorReshapeOp>
|
|
struct FoldInitTensorWithTensorReshapeOp
|
|
: public OpRewritePattern<TensorReshapeOp> {
|
|
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!reshapeOp.getSrc().template getDefiningOp<InitTensorOp>())
|
|
return failure();
|
|
Location loc = reshapeOp.getLoc();
|
|
ReifiedRankedShapedTypeDims resultShapes;
|
|
ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface =
|
|
cast<ReifyRankedShapedTypeOpInterface>(reshapeOp.getOperation());
|
|
if (failed(reifyShapedTypeInterface.reifyResultShapes(rewriter,
|
|
resultShapes)) ||
|
|
!llvm::hasSingleElement(resultShapes))
|
|
return failure();
|
|
Value initTensor = rewriter.create<InitTensorOp>(
|
|
loc, getAsOpFoldResult(resultShapes[0]),
|
|
reshapeOp.getResultType().getElementType());
|
|
if (initTensor.getType() != reshapeOp.getResultType()) {
|
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
|
reshapeOp, reshapeOp.getResultType(), initTensor);
|
|
} else {
|
|
rewriter.replaceOp(reshapeOp, initTensor);
|
|
}
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct FoldInitTensorWithDimOp : public OpRewritePattern<tensor::DimOp> {
|
|
using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
|
|
auto initTensorOp = dimOp.getSource().getDefiningOp<linalg::InitTensorOp>();
|
|
if (!initTensorOp || !maybeConstantIndex)
|
|
return failure();
|
|
if (!initTensorOp.isDynamicSize(*maybeConstantIndex))
|
|
return failure();
|
|
rewriter.replaceOp(dimOp, initTensorOp.getDynamicSize(*maybeConstantIndex));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Canonicalize
|
|
///
|
|
/// ```mlir
|
|
/// %0 = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
|
|
/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
|
|
/// ```
|
|
///
|
|
/// into
|
|
///
|
|
/// ```mlir
|
|
/// %0 = linalg.init_tensor [4, %d1] : tensor<4x?xf32>
|
|
/// ```
|
|
///
|
|
/// This assumes the input program is correct in terms of its shape. So it
|
|
/// is safe to assume that `%d0` is in fact 4. If that was not the case, the
|
|
/// input program is wrong to begin with, so its undefined behavior anyway (i.e.
|
|
/// this optimization can still triggering without violating program semantics).
|
|
struct FoldInitTensorWithTensorCastOp
|
|
: public OpRewritePattern<tensor::CastOp> {
|
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!canFoldIntoProducerOp(castOp))
|
|
return failure();
|
|
auto producer = castOp.getSource().getDefiningOp<InitTensorOp>();
|
|
if (!producer)
|
|
return failure();
|
|
|
|
auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
|
|
ArrayRef<int64_t> resultShape = resultType.getShape();
|
|
SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
|
|
SmallVector<OpFoldResult> newMixedSizes;
|
|
newMixedSizes.reserve(currMixedSizes.size());
|
|
assert(resultShape.size() == currMixedSizes.size() &&
|
|
"mismatch in result shape and sizes of init_tensor op");
|
|
for (auto it : llvm::zip(resultShape, currMixedSizes)) {
|
|
int64_t newDim = std::get<0>(it);
|
|
OpFoldResult currDim = std::get<1>(it);
|
|
// Case 1: The init tensor dim is static. Check that the tensor cast
|
|
// result dim matches.
|
|
if (auto attr = currDim.dyn_cast<Attribute>()) {
|
|
if (ShapedType::isDynamic(newDim) ||
|
|
newDim != attr.cast<IntegerAttr>().getInt()) {
|
|
// Something is off, the cast result shape cannot be more dynamic than
|
|
// the init tensor result shape (enforced by `canFoldIntoProducer`).
|
|
// Abort for now.
|
|
return rewriter.notifyMatchFailure(
|
|
producer, "mismatch in static value of shape of init "
|
|
"tensor result and cast result");
|
|
}
|
|
newMixedSizes.push_back(attr);
|
|
continue;
|
|
}
|
|
|
|
// Case 2 : The tensor cast shape is static, but init tensor result shape
|
|
// is dynamic.
|
|
if (!ShapedType::isDynamic(newDim)) {
|
|
newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
|
|
continue;
|
|
}
|
|
|
|
// Case 3 : The tensor cast shape is dynamic and init tensor result shape
|
|
// is dynamic. Use the dynamic value from the init tensor op.
|
|
newMixedSizes.push_back(currDim);
|
|
}
|
|
|
|
rewriter.replaceOpWithNewOp<InitTensorOp>(castOp, newMixedSizes,
|
|
resultType.getElementType());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void InitTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|
MLIRContext *context) {
|
|
results.add<FoldInitTensorWithTensorCastOp, FoldInitTensorWithDimOp,
|
|
FoldInitTensorWithExtractSliceOp,
|
|
FoldInitTensorWithTensorReshapeOp<tensor::ExpandShapeOp>,
|
|
FoldInitTensorWithTensorReshapeOp<tensor::CollapseShapeOp>,
|
|
ReplaceStaticShapeDims>(context);
|
|
}
|
|
|
|
LogicalResult InitTensorOp::reifyResultShapes(
|
|
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
|
|
auto shapes = llvm::to_vector<4>(llvm::map_range(
|
|
llvm::seq<int64_t>(0, getType().getRank()), [&](int64_t dim) -> Value {
|
|
if (isDynamicSize(dim))
|
|
return getDynamicSize(dim);
|
|
return builder.create<arith::ConstantIndexOp>(getLoc(),
|
|
getStaticSize(dim));
|
|
}));
|
|
reifiedReturnShapes.emplace_back(std::move(shapes));
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// YieldOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void linalg::YieldOp::print(OpAsmPrinter &p) {
|
|
if (getNumOperands() > 0)
|
|
p << ' ' << getOperands();
|
|
p.printOptionalAttrDict((*this)->getAttrs());
|
|
if (getNumOperands() > 0)
|
|
p << " : " << getOperandTypes();
|
|
}
|
|
|
|
ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
|
|
SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
|
|
SmallVector<Type, 2> types;
|
|
SMLoc loc = parser.getCurrentLocation();
|
|
return failure(parser.parseOperandList(opInfo) ||
|
|
parser.parseOptionalAttrDict(result.attributes) ||
|
|
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
|
|
parser.resolveOperands(opInfo, types, loc, result.operands));
|
|
}
|
|
|
|
// Check the operand number and types must match the element types of the
|
|
// LinalgOp interface's shaped operands.
|
|
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
|
|
if (op.getNumOperands() != linalgOp.getNumOutputs())
|
|
return op.emitOpError("expected number of yield values (")
|
|
<< linalgOp.getNumOutputs()
|
|
<< ") to match the number of operands of the enclosing "
|
|
<< "LinalgOp (" << op.getNumOperands() << ")";
|
|
|
|
for (OpOperand &opOperand : op->getOpOperands()) {
|
|
OpOperand *outputOperand =
|
|
linalgOp.getOutputOperand(opOperand.getOperandNumber());
|
|
Type elementType = getElementTypeOrSelf(outputOperand->get().getType());
|
|
if (opOperand.get().getType() != elementType)
|
|
return op.emitOpError("type of yield operand ")
|
|
<< (opOperand.getOperandNumber() + 1) << " ("
|
|
<< opOperand.get().getType() << ") doesn't match "
|
|
<< "the element type of the enclosing linalg.generic op ("
|
|
<< elementType << ")";
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult linalg::YieldOp::verify() {
|
|
auto *parentOp = (*this)->getParentOp();
|
|
if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
|
|
return emitOpError("expected single non-empty parent region");
|
|
|
|
if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
|
|
return verifyYield(*this, linalgOp);
|
|
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// IndexOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
LogicalResult IndexOp::verify() {
|
|
auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
|
|
if (!linalgOp)
|
|
return emitOpError("expected parent op with LinalgOp interface");
|
|
if (linalgOp.getNumLoops() <= getDim())
|
|
return emitOpError("expected dim (")
|
|
<< getDim() << ") to be lower than the number of loops ("
|
|
<< linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
|
|
return success();
|
|
}
|
|
|
|
/////// Operations corresponding to library calls defined with Tablegen ////////
|
|
|
|
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
|
|
|
|
/// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`.
|
|
/// Assumes `op` is a LinalgOp.
|
|
void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName,
|
|
SmallVectorImpl<unsigned> &res) {
|
|
if (!cast<LinalgOp>(op).iterator_types())
|
|
return;
|
|
|
|
unsigned dim = 0;
|
|
for (auto tn :
|
|
cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) {
|
|
if (tn == iteratorTypeName)
|
|
res.push_back(dim);
|
|
++dim;
|
|
}
|
|
}
|
|
|
|
AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
|
|
unsigned rank,
|
|
MLIRContext *context) {
|
|
if (maybeMap)
|
|
return *maybeMap;
|
|
if (rank == 0)
|
|
return AffineMap::get(context);
|
|
return AffineMap::getMultiDimIdentityMap(rank, context);
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4>
|
|
mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
|
|
MLIRContext *context) {
|
|
SmallVector<AffineExpr, 4> res;
|
|
res.reserve(num);
|
|
for (unsigned i = 0; i < num; ++i)
|
|
res.push_back(getAffineDimExpr(startIdx++, context));
|
|
return res;
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
|
|
ArrayRef<AffineExpr> b) {
|
|
auto rangeA = llvm::make_range(a.begin(), a.end());
|
|
auto rangeB = llvm::make_range(b.begin(), b.end());
|
|
auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB);
|
|
return llvm::to_vector<4>(concatRanges);
|
|
}
|
|
|
|
static void appendMangledType(llvm::raw_string_ostream &ss, Type t) {
|
|
if (auto memref = t.dyn_cast<MemRefType>()) {
|
|
ss << "view";
|
|
for (auto size : memref.getShape())
|
|
if (size < 0)
|
|
ss << "sx";
|
|
else
|
|
ss << size << "x";
|
|
appendMangledType(ss, memref.getElementType());
|
|
} else if (auto vec = t.dyn_cast<VectorType>()) {
|
|
ss << "vector";
|
|
llvm::interleave(
|
|
vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
|
|
appendMangledType(ss, vec.getElementType());
|
|
} else if (t.isSignlessIntOrIndexOrFloat()) {
|
|
ss << t;
|
|
} else {
|
|
llvm_unreachable("Invalid type for linalg library name mangling");
|
|
}
|
|
}
|
|
|
|
std::string mlir::linalg::generateLibraryCallName(Operation *op) {
|
|
assert(isa<LinalgOp>(op));
|
|
std::string name(op->getName().getStringRef().str());
|
|
name.reserve(128);
|
|
std::replace(name.begin(), name.end(), '.', '_');
|
|
llvm::raw_string_ostream ss(name);
|
|
ss << "_";
|
|
auto types = op->getOperandTypes();
|
|
llvm::interleave(
|
|
types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); },
|
|
[&]() { ss << "_"; });
|
|
return ss.str();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Canonicalizers and Folders.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
for (OpOperand *opOperand : op.getInputAndOutputOperands()) {
|
|
// Linalg "inputs" may be either tensor or memref type.
|
|
// tensor<0xelt_type> is a convention that may not always mean
|
|
// "0 iterations". Only erase in cases we see memref<...x0x...>.
|
|
auto mt = opOperand->get().getType().dyn_cast<MemRefType>();
|
|
if (!mt)
|
|
continue;
|
|
if (llvm::is_contained(op.getShape(opOperand), 0)) {
|
|
rewriter.eraseOp(op);
|
|
return success();
|
|
}
|
|
}
|
|
return failure();
|
|
}
|
|
};
|
|
|
|
struct FoldTensorCastProducerOp : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp op,
|
|
PatternRewriter &rewriter) const override {
|
|
// If no operand comes from a tensor::CastOp and can be folded then fail.
|
|
bool hasTensorCastOperand =
|
|
llvm::any_of(op.getInputAndOutputOperands(), [&](OpOperand *opOperand) {
|
|
if (opOperand->get().isa<BlockArgument>())
|
|
return false;
|
|
auto castOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
return castOp && canFoldIntoConsumerOp(castOp);
|
|
});
|
|
if (!hasTensorCastOperand)
|
|
return failure();
|
|
|
|
SmallVector<Type, 4> newResultTypes;
|
|
newResultTypes.reserve(op->getNumResults());
|
|
SmallVector<Value, 4> newOperands;
|
|
newOperands.reserve(op->getNumOperands());
|
|
// Inputs may fold.
|
|
for (OpOperand *opOperand : op.getInputOperands()) {
|
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
newOperands.push_back(canFoldIntoConsumerOp(tensorCastOp)
|
|
? tensorCastOp.getSource()
|
|
: opOperand->get());
|
|
}
|
|
// Init tensors may fold, in which case the resultType must also change.
|
|
for (OpOperand *opOperand : op.getOutputOperands()) {
|
|
auto tensorCastOp = opOperand->get().getDefiningOp<tensor::CastOp>();
|
|
bool fold = canFoldIntoConsumerOp(tensorCastOp);
|
|
newOperands.push_back(fold ? tensorCastOp.getOperand()
|
|
: opOperand->get());
|
|
newResultTypes.push_back(newOperands.back().getType());
|
|
}
|
|
// Clone op.
|
|
Operation *newOp =
|
|
op.clone(rewriter, op->getLoc(), newResultTypes, newOperands);
|
|
SmallVector<Value, 4> replacements;
|
|
replacements.reserve(newOp->getNumResults());
|
|
for (auto result : llvm::zip(op->getResults(), newOp->getResults())) {
|
|
Value oldResult = std::get<0>(result);
|
|
Value newResult = std::get<1>(result);
|
|
if (newResult.getType() != oldResult.getType()) {
|
|
replacements.push_back(rewriter.create<tensor::CastOp>(
|
|
op->getLoc(), oldResult.getType(), newResult));
|
|
} else {
|
|
replacements.push_back(newResult);
|
|
}
|
|
}
|
|
rewriter.replaceOp(op, replacements);
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
|
|
/// result that is more static than the linalg op.
|
|
struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
|
|
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(tensor::CastOp castOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!tensor::canFoldIntoProducerOp(castOp))
|
|
return failure();
|
|
|
|
auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
|
|
if (!linalgOp)
|
|
return failure();
|
|
|
|
// Cast can be in conditionally reachable region, if which case folding will
|
|
// generate invalid code. Only conservatively fold ops in same block for
|
|
// now.
|
|
if (castOp->getBlock() != linalgOp->getBlock())
|
|
return failure();
|
|
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(linalgOp);
|
|
|
|
Location loc = linalgOp.getLoc();
|
|
OpResult resultValue = castOp.getSource().cast<OpResult>();
|
|
unsigned resultNumber = resultValue.getResultNumber();
|
|
auto resultType = castOp->getResult(0).getType().cast<RankedTensorType>();
|
|
// Replace the `outs` for the result with a `tensor.cast`. This cast is now
|
|
// going from a more dynamic shape to a less dynamic shape. If the producer
|
|
// for this cast, i.e. producer of the out operand, is also an operation
|
|
// that folds with tensor.cast consumer (like this pattern), the cast will
|
|
// continue to propagate as far up the stack as it can go.
|
|
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
|
|
Value newOperand =
|
|
rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
|
|
SmallVector<Value> newOperands = linalgOp.getInputOperands();
|
|
SmallVector<Value> outputOperands = linalgOp.getOutputOperands();
|
|
outputOperands[resultNumber] = newOperand;
|
|
newOperands.append(outputOperands.begin(), outputOperands.end());
|
|
|
|
SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
|
|
linalgOp->result_type_end());
|
|
resultTypes[resultNumber] = resultType;
|
|
Operation *newOp = linalgOp.clone(rewriter, loc, resultTypes, newOperands);
|
|
|
|
// Create a tensor.cast operation back to the original type.
|
|
Value castBack = rewriter.create<tensor::CastOp>(
|
|
loc, resultValue.getType(), newOp->getResult(resultNumber));
|
|
|
|
SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
|
|
results[resultNumber] = castBack;
|
|
rewriter.replaceOp(linalgOp, results);
|
|
rewriter.replaceOp(castOp, newOp->getResult(resultNumber));
|
|
return success();
|
|
}
|
|
};
|
|
|
|
/// For each of the operand in `operands` this function maps the static sizes of
|
|
/// dimensions to their affine dim expressions.
|
|
static void populateMap(LinalgOp linalgOp, ArrayRef<OpOperand *> operands,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
|
|
for (OpOperand *opOperand : operands) {
|
|
if (linalgOp.isScalar(opOperand))
|
|
continue;
|
|
Value src = opOperand->get();
|
|
auto sourceType = src.getType().cast<RankedTensorType>();
|
|
auto sourceMap = linalgOp.getTiedIndexingMap(opOperand);
|
|
|
|
// Get the `sourceShape` of the `sourceType`. If the operand is a result of
|
|
// `tensor.cast` operation and source of the cast operation has a static
|
|
// shape, then assign it to the `sourceShape`.
|
|
auto *parentOp = src.getDefiningOp();
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
if (parentOp) {
|
|
if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
|
|
Value castSource = castOp.getSource();
|
|
auto castSourceType = castSource.getType().cast<RankedTensorType>();
|
|
if (castSourceType.hasStaticShape())
|
|
sourceShape = castSourceType.getShape();
|
|
}
|
|
}
|
|
|
|
// If the source shape's dimension has a static shape, map the affine dim
|
|
// expression to the known static size.
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
if (sourceType.isDynamicDim(i))
|
|
continue;
|
|
if (auto affineDimExpr = sourceMap.getResult(i).dyn_cast<AffineDimExpr>())
|
|
affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
|
|
/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
|
|
/// their result types is stored in `resultTypes`. If `opOperand` requires no
|
|
/// change then `changeNeeded` is false and same operand is added in the
|
|
/// `newOperands` list.
|
|
static void createNewOperandWithStaticSizes(
|
|
Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
|
|
llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
|
|
SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
|
|
bool &changeNeeded) {
|
|
Value src = opOperand->get();
|
|
newOperands.push_back(src);
|
|
if (linalgOp.isScalar(opOperand))
|
|
return;
|
|
auto sourceType = src.getType().cast<RankedTensorType>();
|
|
Type resultType = sourceType;
|
|
if (sourceType.hasStaticShape() && linalgOp.isOutputTensor(opOperand)) {
|
|
resultTypes.push_back(resultType);
|
|
return;
|
|
}
|
|
ArrayRef<int64_t> sourceShape = sourceType.getShape();
|
|
AffineMap sourceMap = linalgOp.getTiedIndexingMap(opOperand);
|
|
SmallVector<int64_t> newShape;
|
|
// If operand is updated with new shape, `newOperandNeeded` will be
|
|
// true.
|
|
bool newOperandNeeded = false;
|
|
for (unsigned i = 0; i < sourceShape.size(); i++) {
|
|
int64_t dimShape = sourceShape[i];
|
|
AffineExpr dimExpr = sourceMap.getResult(i);
|
|
if (affineExprToSize.find(dimExpr) == affineExprToSize.end() ||
|
|
!sourceType.isDynamicDim(i)) {
|
|
newShape.push_back(dimShape);
|
|
continue;
|
|
}
|
|
// Dimension has a dynamic shape and corresponding affine dim
|
|
// expression is present in the map. So assign the size for the
|
|
// given affine dim expression to the dimension.
|
|
newShape.push_back(affineExprToSize[dimExpr]);
|
|
newOperandNeeded = true;
|
|
}
|
|
resultType = RankedTensorType::get(newShape, sourceType.getElementType());
|
|
if (newOperandNeeded) {
|
|
changeNeeded = true;
|
|
// Get the new operand value given its size and element type by
|
|
// casting it.
|
|
Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
|
|
unsigned index = opOperand->getOperandNumber();
|
|
newOperands[index] = newOperand;
|
|
}
|
|
if (linalgOp.isOutputTensor(opOperand))
|
|
resultTypes.push_back(resultType);
|
|
}
|
|
|
|
/// Static shapes for the operands can be inferred if any one of the operands
|
|
/// have a static shape. This can be done by referring to the affine dim
|
|
/// expressions for the operand.
|
|
struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
|
|
using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(LinalgOp linalgOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!linalgOp.hasTensorSemantics())
|
|
return failure();
|
|
|
|
// Maps must be projected permutations.
|
|
if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
|
|
return !map.isProjectedPermutation();
|
|
}))
|
|
return failure();
|
|
|
|
// Maps affine dim expressions to the static size of that dimension.
|
|
llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
|
|
Location loc = linalgOp.getLoc();
|
|
|
|
// For each of the affine dim expression, check if the size is known. If
|
|
// known add that in the map.
|
|
populateMap(linalgOp, linalgOp.getInputAndOutputOperands(),
|
|
affineExprToSize);
|
|
|
|
SmallVector<Value> newOperands;
|
|
SmallVector<Type> resultTypes;
|
|
|
|
// `changeNeeded` is `false` if the operands of `linalgOp` require no
|
|
// change in their types.
|
|
bool changeNeeded = false;
|
|
newOperands.reserve(linalgOp.getNumInputsAndOutputs());
|
|
resultTypes.reserve(linalgOp.getNumOutputs());
|
|
|
|
// Iterate over all the operands and update the static sizes.
|
|
for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) {
|
|
createNewOperandWithStaticSizes(loc, rewriter, opOperand,
|
|
affineExprToSize, linalgOp, newOperands,
|
|
resultTypes, changeNeeded);
|
|
}
|
|
|
|
// If the generic op has all the required static information, no
|
|
// canonicalization needed.
|
|
if (!changeNeeded)
|
|
return failure();
|
|
|
|
// Clone op.
|
|
Operation *newOp =
|
|
linalgOp.clone(rewriter, linalgOp->getLoc(), resultTypes, newOperands);
|
|
SmallVector<Value> replacements;
|
|
replacements.reserve(newOp->getNumResults());
|
|
for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
|
|
Value newResult = std::get<1>(it);
|
|
Value oldResult = std::get<0>(it);
|
|
Type newType = newResult.getType();
|
|
Type oldType = oldResult.getType();
|
|
replacements.push_back(
|
|
(newType != oldType)
|
|
? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
|
|
: newResult);
|
|
}
|
|
rewriter.replaceOp(linalgOp, replacements);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
// All named ops canonicalizers and folders are auto-generated in the
|
|
// .cpp.inc.
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// LinalgDialect
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
void LinalgDialect::getCanonicalizationPatterns(
|
|
RewritePatternSet &results) const {
|
|
results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
|
|
FoldTensorCastProducerOp, InferStaticShapeOfOperands>(
|
|
getContext());
|
|
}
|
|
|
|
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
|
|
Attribute value, Type type,
|
|
Location loc) {
|
|
return builder.create<arith::ConstantOp>(loc, type, value);
|
|
}
|