Introduce splat op + provide its LLVM lowering

- introduce splat op in standard dialect (currently for int/float/index input
  type, output type can be vector or statically shaped tensor)
- implement LLVM lowering (when result type is 1-d vector)
- add constant folding hook for it
- while on Ops.cpp, fix some stale names

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#141

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/141 from bondhugula:splat 48976a6aa0a75be6d91187db6418de989e03eb51
PiperOrigin-RevId: 270965304
This commit is contained in:
Uday Bondhugula 2019-09-24 12:44:11 -07:00 committed by A. Unique TensorFlower
parent 42d8fa667b
commit 458ede8775
10 changed files with 212 additions and 27 deletions

View File

@ -352,6 +352,32 @@ because of the
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) [restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
in these contexts. in these contexts.
### 'splat' operation
Syntax:
``` {.ebnf}
operation ::= `splat` ssa-use `:` ( vector-type | tensor-type )
```
Broadcast the operand to all elements of the result vector or tensor. The
operand has to be of either integer or float type. When the result is a tensor,
it has to be statically shaped.
Example:
```mlir {.mlir}
%s = load %A[%i] : memref<128xf32>
%v = splat %s : vector<4xf32>
%t = splat %s : tensor<8x16xi32>
```
TODO: This operation is easy to extend to broadcast to dynamically shaped
tensors in the same way dynamically shaped memrefs are handled. `mlir {.mlir} //
Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding // to the
sizes of the two dynamic dimensions. %m = "foo"() : () -> (index) %n = "bar"() :
() -> (index) %t = splat %s [%m, %n] : tensor<?x?xi32>`
### 'store' operation ### 'store' operation
Syntax: Syntax:

View File

@ -692,7 +692,7 @@ index-type ::= `index`
The `index` type is a signless integer whose size is equal to the natural The `index` type is a signless integer whose size is equal to the natural
machine word of the target ([rationale](Rationale.md#signless-types)) and is machine word of the target ([rationale](Rationale.md#signless-types)) and is
used by the affine constructs in MLIR. Unlike fixed-size integers. It cannot be used by the affine constructs in MLIR. Unlike fixed-size integers, it cannot be
used as an element of vector, tensor or memref type used as an element of vector, tensor or memref type
([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)). ([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)).

View File

@ -881,6 +881,32 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
let summary = "signed integer shift left"; let summary = "signed integer shift left";
} }
def SplatOp : Std_Op<"splat", [NoSideEffect]> {
let summary = "splat or broadcast operation";
let description = [{
The "splat" op reads a value of integer or float type and broadcasts it into
a vector or a tensor. The output of splat is thus a new value of either
vector or tensor type with elemental type being its operand's type.
When the result is a tensor, it has to be statically shaped.
%1 = splat %0 : vector<8xi32>
%2 = splat %0 : tensor<4x8xi32>
// TODO: handle broadcast to dynamically shaped tensors.
}];
let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat],
"integer or float type">:$input);
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
let builders =
[OpBuilder<"Builder *builder, OperationState &result, Value *element, "
"Type aggregateType",
[{ build(builder, result, aggregateType, element); }]>];
let hasFolder = 1;
}
def SubFOp : FloatArithmeticOp<"subf"> { def SubFOp : FloatArithmeticOp<"subf"> {
let summary = "floating point subtraction operation"; let summary = "floating point subtraction operation";
let hasFolder = 1; let hasFolder = 1;

View File

@ -248,17 +248,6 @@ public:
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr); return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
} }
// Get the array attribute named "position" containing the given list of
// integers as integer attribute elements.
static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
ArrayRef<int64_t> values) {
SmallVector<Attribute, 4> attrs;
attrs.reserve(values.size());
for (int64_t pos : values)
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
return builder.getArrayAttr(attrs);
}
// Extract raw data pointer value from a value representing a memref. // Extract raw data pointer value from a value representing a memref.
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
Location loc, Location loc,
@ -269,9 +258,9 @@ public:
if (hasStaticShape) if (hasStaticShape)
return convertedMemRefValue; return convertedMemRefValue;
else else
return builder.create<LLVM::ExtractValueOp>( return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr,
loc, elementTypePtr, convertedMemRefValue, convertedMemRefValue,
getIntegerArrayAttr(builder, 0)); builder.getIndexArrayAttr(0));
return buffer; return buffer;
} }
@ -1028,6 +1017,39 @@ struct CondBranchOpLowering
using Super::Super; using Super::Super;
}; };
// The Splat operation is lowered to an insertelement + a shufflevector
// operation. Splat to only 1-d vector result types are lowered.
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
auto splatOp = cast<SplatOp>(op);
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
if (!resultType || resultType.getRank() != 1)
return matchFailure();
// First insert it into an undef vector so we can shuffle it.
auto vectorType = lowering.convertType(splatOp.getType());
Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
auto zero = rewriter.create<LLVM::ConstantOp>(
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
auto v = rewriter.create<LLVM::InsertElementOp>(
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
SmallVector<int32_t, 4> zeroValues(width, 0);
// Shuffle the value across the desired number of elements.
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
return matchSuccess();
}
};
} // namespace } // namespace
static void ensureDistinctSuccessors(Block &bb) { static void ensureDistinctSuccessors(Block &bb) {
@ -1089,9 +1111,9 @@ void mlir::populateStdToLLVMConversionPatterns(
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering, DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering, MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering, RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
SelectOpLowering, SignExtendIOpLowering, SIToFPLowering, StoreOpLowering, SelectOpLowering, SIToFPLowering, SignExtendIOpLowering, SplatOpLowering,
SubFOpLowering, SubIOpLowering, TruncateIOpLowering, XOrOpLowering, StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering,
ZeroExtendIOpLowering>(*converter.getDialect(), converter); XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter);
} }
// Convert types using the stored LLVM IR module. // Convert types using the stored LLVM IR module.

View File

@ -202,10 +202,10 @@ ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
numDims = opInfos.size(); numDims = opInfos.size();
// Parse the optional symbol operands. // Parse the optional symbol operands.
auto affineIntTy = parser.getBuilder().getIndexType(); auto indexTy = parser.getBuilder().getIndexType();
if (parser.parseOperandList(opInfos, if (parser.parseOperandList(opInfos,
OpAsmParser::Delimiter::OptionalSquare) || OpAsmParser::Delimiter::OptionalSquare) ||
parser.resolveOperands(opInfos, affineIntTy, operands)) parser.resolveOperands(opInfos, indexTy, operands))
return failure(); return failure();
return success(); return success();
} }
@ -1658,14 +1658,14 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ShapedType type; ShapedType type;
auto affineIntTy = parser.getBuilder().getIndexType(); auto indexTy = parser.getBuilder().getIndexType();
return failure( return failure(
parser.parseOperand(aggregateInfo) || parser.parseOperand(aggregateInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttributeDict(result.attributes) || parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) || parser.parseColonType(type) ||
parser.resolveOperand(aggregateInfo, type, result.operands) || parser.resolveOperand(aggregateInfo, type, result.operands) ||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types)); parser.addTypeToList(type.getElementType(), result.types));
} }
@ -1739,14 +1739,14 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType type; MemRefType type;
auto affineIntTy = parser.getBuilder().getIndexType(); auto indexTy = parser.getBuilder().getIndexType();
return failure( return failure(
parser.parseOperand(memrefInfo) || parser.parseOperand(memrefInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttributeDict(result.attributes) || parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) || parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types)); parser.addTypeToList(type.getElementType(), result.types));
} }
@ -2043,6 +2043,55 @@ static LogicalResult verify(SignExtendIOp op) {
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
static void print(OpAsmPrinter &p, SplatOp op) {
p << "splat " << *op.getOperand();
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.getType();
}
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType splatValueInfo;
ShapedType shapedType;
return failure(parser.parseOperand(splatValueInfo) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(shapedType) ||
parser.resolveOperand(splatValueInfo,
shapedType.getElementType(),
result.operands) ||
parser.addTypeToList(shapedType, result.types));
}
static LogicalResult verify(SplatOp op) {
// TODO: we could replace this by a trait.
if (op.getOperand()->getType() !=
op.getType().cast<ShapedType>().getElementType())
return op.emitError("operand should be of elemental type of result type");
return success();
}
// Constant folding hook for SplatOp.
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "splat takes one operand");
auto constOperand = operands.front();
if (!constOperand ||
(!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
return {};
auto shapedType = getType().cast<ShapedType>();
assert(shapedType.getElementType() == constOperand.getType() &&
"incorrect input attribute type for folding");
// SplatElementsAttr::get treats single value for second arg as being a splat.
return SplatElementsAttr::get(shapedType, {constOperand});
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// StoreOp // StoreOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2062,7 +2111,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo; SmallVector<OpAsmParser::OperandType, 4> indexInfo;
MemRefType memrefType; MemRefType memrefType;
auto affineIntTy = parser.getBuilder().getIndexType(); auto indexTy = parser.getBuilder().getIndexType();
return failure( return failure(
parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) || parser.parseOperand(memrefInfo) ||
@ -2072,7 +2121,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(storeValueInfo, memrefType.getElementType(), parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
result.operands) || result.operands) ||
parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
parser.resolveOperands(indexInfo, affineIntTy, result.operands)); parser.resolveOperands(indexInfo, indexTy, result.operands));
} }
static LogicalResult verify(StoreOp op) { static LogicalResult verify(StoreOp op) {

View File

@ -552,3 +552,18 @@ func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
// And we're done // And we're done
// CHECK-NEXT: return // CHECK-NEXT: return
} }
// CHECK-LABEL: @splat
// CHECK-SAME: [[A:%arg[0-9]+]]: !llvm<"<4 x float>">
// CHECK-SAME: [[ELT:%arg[0-9]+]]: !llvm.float
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
%vb = splat %b : vector<4xf32>
%r = mulf %a, %vb : vector<4xf32>
return %r : vector<4xf32>
}
// CHECK-NEXT: [[UNDEF:%[0-9]+]] = llvm.mlir.undef : !llvm<"<4 x float>">
// CHECK-NEXT: [[ZERO:%[0-9]+]] = llvm.mlir.constant(0 : i32) : !llvm.i32
// CHECK-NEXT: [[V:%[0-9]+]] = llvm.insertelement [[UNDEF]], [[ELT]], [[ZERO]] : !llvm<"<4 x float>">
// CHECK-NEXT: [[SPLAT:%[0-9]+]] = llvm.shufflevector [[V]], [[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
// CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>">
// CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>">

View File

@ -467,6 +467,17 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
return return
} }
// CHECK-LABEL: func @test_splat_op
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
func @test_splat_op(%s : f32) {
%v = splat %s : vector<8xf32>
// CHECK: splat [[S]] : vector<8xf32>
%t = splat %s : tensor<8xf32>
// CHECK: splat [[S]] : tensor<8xf32>
%u = "std.splat"(%s) : (f32) -> vector<4xf32>
// CHECK: splat [[S]] : vector<4xf32>
return
}
// CHECK-LABEL: func @test_vector.transfer_ops(%arg0 // CHECK-LABEL: func @test_vector.transfer_ops(%arg0
func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) { func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {

View File

@ -821,3 +821,27 @@ func @return_not_in_function() {
}): () -> () }): () -> ()
return return
} }
// -----
func @invalid_splat(%v : f32) {
splat %v : memref<8xf32>
// expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}}
return
}
// -----
func @invalid_splat(%v : vector<8xf32>) {
%w = splat %v : tensor<8xvector<8xf32>>
// expected-error@-1 {{must be integer or float type}}
return
}
// -----
func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
splat %v : vector<8xf64>
// expected-error@-1 {{expects different type than prior uses}}
return
}

View File

@ -540,3 +540,15 @@ func @custom_insertion_position() {
}) : () -> () }) : () -> ()
return return
} }
// CHECK-LABEL: func @splat_fold
func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
%c = constant 1.0 : f32
%v = splat %c : vector<4xf32>
%t = splat %c : tensor<4xf32>
return %v, %t : vector<4xf32>, tensor<4xf32>
// CHECK-NEXT: [[V:%.*]] = constant dense<1.000000e+00> : vector<4xf32>
// CHECK-NEXT: [[T:%.*]] = constant dense<1.000000e+00> : tensor<4xf32>
// CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
}

View File

@ -31,7 +31,7 @@ syn match mlirType /x\s*\zsvector/
" Operations. " Operations.
" Core ops (not exhaustive yet). " Core ops (not exhaustive yet).
" TODO: the list is not exhaustive. " TODO: the list is not exhaustive.
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli store select subf subi tensor_cast syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli splat store select subf subi tensor_cast
" Affine ops. " Affine ops.
syn match mlirOps /\<affine\.apply\>/ syn match mlirOps /\<affine\.apply\>/