[Linalg] Add a slice op

This CL adds a linalg.slice op with the proper roundtripping test.
    A slice op allows taking subviews that may be rank-reducing (if some indexing is of index type) or not (if all indexings are of linalg.range type).

    A slice must be constructed directly from a base view (no chains of slices may exist in the IR). Helper functions that fold will be provided for construction if/when necessary.

    This also renames base_view to view.

--

PiperOrigin-RevId: 244406827
This commit is contained in:
Nicolas Vasilache 2019-04-19 12:55:34 -07:00 committed by Mehdi Amini
parent 1d5dc840e7
commit 0b47f74037
5 changed files with 386 additions and 149 deletions

View File

@ -24,53 +24,14 @@
namespace mlir {
/// A `BaseViewOp` produces a `ViewType` which is a multi-dimensional range
/// abstraction on top of an underlying linalg.buffer. A BaseViewOp gives a
/// buffer an indexing structure.
///
/// A new value of ViewType is constructed from a buffer with a base_view op and
/// ranges:
/// The "buffer_alloc" op creates a 1-D linalg.buffer of the specified type,
/// upon which a base view can be laid out to give it indexing semantics.
/// "buffer_alloc" takes a single argument, the size of the buffer to allocate
/// (in number of elements).
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
/// %0 = linalg.buffer_alloc %arg0 : !linalg.buffer<f32>
/// ```
class BaseViewOp : public mlir::Op<BaseViewOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.base_view"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *buffer,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
mlir::Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
mlir::Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
mlir::Value *getIndexing(unsigned rank) {
return *(getIndexings().begin() + rank);
}
// Get all the indexings in this view.
mlir::Operation::operand_range getIndexings() {
return {operand_begin() + BaseViewOp::FirstIndexingOperand, operand_end()};
}
};
/// A BufferAllocOp is used to create a 1-D !linalg.buffer upon which a base
/// view can be laid out. The size argument is an `i64` (and not an index), so
/// that we can
class BufferAllocOp
: public Op<BufferAllocOp, OpTrait::OneOperand, OpTrait::OneResult> {
public:
@ -89,7 +50,11 @@ public:
Type getElementType() { return getBufferType().getElementType(); }
};
/// A BufferDeallocOp is used to free a !linalg.buffer.
/// The "buffer_dealloc" op frees a 1-D linalg.buffer of the specified type.
///
/// ```{.mlir}
/// linalg.buffer_dealloc %0 : !linalg.buffer<f32>
/// ```
class BufferDeallocOp
: public Op<BufferDeallocOp, OpTrait::OneOperand, OpTrait::ZeroResult> {
public:
@ -109,8 +74,12 @@ public:
}
};
/// A RangeOp is used to create a value of RangeType from 3 values of type index
/// The "linalg.range" op creates a linalg.range from 3 values of type `index`
/// that represent the min, max and step values of the range.
///
/// ```{.mlir}
/// %3 = linalg.range %0:%1:%2 : !linalg.range
/// ```
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
@ -130,6 +99,126 @@ public:
Value *step() { return getOperand(2); }
};
/// The "linalg.slice" op produces a linalg.view which is a subview of a given
/// base view. This allows defining a subregion within the underlying buffer to
/// operate on only a subset of the buffer.
///
/// A "linalg.slice" op takes a base view and a variadic number of indexings and
/// produces a linalg.view of the same elemental type as the buffer. An indexing
/// is either:
/// 1. a linalg.range, in which case it does not reduce the rank of the parent
/// view.
/// 2. an index, in which case it reduces the rank of the parent view by one.
///
/// The parent view must be a base view (i.e. either a function argument or has
/// been produced by a linalg.view op). In other words, chains of
/// linalg.slice operations cannot be constructed in the IR. This defines away
/// problems related to keeping track of which dimensions of the base view have
/// been rank-reduced.
///
/// Examples:
/// 1. rank-preserving slice:
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, !linalg.range,
/// !linalg.range, !linalg.view<?x?xf32>
/// ```
///
/// 2. rank-reducing slice (from 2-D to 1-D):
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index,
/// !linalg.range, !linalg.view<?xf32>
/// ```
///
/// 3. rank-reducing slice (from 2-D to 0-D):
///
/// ```{.mlir}
/// %4 = linalg.slice %0[%1, %2] : !linalg.view<?x?xf32>, index, index,
/// !linalg.view<f32>
/// ```
class ViewOp;
class SliceOp : public mlir::Op<SliceOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.slice"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *base, llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
mlir::Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
Value *getBaseView() { return getOperand(0); }
ViewOp getBaseViewOp();
ViewType getBaseViewType();
unsigned getBaseViewRank() { return getBaseViewType().getRank(); }
// Get the underlying indexing at a given rank.
mlir::Value *getIndexing(unsigned rank) {
return *(getIndexings().begin() + rank);
}
// Get all the indexings in this view.
mlir::Operation::operand_range getIndexings() {
return {operand_begin() + SliceOp::FirstIndexingOperand, operand_end()};
}
// Get the subset of indexings that are of RangeType.
SmallVector<Value *, 8> getRanges();
};
/// The "linalg.view" op produces a linalg.view which is a multi-dimensional
/// range abstraction on top of an underlying linalg.buffer. This gives an
/// indexing structure to an otherwise non-indexable linalg.buffer.
///
/// A "linalg.view" takes a buffer and a variadic number of ranges and produces
/// a `view` of the same elemental type as the buffer and of rank the number of
/// ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewOp : public mlir::Op<ViewOp, mlir::OpTrait::VariadicOperands,
mlir::OpTrait::OneResult,
mlir::OpTrait::HasNoSideEffect> {
enum { FirstIndexingOperand = 1 };
public:
using Op::Op;
// Hooks to customize the behavior of this op.
static llvm::StringRef getOperationName() { return "linalg.view"; }
static void build(mlir::Builder *b, mlir::OperationState *result,
mlir::Value *buffer,
llvm::ArrayRef<mlir::Value *> indexings);
mlir::LogicalResult verify();
static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
void print(mlir::OpAsmPrinter *p);
// Op-specific functionality.
unsigned getRank() { return getViewType().getRank(); }
mlir::Type getElementType() { return getViewType().getElementType(); }
ViewType getViewType() { return getType().cast<ViewType>(); }
mlir::Value *getSupportingBuffer() { return getOperand(0); }
// Get the underlying indexing at a given rank.
mlir::Value *getIndexing(unsigned rank) {
return *(getIndexings().begin() + rank);
}
// Get all the indexings in this view.
mlir::Operation::operand_range getIndexings() {
return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()};
}
};
} // namespace mlir
#endif // MLIR_LINALG_LINALGOPS_H_

View File

@ -42,7 +42,9 @@ public:
void printType(Type type, llvm::raw_ostream &os) const override;
};
/// A BufferType represents a minimal range abstraction (min, max, step).
/// A BufferType represents a contiguous block of memory that can be allocated
/// and deallocated. A buffer cannot be indexed directly, a view must be
/// laid out on a buffer to give it indexing semantics.
class BufferTypeStorage;
class BufferType : public Type::TypeBase<BufferType, Type, BufferTypeStorage> {
public:
@ -58,6 +60,14 @@ public:
};
/// A RangeType represents a minimal range abstraction (min, max, step).
/// It is constructed by calling the linalg.range op with three values index of
/// index type:
///
/// ```{.mlir}
/// func @foo(%arg0 : index, %arg1 : index, %arg2 : index) {
/// %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
/// }
/// ```
class RangeType : public Type::TypeBase<RangeType, Type> {
public:
// Used for generic hooks in TypeBase.
@ -74,13 +84,13 @@ public:
/// A ViewType represents a multi-dimensional range abstraction on top of an
/// underlying storage type. It is parameterizable by the underlying element
/// type and the rank of the view.
/// A new value of ViewType is constructed from a buffer with a base_view op and
/// A new value of ViewType is constructed from a buffer with a view op and
/// passing it ranges:
///
/// ```{.mlir}
/// %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
/// %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
/// %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
/// ```
class ViewTypeStorage;
class ViewType

View File

@ -29,89 +29,6 @@
using namespace mlir;
//////////////////////////////////////////////////////////////////////////////
// BaseViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::BaseViewOp::build(Builder *b, OperationState *result, Value *buffer,
ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.base_view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::BaseViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError(Twine(index) + "^th index must be of range type");
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError(
"the rank of the base view must be the number of its indexings");
return success();
}
bool mlir::BaseViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"expected" + Twine(viewType.getRank()) +
" range indexings");
return parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo,
RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types);
}
// A BaseViewOp prints as:
//
// ```{.mlir}
// linalg.base_view %0[%1, %2] : !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::BaseViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(),
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << "] : " << getType();
}
//////////////////////////////////////////////////////////////////////////////
// BufferAllocOp
//////////////////////////////////////////////////////////////////////////////
@ -122,9 +39,8 @@ void mlir::BufferAllocOp::build(Builder *b, OperationState *result, Type type,
}
mlir::LogicalResult mlir::BufferAllocOp::verify() {
if (!size() || !size()->getType().isa<IntegerType>() ||
!size()->getType().cast<IntegerType>().isInteger(64))
return emitOpError("first operand should be of type i64");
if (!size() || !size()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!VectorType::isValidElementType(getElementType()) &&
!getElementType().isa<VectorType>())
return emitOpError("unsupported buffer element type");
@ -143,14 +59,14 @@ void mlir::BufferAllocOp::print(OpAsmPrinter *p) {
bool mlir::BufferAllocOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType sizeInfo;
BufferType bufferType;
auto int64Ty = parser->getBuilder().getIntegerType(64);
auto indexTy = parser->getBuilder().getIndexType();
if (parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType))
return true;
if (bufferType.getElementType() != parser->getBuilder().getF32Type())
return parser->emitError(
parser->getNameLoc(),
"Only buffer<f32> supported until mlir::Parser pieces are exposed");
return parser->resolveOperands(sizeInfo, int64Ty, result->operands) ||
return parser->resolveOperands(sizeInfo, indexTy, result->operands) ||
parser->addTypeToList(bufferType, result->types);
}
@ -183,7 +99,6 @@ bool mlir::BufferDeallocOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->parseOperand(sizeInfo) || parser->parseColonType(bufferType) ||
parser->resolveOperands(sizeInfo, bufferType, result->operands);
}
//////////////////////////////////////////////////////////////////////////////
// RangeOp
//////////////////////////////////////////////////////////////////////////////
@ -224,3 +139,218 @@ bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}
//////////////////////////////////////////////////////////////////////////////
// SliceOp
//////////////////////////////////////////////////////////////////////////////
void mlir::SliceOp::build(Builder *b, OperationState *result, Value *base,
ArrayRef<Value *> indexings) {
result->addOperands({base});
result->addOperands(indexings);
ViewType viewType = base->getType().cast<ViewType>();
unsigned rank = viewType.getRank();
for (auto *i : indexings)
if (!i->getType().isa<RangeType>())
rank--;
Type elementType = viewType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::SliceOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a view operand followed by 'rank' indices");
if (!getOperand(0)->getDefiningOp()->isa<ViewOp>())
return emitOpError(
"requires at least a view operand followed by 'rank' indices");
auto viewOp = getOperand(0)->getDefiningOp()->dyn_cast<ViewOp>();
if (!viewOp)
return emitOpError("first operand must come from a ViewOp");
unsigned rank = getBaseViewRank();
if (llvm::size(getIndexings()) != rank) {
return emitOpError("requires at least a view operand followed by " +
Twine(rank) + " indexings");
}
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>() &&
!indexing->getType().isa<IndexType>()) {
return emitOpError(Twine(index) +
"^th index must be of range or index type");
}
if (indexing->getType().isa<IndexType>())
--rank;
++index;
}
if (getRank() != rank) {
return emitOpError("the rank of the view must be the number of its range "
"indices: " +
Twine(rank));
}
return success();
}
bool mlir::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType baseInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
SmallVector<Type, 8> types;
if (parser->parseOperand(baseInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonTypeList(types))
return true;
if (types.size() != 2 + indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"unexpected number of types ");
ViewType baseViewType = types[0].dyn_cast<ViewType>();
if (!baseViewType)
return parser->emitError(parser->getNameLoc(),
"view type expected for first type");
if (indexingsInfo.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(baseViewType.getRank()) +
" indexings");
ViewType viewType = types.back().dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
ArrayRef<Type> indexingTypes =
ArrayRef<Type>(types).drop_front(1).drop_back(1);
if (indexingTypes.size() != baseViewType.getRank())
return parser->emitError(parser->getNameLoc(),
"expected " + Twine(baseViewType.getRank()) +
" indexing types");
return parser->resolveOperand(baseInfo, baseViewType, result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo, indexingTypes,
indexingsInfo.front().location,
result->operands)) ||
parser->addTypeToList(viewType, result->types);
}
// A SliceOp prints as:
//
// ```{.mlir}
// linalg.slice %0[%1, %2] :
// !linalg.view<?x?xf32>, [indexing-types], !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a view created from a buffer, %1 and %2 are
// ssa-value each holding a range.
void mlir::SliceOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getBaseView() << "[";
interleave(
getIndexings().begin(), getIndexings().end(),
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << "] : " << getBaseViewType();
for (auto indexing : getIndexings()) {
*p << ", " << indexing->getType();
}
*p << ", " << getType();
}
ViewOp mlir::SliceOp::getBaseViewOp() {
return getOperand(0)->getDefiningOp()->cast<ViewOp>();
}
ViewType mlir::SliceOp::getBaseViewType() {
return getBaseViewOp().getType().cast<ViewType>();
}
SmallVector<Value *, 8> mlir::SliceOp::getRanges() {
llvm::SmallVector<Value *, 8> res;
for (auto *operand : getIndexings()) {
if (!operand->getType().isa<IndexType>()) {
res.push_back(operand);
}
}
return res;
}
//////////////////////////////////////////////////////////////////////////////
// ViewOp
//////////////////////////////////////////////////////////////////////////////
void mlir::ViewOp::build(Builder *b, OperationState *result, Value *buffer,
ArrayRef<Value *> indexings) {
BufferType bufferType = buffer->getType().cast<BufferType>();
result->addOperands({buffer});
result->addOperands(indexings);
assert(
std::none_of(indexings.begin(), indexings.end(),
[](Value *v) { return !v->getType().isa<RangeType>(); }) &&
"linalg.view takes only arguments of type linalg.range");
Type elementType = bufferType.getElementType();
result->addTypes(
{ViewType::get(b->getContext(), elementType, indexings.size())});
}
LogicalResult mlir::ViewOp::verify() {
if (llvm::empty(getOperands()))
return emitOpError(
"requires at least a buffer operand followed by indexings");
auto bufferType = getOperand(0)->getType().dyn_cast<BufferType>();
if (!bufferType)
return emitOpError("first operand must be of BufferType");
unsigned index = 0;
for (auto indexing : getIndexings()) {
if (!indexing->getType().isa<RangeType>()) {
return emitOpError(Twine(index) + "^th index must be of range type");
}
++index;
}
if (getViewType().getRank() != index)
return emitOpError(
"the rank of the view must be the number of its indexings");
return success();
}
bool mlir::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType bufferInfo;
SmallVector<OpAsmParser::OperandType, 8> indexingsInfo;
Type type;
if (parser->parseOperand(bufferInfo) ||
parser->parseOperandList(indexingsInfo, -1,
OpAsmParser::Delimiter::Square) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
ViewType viewType = type.dyn_cast<ViewType>();
if (!viewType)
return parser->emitError(parser->getNameLoc(), "view type expected");
if (viewType.getRank() != indexingsInfo.size())
return parser->emitError(parser->getNameLoc(),
"expected" + Twine(viewType.getRank()) +
" range indexings");
return parser->resolveOperand(
bufferInfo,
BufferType::get(type.getContext(), viewType.getElementType()),
result->operands) ||
(!indexingsInfo.empty() &&
parser->resolveOperands(indexingsInfo,
RangeType::get(type.getContext()),
result->operands)) ||
parser->addTypeToList(viewType, result->types);
}
// A ViewOp prints as:
//
// ```{.mlir}
// linalg.view %0[%1, %2] : !linalg.view<?x?xf32>
// ```
//
// Where %0 is an ssa-value holding a buffer, %1 and %2 are ssa-value each
// holding a range.
void mlir::ViewOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *getSupportingBuffer() << "[";
interleave(
getIndexings().begin(), getIndexings().end(),
[&](mlir::Value *v) { *p << *v; }, [&]() { *p << ", "; });
*p << "] : " << getType();
}

View File

@ -30,7 +30,7 @@ using namespace mlir;
mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<BufferType, RangeType, ViewType>();
addOperations<BaseViewOp, BufferAllocOp, BufferDeallocOp, RangeOp>();
addOperations<BufferAllocOp, BufferDeallocOp, RangeOp, SliceOp, ViewOp>();
}
struct mlir::BufferTypeStorage : public mlir::TypeStorage {

View File

@ -7,28 +7,36 @@ func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
func @buffer(%arg0: i64, %arg1: i64) {
%0 = muli %arg0, %arg0 : i64
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
// CHECK-LABEL: func @buffer(%arg0: i64, %arg1: i64) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
%0 = muli %arg0, %arg0 : i64
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = muli %arg0, %arg0 : index
%1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
%4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
%7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
linalg.buffer_dealloc %1 : !linalg.buffer<f32>
return
}
// CHECK-LABEL: func @views(%arg0: i64, %arg1: i64, %arg2: index, %arg3: index, %arg4: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : i64
// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
// CHECK-NEXT: %3 = linalg.base_view %1[%2, %2] : !linalg.view<?x?xf32>
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>