[Linalg] Add a slice op
This CL adds a linalg.slice op with the proper roundtripping test. A slice op allows taking subviews that may be rank-reducing (if some indexing is of index type) or not (if all indexings are of linalg.range type). A slice must be constructed directly from a base view (no chains of slices may exist in the IR). Helper functions that fold will be provided for construction if/when necessary. This also renames base_view to view. -- PiperOrigin-RevId: 244406827
This commit is contained in:
parent
1d5dc840e7
commit
0b47f74037
|
@ -24,53 +24,14 @@
|
|||
|
||||
namespace mlir {
|
||||
|
||||
/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
|
||||
/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
|
||||
/// buffer an indexing structure.
|
||||
///
|
||||
/// A new value of ViewType is constructed from a buffer with a base_view op and
|
||||
/// ranges:
|
||||
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
|
||||
/// upon which a base view can be laid out to give it indexing semantics.
|
||||
/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
|
||||
/// (in number of elements).
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
|
||||
/// ```
|
||||
class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.base_view"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *buffer,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
mlir::Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
mlir::Value *getSupportingBuffer() { return getOperand(0); }
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank) {
|
||||
return *(getIndexings().begin() + rank);
|
||||
}
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings() {
|
||||
return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
|
||||
/// view can be laid out. The size argument is an `i64` (and not an index), so
|
||||
/// that we can
|
||||
class BufferAllocOp
|
||||
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
|
||||
public:
|
||||
|
@ -89,7 +50,11 @@ public:
|
|||
Type getElementType() { return getBufferType().getElementType(); }
|
||||
};
|
||||
|
||||
/// A BufferDeallocOp is used to free a !linalg.buffer.
|
||||
/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
|
||||
/// ```
|
||||
class BufferDeallocOp
|
||||
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
|
||||
public:
|
||||
|
@ -109,8 +74,12 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// A RangeOp is used to create a value of RangeType from 3 values of type index
|
||||
/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
|
||||
/// that represent the min, max and step values of the range.
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %3 = linalg.range %0:%1:%2 : !linalg.range
|
||||
/// ```
|
||||
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
|
||||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
|
@ -130,6 +99,126 @@ public:
|
|||
Value *step() { return getOperand(2); }
|
||||
};
|
||||
|
||||
/// The "linalg.slice" op produces a linalg.view which is a subview of a given
|
||||
/// base view. This allows defining a subregion within the underlying buffer to
|
||||
/// operate on only a subset of the buffer.
|
||||
///
|
||||
/// A "linalg.slice" op takes a base view and a variadic number of indexings and
|
||||
/// produces a linalg.view of the same elemental type as the buffer. An indexing
|
||||
/// is either:
|
||||
/// 1. a linalg.range, in which case it does not reduce the rank of the parent
|
||||
/// view.
|
||||
/// 2. an index, in which case it reduces the rank of the parent view by one.
|
||||
///
|
||||
/// The parent view must be a base view (i.e. either a function argument or has
|
||||
/// been produced by a linalg.view op). In other words, chains of
|
||||
/// linalg.slice operations cannot be constructed in the IR. This defines away
|
||||
/// problems related to keeping track of which dimensions of the base view have
|
||||
/// been rank-reduced.
|
||||
///
|
||||
/// Examples:
|
||||
/// 1. rank-preserving slice:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
|
||||
/// !linalg.range, !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// 2. rank-reducing slice (from 2-D to 1-D):
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
|
||||
/// !linalg.range, !linalg.view<?xf32>
|
||||
/// ```
|
||||
///
|
||||
/// 3. rank-reducing slice (from 2-D to 0-D):
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
|
||||
/// !linalg.view<f32>
|
||||
/// ```
|
||||
class ViewOp;
|
||||
class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.slice"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
mlir::Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
Value *getBaseView() { return getOperand(0); }
|
||||
ViewOp getBaseViewOp();
|
||||
ViewType getBaseViewType();
|
||||
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank) {
|
||||
return *(getIndexings().begin() + rank);
|
||||
}
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings() {
|
||||
return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
||||
// Get the subset of indexings that are of RangeType.
|
||||
SmallVector<Value *, 8> getRanges();
|
||||
};
|
||||
|
||||
/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
|
||||
/// range abstraction on top of an underlying linalg.buffer. This gives an
|
||||
/// indexing structure to an otherwise non-indexable linalg.buffer.
|
||||
///
|
||||
/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
|
||||
/// a `view` of the same elemental type as the buffer and of rank the number of
|
||||
/// ranges:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
|
||||
mlir::OpTrait::OneResult,
|
||||
mlir::OpTrait::HasNoSideEffect> {
|
||||
enum { FirstIndexingOperand = 1 };
|
||||
|
||||
public:
|
||||
using Op::Op;
|
||||
|
||||
// Hooks to customize the behavior of this op.
|
||||
static llvm::StringRef getOperationName() { return "linalg.view"; }
|
||||
static void build(mlir::Builder *b, mlir::OperationState *result,
|
||||
mlir::Value *buffer,
|
||||
llvm::ArrayRef<mlir::Value *> indexings);
|
||||
mlir::LogicalResult verify();
|
||||
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
|
||||
void print(mlir::OpAsmPrinter *p);
|
||||
|
||||
// Op-specific functionality.
|
||||
unsigned getRank() { return getViewType().getRank(); }
|
||||
mlir::Type getElementType() { return getViewType().getElementType(); }
|
||||
ViewType getViewType() { return getType().cast<ViewType>(); }
|
||||
mlir::Value *getSupportingBuffer() { return getOperand(0); }
|
||||
// Get the underlying indexing at a given rank.
|
||||
mlir::Value *getIndexing(unsigned rank) {
|
||||
return *(getIndexings().begin() + rank);
|
||||
}
|
||||
// Get all the indexings in this view.
|
||||
mlir::Operation::operand_range getIndexings() {
|
||||
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // MLIR_LINALG_LINALGOPS_H_
|
||||
|
|
|
@ -42,7 +42,9 @@ public:
|
|||
void printType(Type type, llvm::raw_ostream &os) const override;
|
||||
};
|
||||
|
||||
/// A BufferType represents a minimal range abstraction (min, max, step).
|
||||
/// A BufferType represents a contiguous block of memory that can be allocated
|
||||
/// and deallocated. A buffer cannot be indexed directly, a view must be
|
||||
/// laid out on a buffer to give it indexing semantics.
|
||||
class BufferTypeStorage;
|
||||
class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
|
||||
public:
|
||||
|
@ -58,6 +60,14 @@ public:
|
|||
};
|
||||
|
||||
/// A RangeType represents a minimal range abstraction (min, max, step).
|
||||
/// It is constructed by calling the linalg.range op with three values index of
|
||||
/// index type:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
|
||||
/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
/// }
|
||||
/// ```
|
||||
class RangeType : public Type::TypeBase<RangeType, Type> {
|
||||
public:
|
||||
// Used for generic hooks in TypeBase.
|
||||
|
@ -74,13 +84,13 @@ public:
|
|||
/// A ViewType represents a multi-dimensional range abstraction on top of an
|
||||
/// underlying storage type. It is parameterizable by the underlying element
|
||||
/// type and the rank of the view.
|
||||
/// A new value of ViewType is constructed from a buffer with a base_view op and
|
||||
/// A new value of ViewType is constructed from a buffer with a view op and
|
||||
/// passing it ranges:
|
||||
///
|
||||
/// ```{.mlir}
|
||||
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
/// ```
|
||||
class ViewTypeStorage;
|
||||
class ViewType
|
||||
|
|
|
@ -29,89 +29,6 @@
|
|||
|
||||
using namespace mlir;
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// BaseViewOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
|
||||
ArrayRef<Value *> indexings) {
|
||||
BufferType bufferType = buffer->getType().cast<BufferType>();
|
||||
result->addOperands({buffer});
|
||||
result->addOperands(indexings);
|
||||
assert(
|
||||
std::none_of(indexings.begin(), indexings.end(),
|
||||
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
|
||||
"linalg.base_view takes only arguments of type linalg.range");
|
||||
|
||||
Type elementType = bufferType.getElementType();
|
||||
result->addTypes(
|
||||
{ViewType::get(b->getContext(), elementType, indexings.size())});
|
||||
}
|
||||
|
||||
LogicalResult mlir::BaseViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a buffer operand followed by indexings");
|
||||
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
|
||||
if (!bufferType)
|
||||
return emitOpError("first operand must be of BufferType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>()) {
|
||||
return emitOpError(Twine(index) + "^th index must be of range type");
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (getViewType().getRank() != index)
|
||||
return emitOpError(
|
||||
"the rank of the base view must be the number of its indexings");
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
Type type;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, -1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type))
|
||||
return true;
|
||||
|
||||
ViewType viewType = type.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
if (viewType.getRank() != indexingsInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected" + Twine(viewType.getRank()) +
|
||||
" range indexings");
|
||||
return parser->resolveOperand(
|
||||
bufferInfo,
|
||||
BufferType::get(type.getContext(), viewType.getElementType()),
|
||||
result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo,
|
||||
RangeType::get(type.getContext()),
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types);
|
||||
}
|
||||
|
||||
// A BaseViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void mlir::BaseViewOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
|
||||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(),
|
||||
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
|
||||
*p << "] : " << getType();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// BufferAllocOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -122,9 +39,8 @@ void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
|
|||
}
|
||||
|
||||
mlir::LogicalResult mlir::BufferAllocOp::verify() {
|
||||
if (!size() || !size()->getType().isa<IntegerType>() ||
|
||||
!size()->getType().cast<IntegerType>().isInteger(64))
|
||||
return emitOpError("first operand should be of type i64");
|
||||
if (!size() || !size()->getType().isa<IndexType>())
|
||||
return emitOpError("first operand should be of type index");
|
||||
if (!VectorType::isValidElementType(getElementType()) &&
|
||||
!getElementType().isa<VectorType>())
|
||||
return emitOpError("unsupported buffer element type");
|
||||
|
@ -143,14 +59,14 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
|
|||
bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType sizeInfo;
|
||||
BufferType bufferType;
|
||||
auto int64Ty = parser->getBuilder().getIntegerType(64);
|
||||
auto indexTy = parser->getBuilder().getIndexType();
|
||||
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
|
||||
return true;
|
||||
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
|
||||
return parser->emitError(
|
||||
parser->getNameLoc(),
|
||||
"Only buffer<f32> supported until mlir::Parser pieces are exposed");
|
||||
return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
|
||||
return parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
|
||||
parser->addTypeToList(bufferType, result->types);
|
||||
}
|
||||
|
||||
|
@ -183,7 +99,6 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
|
||||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RangeOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -224,3 +139,218 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// SliceOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
|
||||
ArrayRef<Value *> indexings) {
|
||||
result->addOperands({base});
|
||||
result->addOperands(indexings);
|
||||
|
||||
ViewType viewType = base->getType().cast<ViewType>();
|
||||
unsigned rank = viewType.getRank();
|
||||
for (auto *i : indexings)
|
||||
if (!i->getType().isa<RangeType>())
|
||||
rank--;
|
||||
Type elementType = viewType.getElementType();
|
||||
result->addTypes(
|
||||
{ViewType::get(b->getContext(), elementType, indexings.size())});
|
||||
}
|
||||
|
||||
LogicalResult mlir::SliceOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a view operand followed by 'rank' indices");
|
||||
if (!getOperand(0)->getDefiningOp()->isa<ViewOp>())
|
||||
return emitOpError(
|
||||
"requires at least a view operand followed by 'rank' indices");
|
||||
|
||||
auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast<ViewOp>();
|
||||
if (!viewOp)
|
||||
return emitOpError("first operand must come from a ViewOp");
|
||||
unsigned rank = getBaseViewRank();
|
||||
if (llvm::size(getIndexings()) != rank) {
|
||||
return emitOpError("requires at least a view operand followed by " +
|
||||
Twine(rank) + " indexings");
|
||||
}
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>() &&
|
||||
!indexing->getType().isa<IndexType>()) {
|
||||
return emitOpError(Twine(index) +
|
||||
"^th index must be of range or index type");
|
||||
}
|
||||
if (indexing->getType().isa<IndexType>())
|
||||
--rank;
|
||||
++index;
|
||||
}
|
||||
if (getRank() != rank) {
|
||||
return emitOpError("the rank of the view must be the number of its range "
|
||||
"indices: " +
|
||||
Twine(rank));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType baseInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
SmallVector<Type, 8> types;
|
||||
if (parser->parseOperand(baseInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, -1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonTypeList(types))
|
||||
return true;
|
||||
|
||||
if (types.size() != 2 + indexingsInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"unexpected number of types ");
|
||||
ViewType baseViewType = types[0].dyn_cast<ViewType>();
|
||||
if (!baseViewType)
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"view type expected for first type");
|
||||
if (indexingsInfo.size() != baseViewType.getRank())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected " + Twine(baseViewType.getRank()) +
|
||||
" indexings");
|
||||
ViewType viewType = types.back().dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
|
||||
ArrayRef<Type> indexingTypes =
|
||||
ArrayRef<Type>(types).drop_front(1).drop_back(1);
|
||||
if (indexingTypes.size() != baseViewType.getRank())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected " + Twine(baseViewType.getRank()) +
|
||||
" indexing types");
|
||||
return parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo, indexingTypes,
|
||||
indexingsInfo.front().location,
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types);
|
||||
}
|
||||
|
||||
// A SliceOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.slice %0[%1, %2] :
|
||||
// !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
|
||||
// ssa-value each holding a range.
|
||||
void mlir::SliceOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getBaseView() << "[";
|
||||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(),
|
||||
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
|
||||
*p << "] : " << getBaseViewType();
|
||||
for (auto indexing : getIndexings()) {
|
||||
*p << ", " << indexing->getType();
|
||||
}
|
||||
*p << ", " << getType();
|
||||
}
|
||||
|
||||
ViewOp mlir::SliceOp::getBaseViewOp() {
|
||||
return getOperand(0)->getDefiningOp()->cast<ViewOp>();
|
||||
}
|
||||
|
||||
ViewType mlir::SliceOp::getBaseViewType() {
|
||||
return getBaseViewOp().getType().cast<ViewType>();
|
||||
}
|
||||
|
||||
SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
|
||||
llvm::SmallVector<Value *, 8> res;
|
||||
for (auto *operand : getIndexings()) {
|
||||
if (!operand->getType().isa<IndexType>()) {
|
||||
res.push_back(operand);
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// ViewOp
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
|
||||
ArrayRef<Value *> indexings) {
|
||||
BufferType bufferType = buffer->getType().cast<BufferType>();
|
||||
result->addOperands({buffer});
|
||||
result->addOperands(indexings);
|
||||
assert(
|
||||
std::none_of(indexings.begin(), indexings.end(),
|
||||
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
|
||||
"linalg.view takes only arguments of type linalg.range");
|
||||
|
||||
Type elementType = bufferType.getElementType();
|
||||
result->addTypes(
|
||||
{ViewType::get(b->getContext(), elementType, indexings.size())});
|
||||
}
|
||||
|
||||
LogicalResult mlir::ViewOp::verify() {
|
||||
if (llvm::empty(getOperands()))
|
||||
return emitOpError(
|
||||
"requires at least a buffer operand followed by indexings");
|
||||
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
|
||||
if (!bufferType)
|
||||
return emitOpError("first operand must be of BufferType");
|
||||
unsigned index = 0;
|
||||
for (auto indexing : getIndexings()) {
|
||||
if (!indexing->getType().isa<RangeType>()) {
|
||||
return emitOpError(Twine(index) + "^th index must be of range type");
|
||||
}
|
||||
++index;
|
||||
}
|
||||
if (getViewType().getRank() != index)
|
||||
return emitOpError(
|
||||
"the rank of the view must be the number of its indexings");
|
||||
return success();
|
||||
}
|
||||
|
||||
bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType bufferInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
|
||||
Type type;
|
||||
if (parser->parseOperand(bufferInfo) ||
|
||||
parser->parseOperandList(indexingsInfo, -1,
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes) ||
|
||||
parser->parseColonType(type))
|
||||
return true;
|
||||
|
||||
ViewType viewType = type.dyn_cast<ViewType>();
|
||||
if (!viewType)
|
||||
return parser->emitError(parser->getNameLoc(), "view type expected");
|
||||
if (viewType.getRank() != indexingsInfo.size())
|
||||
return parser->emitError(parser->getNameLoc(),
|
||||
"expected" + Twine(viewType.getRank()) +
|
||||
" range indexings");
|
||||
return parser->resolveOperand(
|
||||
bufferInfo,
|
||||
BufferType::get(type.getContext(), viewType.getElementType()),
|
||||
result->operands) ||
|
||||
(!indexingsInfo.empty() &&
|
||||
parser->resolveOperands(indexingsInfo,
|
||||
RangeType::get(type.getContext()),
|
||||
result->operands)) ||
|
||||
parser->addTypeToList(viewType, result->types);
|
||||
}
|
||||
|
||||
// A ViewOp prints as:
|
||||
//
|
||||
// ```{.mlir}
|
||||
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
|
||||
// ```
|
||||
//
|
||||
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
|
||||
// holding a range.
|
||||
void mlir::ViewOp::print(OpAsmPrinter *p) {
|
||||
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
|
||||
interleave(
|
||||
getIndexings().begin(), getIndexings().end(),
|
||||
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
|
||||
*p << "] : " << getType();
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ using namespace mlir;
|
|||
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
|
||||
: Dialect("linalg", context) {
|
||||
addTypes<BufferType, RangeType, ViewType>();
|
||||
addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
|
||||
addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
|
||||
}
|
||||
|
||||
struct mlir::BufferTypeStorage : public mlir::TypeStorage {
|
||||
|
|
|
@ -7,28 +7,36 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
|
|||
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
|
||||
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
|
||||
|
||||
func @buffer(%arg0: i64, %arg1: i64) {
|
||||
%0 = muli %arg0, %arg0 : i64
|
||||
func @buffer(%arg0: index, %arg1: index) {
|
||||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
|
||||
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
|
||||
func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
|
||||
%0 = muli %arg0, %arg0 : i64
|
||||
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
|
||||
%0 = muli %arg0, %arg0 : index
|
||||
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
%3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
%4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
%7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
|
||||
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
|
||||
// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
|
||||
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
|
||||
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
|
||||
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
|
||||
// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
|
||||
// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
|
||||
// CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
|
||||
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
|
Loading…
Reference in New Issue