Introduce splat op + provide its LLVM lowering
- introduce splat op in standard dialect (currently for int/float/index input type, output type can be vector or statically shaped tensor) - implement LLVM lowering (when result type is 1-d vector) - add constant folding hook for it - while on Ops.cpp, fix some stale names Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#141 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/141 from bondhugula:splat 48976a6aa0a75be6d91187db6418de989e03eb51 PiperOrigin-RevId: 270965304
This commit is contained in:
parent
42d8fa667b
commit
458ede8775
|
@ -352,6 +352,32 @@ because of the
|
||||||
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
|
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
|
||||||
in these contexts.
|
in these contexts.
|
||||||
|
|
||||||
|
### 'splat' operation
|
||||||
|
|
||||||
|
Syntax:
|
||||||
|
|
||||||
|
``` {.ebnf}
|
||||||
|
operation ::= `splat` ssa-use `:` ( vector-type | tensor-type )
|
||||||
|
```
|
||||||
|
|
||||||
|
Broadcast the operand to all elements of the result vector or tensor. The
|
||||||
|
operand has to be of either integer or float type. When the result is a tensor,
|
||||||
|
it has to be statically shaped.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```mlir {.mlir}
|
||||||
|
%s = load %A[%i] : memref<128xf32>
|
||||||
|
%v = splat %s : vector<4xf32>
|
||||||
|
%t = splat %s : tensor<8x16xi32>
|
||||||
|
```
|
||||||
|
|
||||||
|
TODO: This operation is easy to extend to broadcast to dynamically shaped
|
||||||
|
tensors in the same way dynamically shaped memrefs are handled. `mlir {.mlir} //
|
||||||
|
Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding // to the
|
||||||
|
sizes of the two dynamic dimensions. %m = "foo"() : () -> (index) %n = "bar"() :
|
||||||
|
() -> (index) %t = splat %s [%m, %n] : tensor<?x?xi32>`
|
||||||
|
|
||||||
### 'store' operation
|
### 'store' operation
|
||||||
|
|
||||||
Syntax:
|
Syntax:
|
||||||
|
|
|
@ -692,7 +692,7 @@ index-type ::= `index`
|
||||||
|
|
||||||
The `index` type is a signless integer whose size is equal to the natural
|
The `index` type is a signless integer whose size is equal to the natural
|
||||||
machine word of the target ([rationale](Rationale.md#signless-types)) and is
|
machine word of the target ([rationale](Rationale.md#signless-types)) and is
|
||||||
used by the affine constructs in MLIR. Unlike fixed-size integers. It cannot be
|
used by the affine constructs in MLIR. Unlike fixed-size integers, it cannot be
|
||||||
used as an element of vector, tensor or memref type
|
used as an element of vector, tensor or memref type
|
||||||
([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)).
|
([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)).
|
||||||
|
|
||||||
|
|
|
@ -881,6 +881,32 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
|
||||||
let summary = "signed integer shift left";
|
let summary = "signed integer shift left";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def SplatOp : Std_Op<"splat", [NoSideEffect]> {
|
||||||
|
let summary = "splat or broadcast operation";
|
||||||
|
let description = [{
|
||||||
|
The "splat" op reads a value of integer or float type and broadcasts it into
|
||||||
|
a vector or a tensor. The output of splat is thus a new value of either
|
||||||
|
vector or tensor type with elemental type being its operand's type.
|
||||||
|
When the result is a tensor, it has to be statically shaped.
|
||||||
|
|
||||||
|
%1 = splat %0 : vector<8xi32>
|
||||||
|
%2 = splat %0 : tensor<4x8xi32>
|
||||||
|
|
||||||
|
// TODO: handle broadcast to dynamically shaped tensors.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat],
|
||||||
|
"integer or float type">:$input);
|
||||||
|
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
|
||||||
|
|
||||||
|
let builders =
|
||||||
|
[OpBuilder<"Builder *builder, OperationState &result, Value *element, "
|
||||||
|
"Type aggregateType",
|
||||||
|
[{ build(builder, result, aggregateType, element); }]>];
|
||||||
|
|
||||||
|
let hasFolder = 1;
|
||||||
|
}
|
||||||
|
|
||||||
def SubFOp : FloatArithmeticOp<"subf"> {
|
def SubFOp : FloatArithmeticOp<"subf"> {
|
||||||
let summary = "floating point subtraction operation";
|
let summary = "floating point subtraction operation";
|
||||||
let hasFolder = 1;
|
let hasFolder = 1;
|
||||||
|
|
|
@ -248,17 +248,6 @@ public:
|
||||||
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the array attribute named "position" containing the given list of
|
|
||||||
// integers as integer attribute elements.
|
|
||||||
static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
|
|
||||||
ArrayRef<int64_t> values) {
|
|
||||||
SmallVector<Attribute, 4> attrs;
|
|
||||||
attrs.reserve(values.size());
|
|
||||||
for (int64_t pos : values)
|
|
||||||
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
|
|
||||||
return builder.getArrayAttr(attrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract raw data pointer value from a value representing a memref.
|
// Extract raw data pointer value from a value representing a memref.
|
||||||
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
|
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
|
||||||
Location loc,
|
Location loc,
|
||||||
|
@ -269,9 +258,9 @@ public:
|
||||||
if (hasStaticShape)
|
if (hasStaticShape)
|
||||||
return convertedMemRefValue;
|
return convertedMemRefValue;
|
||||||
else
|
else
|
||||||
return builder.create<LLVM::ExtractValueOp>(
|
return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr,
|
||||||
loc, elementTypePtr, convertedMemRefValue,
|
convertedMemRefValue,
|
||||||
getIntegerArrayAttr(builder, 0));
|
builder.getIndexArrayAttr(0));
|
||||||
return buffer;
|
return buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1028,6 +1017,39 @@ struct CondBranchOpLowering
|
||||||
using Super::Super;
|
using Super::Super;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The Splat operation is lowered to an insertelement + a shufflevector
|
||||||
|
// operation. Splat to only 1-d vector result types are lowered.
|
||||||
|
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
||||||
|
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
|
||||||
|
|
||||||
|
PatternMatchResult
|
||||||
|
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto splatOp = cast<SplatOp>(op);
|
||||||
|
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
||||||
|
if (!resultType || resultType.getRank() != 1)
|
||||||
|
return matchFailure();
|
||||||
|
|
||||||
|
// First insert it into an undef vector so we can shuffle it.
|
||||||
|
auto vectorType = lowering.convertType(splatOp.getType());
|
||||||
|
Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
|
||||||
|
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||||
|
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
|
||||||
|
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
|
||||||
|
|
||||||
|
auto v = rewriter.create<LLVM::InsertElementOp>(
|
||||||
|
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
|
||||||
|
|
||||||
|
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
|
||||||
|
SmallVector<int32_t, 4> zeroValues(width, 0);
|
||||||
|
|
||||||
|
// Shuffle the value across the desired number of elements.
|
||||||
|
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static void ensureDistinctSuccessors(Block &bb) {
|
static void ensureDistinctSuccessors(Block &bb) {
|
||||||
|
@ -1089,9 +1111,9 @@ void mlir::populateStdToLLVMConversionPatterns(
|
||||||
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
|
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
|
||||||
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
||||||
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
||||||
SelectOpLowering, SignExtendIOpLowering, SIToFPLowering, StoreOpLowering,
|
SelectOpLowering, SIToFPLowering, SignExtendIOpLowering, SplatOpLowering,
|
||||||
SubFOpLowering, SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
|
StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering,
|
||||||
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert types using the stored LLVM IR module.
|
// Convert types using the stored LLVM IR module.
|
||||||
|
|
|
@ -202,10 +202,10 @@ ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
|
||||||
numDims = opInfos.size();
|
numDims = opInfos.size();
|
||||||
|
|
||||||
// Parse the optional symbol operands.
|
// Parse the optional symbol operands.
|
||||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
auto indexTy = parser.getBuilder().getIndexType();
|
||||||
if (parser.parseOperandList(opInfos,
|
if (parser.parseOperandList(opInfos,
|
||||||
OpAsmParser::Delimiter::OptionalSquare) ||
|
OpAsmParser::Delimiter::OptionalSquare) ||
|
||||||
parser.resolveOperands(opInfos, affineIntTy, operands))
|
parser.resolveOperands(opInfos, indexTy, operands))
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1658,14 +1658,14 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
ShapedType type;
|
ShapedType type;
|
||||||
|
|
||||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
auto indexTy = parser.getBuilder().getIndexType();
|
||||||
return failure(
|
return failure(
|
||||||
parser.parseOperand(aggregateInfo) ||
|
parser.parseOperand(aggregateInfo) ||
|
||||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||||
parser.parseColonType(type) ||
|
parser.parseColonType(type) ||
|
||||||
parser.resolveOperand(aggregateInfo, type, result.operands) ||
|
parser.resolveOperand(aggregateInfo, type, result.operands) ||
|
||||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
|
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||||
parser.addTypeToList(type.getElementType(), result.types));
|
parser.addTypeToList(type.getElementType(), result.types));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1739,14 +1739,14 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
MemRefType type;
|
MemRefType type;
|
||||||
|
|
||||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
auto indexTy = parser.getBuilder().getIndexType();
|
||||||
return failure(
|
return failure(
|
||||||
parser.parseOperand(memrefInfo) ||
|
parser.parseOperand(memrefInfo) ||
|
||||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||||
parser.parseColonType(type) ||
|
parser.parseColonType(type) ||
|
||||||
parser.resolveOperand(memrefInfo, type, result.operands) ||
|
parser.resolveOperand(memrefInfo, type, result.operands) ||
|
||||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
|
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||||
parser.addTypeToList(type.getElementType(), result.types));
|
parser.addTypeToList(type.getElementType(), result.types));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2043,6 +2043,55 @@ static LogicalResult verify(SignExtendIOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SplatOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter &p, SplatOp op) {
|
||||||
|
p << "splat " << *op.getOperand();
|
||||||
|
p.printOptionalAttrDict(op.getAttrs());
|
||||||
|
p << " : " << op.getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
OpAsmParser::OperandType splatValueInfo;
|
||||||
|
ShapedType shapedType;
|
||||||
|
|
||||||
|
return failure(parser.parseOperand(splatValueInfo) ||
|
||||||
|
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||||
|
parser.parseColonType(shapedType) ||
|
||||||
|
parser.resolveOperand(splatValueInfo,
|
||||||
|
shapedType.getElementType(),
|
||||||
|
result.operands) ||
|
||||||
|
parser.addTypeToList(shapedType, result.types));
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(SplatOp op) {
|
||||||
|
// TODO: we could replace this by a trait.
|
||||||
|
if (op.getOperand()->getType() !=
|
||||||
|
op.getType().cast<ShapedType>().getElementType())
|
||||||
|
return op.emitError("operand should be of elemental type of result type");
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Constant folding hook for SplatOp.
|
||||||
|
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.size() == 1 && "splat takes one operand");
|
||||||
|
|
||||||
|
auto constOperand = operands.front();
|
||||||
|
if (!constOperand ||
|
||||||
|
(!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
|
||||||
|
return {};
|
||||||
|
|
||||||
|
auto shapedType = getType().cast<ShapedType>();
|
||||||
|
assert(shapedType.getElementType() == constOperand.getType() &&
|
||||||
|
"incorrect input attribute type for folding");
|
||||||
|
|
||||||
|
// SplatElementsAttr::get treats single value for second arg as being a splat.
|
||||||
|
return SplatElementsAttr::get(shapedType, {constOperand});
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// StoreOp
|
// StoreOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2062,7 +2111,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||||
MemRefType memrefType;
|
MemRefType memrefType;
|
||||||
|
|
||||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
auto indexTy = parser.getBuilder().getIndexType();
|
||||||
return failure(
|
return failure(
|
||||||
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
||||||
parser.parseOperand(memrefInfo) ||
|
parser.parseOperand(memrefInfo) ||
|
||||||
|
@ -2072,7 +2121,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
||||||
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
|
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
|
||||||
result.operands) ||
|
result.operands) ||
|
||||||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
|
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
|
||||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands));
|
parser.resolveOperands(indexInfo, indexTy, result.operands));
|
||||||
}
|
}
|
||||||
|
|
||||||
static LogicalResult verify(StoreOp op) {
|
static LogicalResult verify(StoreOp op) {
|
||||||
|
|
|
@ -552,3 +552,18 @@ func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
|
||||||
// And we're done
|
// And we're done
|
||||||
// CHECK-NEXT: return
|
// CHECK-NEXT: return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: @splat
|
||||||
|
// CHECK-SAME: [[A:%arg[0-9]+]]: !llvm<"<4 x float>">
|
||||||
|
// CHECK-SAME: [[ELT:%arg[0-9]+]]: !llvm.float
|
||||||
|
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
|
||||||
|
%vb = splat %b : vector<4xf32>
|
||||||
|
%r = mulf %a, %vb : vector<4xf32>
|
||||||
|
return %r : vector<4xf32>
|
||||||
|
}
|
||||||
|
// CHECK-NEXT: [[UNDEF:%[0-9]+]] = llvm.mlir.undef : !llvm<"<4 x float>">
|
||||||
|
// CHECK-NEXT: [[ZERO:%[0-9]+]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
// CHECK-NEXT: [[V:%[0-9]+]] = llvm.insertelement [[UNDEF]], [[ELT]], [[ZERO]] : !llvm<"<4 x float>">
|
||||||
|
// CHECK-NEXT: [[SPLAT:%[0-9]+]] = llvm.shufflevector [[V]], [[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
|
||||||
|
// CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>">
|
||||||
|
// CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>">
|
||||||
|
|
|
@ -467,6 +467,17 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @test_splat_op
|
||||||
|
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
|
||||||
|
func @test_splat_op(%s : f32) {
|
||||||
|
%v = splat %s : vector<8xf32>
|
||||||
|
// CHECK: splat [[S]] : vector<8xf32>
|
||||||
|
%t = splat %s : tensor<8xf32>
|
||||||
|
// CHECK: splat [[S]] : tensor<8xf32>
|
||||||
|
%u = "std.splat"(%s) : (f32) -> vector<4xf32>
|
||||||
|
// CHECK: splat [[S]] : vector<4xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: func @test_vector.transfer_ops(%arg0
|
// CHECK-LABEL: func @test_vector.transfer_ops(%arg0
|
||||||
func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {
|
func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {
|
||||||
|
|
|
@ -821,3 +821,27 @@ func @return_not_in_function() {
|
||||||
}): () -> ()
|
}): () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_splat(%v : f32) {
|
||||||
|
splat %v : memref<8xf32>
|
||||||
|
// expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_splat(%v : vector<8xf32>) {
|
||||||
|
%w = splat %v : tensor<8xvector<8xf32>>
|
||||||
|
// expected-error@-1 {{must be integer or float type}}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
||||||
|
splat %v : vector<8xf64>
|
||||||
|
// expected-error@-1 {{expects different type than prior uses}}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -540,3 +540,15 @@ func @custom_insertion_position() {
|
||||||
}) : () -> ()
|
}) : () -> ()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @splat_fold
|
||||||
|
func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
|
||||||
|
%c = constant 1.0 : f32
|
||||||
|
%v = splat %c : vector<4xf32>
|
||||||
|
%t = splat %c : tensor<4xf32>
|
||||||
|
return %v, %t : vector<4xf32>, tensor<4xf32>
|
||||||
|
|
||||||
|
// CHECK-NEXT: [[V:%.*]] = constant dense<1.000000e+00> : vector<4xf32>
|
||||||
|
// CHECK-NEXT: [[T:%.*]] = constant dense<1.000000e+00> : tensor<4xf32>
|
||||||
|
// CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
|
||||||
|
}
|
||||||
|
|
|
@ -31,7 +31,7 @@ syn match mlirType /x\s*\zsvector/
|
||||||
" Operations.
|
" Operations.
|
||||||
" Core ops (not exhaustive yet).
|
" Core ops (not exhaustive yet).
|
||||||
" TODO: the list is not exhaustive.
|
" TODO: the list is not exhaustive.
|
||||||
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli store select subf subi tensor_cast
|
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli splat store select subf subi tensor_cast
|
||||||
|
|
||||||
" Affine ops.
|
" Affine ops.
|
||||||
syn match mlirOps /\<affine\.apply\>/
|
syn match mlirOps /\<affine\.apply\>/
|
||||||
|
|
Loading…
Reference in New Issue