Introduce splat op + provide its LLVM lowering
- introduce splat op in standard dialect (currently for int/float/index input type, output type can be vector or statically shaped tensor) - implement LLVM lowering (when result type is 1-d vector) - add constant folding hook for it - while on Ops.cpp, fix some stale names Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#141 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/141 from bondhugula:splat 48976a6aa0a75be6d91187db6418de989e03eb51 PiperOrigin-RevId: 270965304
This commit is contained in:
parent
42d8fa667b
commit
458ede8775
|
@ -352,6 +352,32 @@ because of the
|
|||
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
|
||||
in these contexts.
|
||||
|
||||
### 'splat' operation
|
||||
|
||||
Syntax:
|
||||
|
||||
``` {.ebnf}
|
||||
operation ::= `splat` ssa-use `:` ( vector-type | tensor-type )
|
||||
```
|
||||
|
||||
Broadcast the operand to all elements of the result vector or tensor. The
|
||||
operand has to be of either integer or float type. When the result is a tensor,
|
||||
it has to be statically shaped.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir {.mlir}
|
||||
%s = load %A[%i] : memref<128xf32>
|
||||
%v = splat %s : vector<4xf32>
|
||||
%t = splat %s : tensor<8x16xi32>
|
||||
```
|
||||
|
||||
TODO: This operation is easy to extend to broadcast to dynamically shaped
|
||||
tensors in the same way dynamically shaped memrefs are handled. `mlir {.mlir} //
|
||||
Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding // to the
|
||||
sizes of the two dynamic dimensions. %m = "foo"() : () -> (index) %n = "bar"() :
|
||||
() -> (index) %t = splat %s [%m, %n] : tensor<?x?xi32>`
|
||||
|
||||
### 'store' operation
|
||||
|
||||
Syntax:
|
||||
|
|
|
@ -692,7 +692,7 @@ index-type ::= `index`
|
|||
|
||||
The `index` type is a signless integer whose size is equal to the natural
|
||||
machine word of the target ([rationale](Rationale.md#signless-types)) and is
|
||||
used by the affine constructs in MLIR. Unlike fixed-size integers. It cannot be
|
||||
used by the affine constructs in MLIR. Unlike fixed-size integers, it cannot be
|
||||
used as an element of vector, tensor or memref type
|
||||
([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)).
|
||||
|
||||
|
|
|
@ -881,6 +881,32 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
|
|||
let summary = "signed integer shift left";
|
||||
}
|
||||
|
||||
def SplatOp : Std_Op<"splat", [NoSideEffect]> {
|
||||
let summary = "splat or broadcast operation";
|
||||
let description = [{
|
||||
The "splat" op reads a value of integer or float type and broadcasts it into
|
||||
a vector or a tensor. The output of splat is thus a new value of either
|
||||
vector or tensor type with elemental type being its operand's type.
|
||||
When the result is a tensor, it has to be statically shaped.
|
||||
|
||||
%1 = splat %0 : vector<8xi32>
|
||||
%2 = splat %0 : tensor<4x8xi32>
|
||||
|
||||
// TODO: handle broadcast to dynamically shaped tensors.
|
||||
}];
|
||||
|
||||
let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat],
|
||||
"integer or float type">:$input);
|
||||
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
|
||||
|
||||
let builders =
|
||||
[OpBuilder<"Builder *builder, OperationState &result, Value *element, "
|
||||
"Type aggregateType",
|
||||
[{ build(builder, result, aggregateType, element); }]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def SubFOp : FloatArithmeticOp<"subf"> {
|
||||
let summary = "floating point subtraction operation";
|
||||
let hasFolder = 1;
|
||||
|
|
|
@ -248,17 +248,6 @@ public:
|
|||
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
|
||||
}
|
||||
|
||||
// Get the array attribute named "position" containing the given list of
|
||||
// integers as integer attribute elements.
|
||||
static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
|
||||
ArrayRef<int64_t> values) {
|
||||
SmallVector<Attribute, 4> attrs;
|
||||
attrs.reserve(values.size());
|
||||
for (int64_t pos : values)
|
||||
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
|
||||
return builder.getArrayAttr(attrs);
|
||||
}
|
||||
|
||||
// Extract raw data pointer value from a value representing a memref.
|
||||
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
|
||||
Location loc,
|
||||
|
@ -269,9 +258,9 @@ public:
|
|||
if (hasStaticShape)
|
||||
return convertedMemRefValue;
|
||||
else
|
||||
return builder.create<LLVM::ExtractValueOp>(
|
||||
loc, elementTypePtr, convertedMemRefValue,
|
||||
getIntegerArrayAttr(builder, 0));
|
||||
return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr,
|
||||
convertedMemRefValue,
|
||||
builder.getIndexArrayAttr(0));
|
||||
return buffer;
|
||||
}
|
||||
|
||||
|
@ -1028,6 +1017,39 @@ struct CondBranchOpLowering
|
|||
using Super::Super;
|
||||
};
|
||||
|
||||
// The Splat operation is lowered to an insertelement + a shufflevector
|
||||
// operation. Splat to only 1-d vector result types are lowered.
|
||||
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
|
||||
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto splatOp = cast<SplatOp>(op);
|
||||
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
|
||||
if (!resultType || resultType.getRank() != 1)
|
||||
return matchFailure();
|
||||
|
||||
// First insert it into an undef vector so we can shuffle it.
|
||||
auto vectorType = lowering.convertType(splatOp.getType());
|
||||
Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
|
||||
auto zero = rewriter.create<LLVM::ConstantOp>(
|
||||
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
|
||||
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
|
||||
|
||||
auto v = rewriter.create<LLVM::InsertElementOp>(
|
||||
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
|
||||
|
||||
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
|
||||
SmallVector<int32_t, 4> zeroValues(width, 0);
|
||||
|
||||
// Shuffle the value across the desired number of elements.
|
||||
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
|
||||
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static void ensureDistinctSuccessors(Block &bb) {
|
||||
|
@ -1089,9 +1111,9 @@ void mlir::populateStdToLLVMConversionPatterns(
|
|||
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
|
||||
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
|
||||
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
|
||||
SelectOpLowering, SignExtendIOpLowering, SIToFPLowering, StoreOpLowering,
|
||||
SubFOpLowering, SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
|
||||
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||
SelectOpLowering, SIToFPLowering, SignExtendIOpLowering, SplatOpLowering,
|
||||
StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering,
|
||||
XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter);
|
||||
}
|
||||
|
||||
// Convert types using the stored LLVM IR module.
|
||||
|
|
|
@ -202,10 +202,10 @@ ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
|
|||
numDims = opInfos.size();
|
||||
|
||||
// Parse the optional symbol operands.
|
||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
if (parser.parseOperandList(opInfos,
|
||||
OpAsmParser::Delimiter::OptionalSquare) ||
|
||||
parser.resolveOperands(opInfos, affineIntTy, operands))
|
||||
parser.resolveOperands(opInfos, indexTy, operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
@ -1658,14 +1658,14 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
|||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
ShapedType type;
|
||||
|
||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(aggregateInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(aggregateInfo, type, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type.getElementType(), result.types));
|
||||
}
|
||||
|
||||
|
@ -1739,14 +1739,14 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType type;
|
||||
|
||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(memrefInfo) ||
|
||||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonType(type) ||
|
||||
parser.resolveOperand(memrefInfo, type, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
|
||||
parser.addTypeToList(type.getElementType(), result.types));
|
||||
}
|
||||
|
||||
|
@ -2043,6 +2043,55 @@ static LogicalResult verify(SignExtendIOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, SplatOp op) {
|
||||
p << "splat " << *op.getOperand();
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.getType();
|
||||
}
|
||||
|
||||
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
|
||||
OpAsmParser::OperandType splatValueInfo;
|
||||
ShapedType shapedType;
|
||||
|
||||
return failure(parser.parseOperand(splatValueInfo) ||
|
||||
parser.parseOptionalAttributeDict(result.attributes) ||
|
||||
parser.parseColonType(shapedType) ||
|
||||
parser.resolveOperand(splatValueInfo,
|
||||
shapedType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.addTypeToList(shapedType, result.types));
|
||||
}
|
||||
|
||||
static LogicalResult verify(SplatOp op) {
|
||||
// TODO: we could replace this by a trait.
|
||||
if (op.getOperand()->getType() !=
|
||||
op.getType().cast<ShapedType>().getElementType())
|
||||
return op.emitError("operand should be of elemental type of result type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
// Constant folding hook for SplatOp.
|
||||
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
|
||||
assert(operands.size() == 1 && "splat takes one operand");
|
||||
|
||||
auto constOperand = operands.front();
|
||||
if (!constOperand ||
|
||||
(!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
|
||||
return {};
|
||||
|
||||
auto shapedType = getType().cast<ShapedType>();
|
||||
assert(shapedType.getElementType() == constOperand.getType() &&
|
||||
"incorrect input attribute type for folding");
|
||||
|
||||
// SplatElementsAttr::get treats single value for second arg as being a splat.
|
||||
return SplatElementsAttr::get(shapedType, {constOperand});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2062,7 +2111,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|||
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
|
||||
MemRefType memrefType;
|
||||
|
||||
auto affineIntTy = parser.getBuilder().getIndexType();
|
||||
auto indexTy = parser.getBuilder().getIndexType();
|
||||
return failure(
|
||||
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(memrefInfo) ||
|
||||
|
@ -2072,7 +2121,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
|
|||
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
|
||||
result.operands) ||
|
||||
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
|
||||
parser.resolveOperands(indexInfo, affineIntTy, result.operands));
|
||||
parser.resolveOperands(indexInfo, indexTy, result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
|
|
|
@ -552,3 +552,18 @@ func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
|
|||
// And we're done
|
||||
// CHECK-NEXT: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @splat
|
||||
// CHECK-SAME: [[A:%arg[0-9]+]]: !llvm<"<4 x float>">
|
||||
// CHECK-SAME: [[ELT:%arg[0-9]+]]: !llvm.float
|
||||
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
|
||||
%vb = splat %b : vector<4xf32>
|
||||
%r = mulf %a, %vb : vector<4xf32>
|
||||
return %r : vector<4xf32>
|
||||
}
|
||||
// CHECK-NEXT: [[UNDEF:%[0-9]+]] = llvm.mlir.undef : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: [[ZERO:%[0-9]+]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||
// CHECK-NEXT: [[V:%[0-9]+]] = llvm.insertelement [[UNDEF]], [[ELT]], [[ZERO]] : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: [[SPLAT:%[0-9]+]] = llvm.shufflevector [[V]], [[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
|
||||
// CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>">
|
||||
// CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>">
|
||||
|
|
|
@ -467,6 +467,17 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_splat_op
|
||||
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
|
||||
func @test_splat_op(%s : f32) {
|
||||
%v = splat %s : vector<8xf32>
|
||||
// CHECK: splat [[S]] : vector<8xf32>
|
||||
%t = splat %s : tensor<8xf32>
|
||||
// CHECK: splat [[S]] : tensor<8xf32>
|
||||
%u = "std.splat"(%s) : (f32) -> vector<4xf32>
|
||||
// CHECK: splat [[S]] : vector<4xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_vector.transfer_ops(%arg0
|
||||
func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {
|
||||
|
|
|
@ -821,3 +821,27 @@ func @return_not_in_function() {
|
|||
}): () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_splat(%v : f32) {
|
||||
splat %v : memref<8xf32>
|
||||
// expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_splat(%v : vector<8xf32>) {
|
||||
%w = splat %v : tensor<8xvector<8xf32>>
|
||||
// expected-error@-1 {{must be integer or float type}}
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
|
||||
splat %v : vector<8xf64>
|
||||
// expected-error@-1 {{expects different type than prior uses}}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -540,3 +540,15 @@ func @custom_insertion_position() {
|
|||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @splat_fold
|
||||
func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
|
||||
%c = constant 1.0 : f32
|
||||
%v = splat %c : vector<4xf32>
|
||||
%t = splat %c : tensor<4xf32>
|
||||
return %v, %t : vector<4xf32>, tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: [[V:%.*]] = constant dense<1.000000e+00> : vector<4xf32>
|
||||
// CHECK-NEXT: [[T:%.*]] = constant dense<1.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ syn match mlirType /x\s*\zsvector/
|
|||
" Operations.
|
||||
" Core ops (not exhaustive yet).
|
||||
" TODO: the list is not exhaustive.
|
||||
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli store select subf subi tensor_cast
|
||||
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli splat store select subf subi tensor_cast
|
||||
|
||||
" Affine ops.
|
||||
syn match mlirOps /\<affine\.apply\>/
|
||||
|
|
Loading…
Reference in New Issue