[mlir] Generalize broadcastable trait operands

Summary:
Generalize broadcastable trait to variadic operands. Update the
documentation that still talked about element type as part of
broadcastable trait (that bug was already fixed). Also rename
Broadcastable to ResultBroadcastableShape to be more explicit that the
trait affects the result shape (it is possible for op to allow
broadcastable operands but not have result shape that is broadcast
compatible with operands).

Doing some intermediate work to have getBroadcastedType take an optional
elementType as input and use that if specified, instead of the common
element type of type1 and type2 in this function.

Differential Revision: https://reviews.llvm.org/D72559
This commit is contained in:
Jacques Pienaar 2020-01-11 09:42:18 -08:00
parent f6418d72f5
commit b70e4efb75
7 changed files with 111 additions and 92 deletions

View File

@ -137,20 +137,20 @@ section goes as follows:
### Broadcastable
* `OpTrait::BroadcastableTwoOperandsOneResult` -- `Broadcastable`
* `OpTrait::ResultsBroadcastableShape` -- `ResultsBroadcastableShape`
This trait provides the API for operations that are known to have
This trait adds the property that the operation is known to have
[broadcast-compatible](https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
operand and result types. Specifically, starting from the most varying
dimension, each dimension pair of the two operands' types should either be the
same or one of them is one. Also, the result type should have the corresponding
operands and its result types' shape is the broadcast compatible with the shape
of the broadcasted operands. Specifically, starting from the most varying
dimension, each dimension pair of the two operands' shapes should either be the
same or one of them is one. Also, the result shape should have the corresponding
dimension equal to the larger one, if known. Shapes are checked partially if
ranks or dimensions are not known. For example, an op with `tensor<?x2xf32>` and
`tensor<2xf32>` as operand types and `tensor<3x2xf32>` as the result type is
broadcast-compatible.
Ths trait assumes the op has two operands and one result, and it asserts if the
pre-condition is not satisfied.
This trait requires that the operands are either vector or tensor types.
### Commutative

View File

@ -51,23 +51,26 @@ bool getBroadcastedShape(ArrayRef<int64_t> shape1, ArrayRef<int64_t> shape2,
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type getBroadcastedType(Type type1, Type type2);
///
/// elementType, if specified, will be used as the element type of the
/// broadcasted result type. Otherwise it is required that the element type of
/// type1 and type2 is the same and this element type will be used as the
/// resultant element type.
Type getBroadcastedType(Type type1, Type type2, Type elementType = nullptr);
} // namespace util
/// This class provides the API for ops that are known to have broadcast-
/// compatible operand and result types. Specifically, starting from the
/// most varying dimension, each dimension pair of the two operands' types
/// should either be the same or one of them is one. Also, the result type
/// should have the corresponding dimension equal to the larger one, if known.
/// Shapes are checked partially if ranks or dimensions are not known. For
/// example, an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand
/// types and tensor<3 x 2 x f32> as the result type is broadcast-compatible.
///
/// Ths trait assumes the op has two operands and one result, and it asserts
/// if the pre-condition is not satisfied.
/// Trait for ops that are known to have broadcast compatible operands and
/// result types. Specifically, starting from the most varying dimension, each
/// dimension pair of the operands' shapes should either be the same or one
/// of them is one. Also, the results's shapes should have the corresponding
/// dimension equal to the larger one, if known. Shapes are checked partially if
/// ranks or dimensions are not known. For example, an op with tensor<?x2xf32>
/// and tensor<2xf32> as operand types and tensor<5x3x2xi16> as the result
/// type has broadcast compatible operands ns result types.
template <typename ConcreteType>
class BroadcastableTwoOperandsOneResult
: public TraitBase<ConcreteType, BroadcastableTwoOperandsOneResult> {
class ResultsBroadcastableShape
: public TraitBase<ConcreteType, ResultsBroadcastableShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyCompatibleOperandBroadcast(op);

View File

@ -1327,7 +1327,10 @@ class PredOpTrait<string descr, Pred pred> : OpTrait {
}
// Op supports operand broadcast behavior.
def Broadcastable : NativeOpTrait<"BroadcastableTwoOperandsOneResult">;
def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
// TODO: Alias of the above, remove post integrate.
def Broadcastable : NativeOpTrait<"ResultsBroadcastableShape">;
// X op Y == Y op X
def Commutative : NativeOpTrait<"IsCommutative">;
// Op behaves like a function.

View File

@ -8,6 +8,7 @@
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
@ -80,25 +81,27 @@ static ArrayRef<int64_t> getShape(Type type) {
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
// Returns the scalar type out of the given type.
auto getScalarType = [](Type type) -> Type {
if (auto shapedType = type.dyn_cast<ShapedType>())
return shapedType.getElementType();
return type;
};
// Make sure underlying scalar type is the same.
auto scalarType = getScalarType(type1);
if (scalarType != getScalarType(type2))
return {};
///
/// elementType, if specified, will be used as the element type of the
/// broadcasted result type. Otherwise it is required that the element type of
/// type1 and type2 is the same and this element type will be used as the
/// resultant element type.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2,
Type elementType) {
// If the elementType is not specified, then the use the common element type
// of the inputs or fail if there is no common element type.
if (!elementType) {
elementType = getElementTypeOrSelf(type1);
if (elementType != getElementTypeOrSelf(type2))
return {};
}
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
if (type1.isa<VectorType>() || type2.isa<VectorType>())
return {};
return UnrankedTensorType::get(scalarType);
return UnrankedTensorType::get(elementType);
}
// Returns the type kind if the given type is a vector or ranked tensor type.
@ -132,16 +135,18 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
// Compose the final broadcasted type
if (resultCompositeKind == StandardTypes::Vector)
return VectorType::get(resultShape, scalarType);
return VectorType::get(resultShape, elementType);
if (resultCompositeKind == StandardTypes::RankedTensor)
return RankedTensorType::get(resultShape, scalarType);
return scalarType;
return RankedTensorType::get(resultShape, elementType);
return elementType;
}
/// Returns true if the given types has both vector types and tensor types.
static bool hasBothVectorAndTensorType(ArrayRef<Type> types) {
return llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }) &&
llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); });
/// Returns a tuple corresponding to whether range has tensor or vector type.
template <typename iterator_range>
static std::tuple<bool, bool> hasTensorOrVectorType(iterator_range types) {
return std::make_tuple(
llvm::any_of(types, [](Type t) { return t.isa<TensorType>(); }),
llvm::any_of(types, [](Type t) { return t.isa<VectorType>(); }));
}
static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
@ -157,55 +162,57 @@ static bool areCompatibleShapes(ArrayRef<int64_t> shape1,
return true;
}
static std::string getShapeString(ArrayRef<int64_t> shape) {
// TODO: should replace with printing shape more uniformly across here and
// when in type.
return formatv("'{0:$[x]}'", llvm::make_range(shape.begin(), shape.end()));
}
LogicalResult OpTrait::impl::verifyCompatibleOperandBroadcast(Operation *op) {
assert(op->getNumOperands() == 2 &&
"only support broadcast check on two operands");
assert(op->getNumResults() == 1 &&
"only support broadcast check on one result");
auto type1 = op->getOperand(0).getType();
auto type2 = op->getOperand(1).getType();
auto retType = op->getResult(0).getType();
// We forbid broadcasting vector and tensor.
if (hasBothVectorAndTensorType({type1, type2, retType}))
// Ensure broadcasting only tensor or only vector types.
auto operandsHasTensorVectorType =
hasTensorOrVectorType(op->getOperandTypes());
auto resultsHasTensorVectorType = hasTensorOrVectorType(op->getResultTypes());
if ((std::get<0>(operandsHasTensorVectorType) ||
std::get<0>(resultsHasTensorVectorType)) &&
(std::get<1>(operandsHasTensorVectorType) ||
std::get<1>(resultsHasTensorVectorType)))
return op->emitError("cannot broadcast vector with tensor");
if (retType.isa<UnrankedTensorType>())
auto rankedOperands = make_filter_range(
op->getOperandTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
// If all operands are unranked, then all result shapes are possible.
if (rankedOperands.empty())
return success();
bool isUnranked1 = type1.isa<UnrankedTensorType>();
bool isUnranked2 = type2.isa<UnrankedTensorType>();
// If both operands are unranked, then all result shapes are possible.
if (isUnranked1 && isUnranked2)
return success();
// If one of the operands is unranked, then the known dimensions in the result
// should be compatible with the other shaped operand.
if (isUnranked1 || isUnranked2) {
// Result should have higher rank than the shaped operand's rank and then
// the result's trailing dimensions should be compatible with the operand
// shape.
ArrayRef<int64_t> shape = getShape(!isUnranked1 ? type1 : type2);
ArrayRef<int64_t> actualSuffix = getShape(retType).take_back(shape.size());
if (!areCompatibleShapes(actualSuffix, shape))
return op->emitOpError()
<< "result type " << retType
<< " has shape incompatible with a ranked operand type";
return success();
// Compute broadcasted shape of operands (which requires that operands are
// broadcast compatible). The results need to be broadcast compatible with
// this result shape.
SmallVector<int64_t, 4> resultShape;
(void)util::getBroadcastedShape(getShape(*rankedOperands.begin()), {},
resultShape);
for (auto other : make_early_inc_range(rankedOperands)) {
SmallVector<int64_t, 4> temp = resultShape;
if (!util::getBroadcastedShape(temp, getShape(other), resultShape))
return op->emitOpError("operands don't have broadcast-compatible shapes");
}
// If both operands are shaped, then the computed broadcasted shape should be
// compatible with the result shape.
SmallVector<int64_t, 4> resultShape;
if (!util::getBroadcastedShape(getShape(type1), getShape(type2), resultShape))
return op->emitOpError("operands don't have broadcast-compatible shapes");
auto rankedResults = make_filter_range(
op->getResultTypes(), [](Type t) { return t.isa<RankedTensorType>(); });
if (!areCompatibleShapes(resultShape, getShape(retType)))
return op->emitOpError() << "result type " << retType
<< " does not have shape compatible with the one "
"computed from the operand types";
// If all of the results are unranked then no further verfication.
if (rankedResults.empty())
return success();
for (auto type : rankedResults) {
ArrayRef<int64_t> actualSuffix =
getShape(type).take_back(resultShape.size());
if (!areCompatibleShapes(actualSuffix, resultShape))
return op->emitOpError()
<< "result type " << getShapeString(getShape(type))
<< " not broadcast compatible with broadcasted operands's shapes "
<< getShapeString(resultShape);
}
return success();
}

View File

@ -78,7 +78,7 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x3xi32>) -> tens
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<3x1xi32>):
// expected-error @+1 {{does not have shape compatible with the one computed}}
// expected-error @+1 {{op result type '4x3x3' not broadcast compatible with broadcasted operands's shapes '4x3x2'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<3x1xi32>) -> tensor<4x3x3xi32>
return %0 : tensor<4x3x3xi32>
}
@ -88,7 +88,7 @@ func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<3x1xi32>) -> tens
// Check incompatible result type with known dimension
func @broadcast_tensor_tensor_tensor(tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32> {
^bb0(%arg0: tensor<8x1x6x1xi32>, %arg1: tensor<7x1x5xi32>):
// expected-error @+1 {{does not have shape compatible with the one computed}}
// expected-error @+1 {{op result type '8x7x6x1' not broadcast compatible with broadcasted operands's shapes '8x7x6x5'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<8x1x6x1xi32>, tensor<7x1x5xi32>) -> tensor<8x7x6x1xi32>
return %0 : tensor<8x7x6x1xi32>
}
@ -123,7 +123,7 @@ func @broadcast_tensor_tensor_tensor(tensor<*xi32>, tensor<*xi32>) -> tensor<2xi
// Unranked operand and compatible ranked result
func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
%0 = "test.broadcastable"(%arg0, %arg0, %arg1) : (tensor<3x2xi32>, tensor<3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}
@ -131,7 +131,7 @@ func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<4
func @broadcast_tensor_tensor_tensor(tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32> {
^bb0(%arg0: tensor<3x2xi32>, %arg1: tensor<*xi32>):
// expected-error @+1 {{shape incompatible with a ranked operand type}}
// expected-error @+1 {{op result type '2' not broadcast compatible with broadcasted operands's shapes '3x2'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<3x2xi32>, tensor<*xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32>
}

View File

@ -376,8 +376,8 @@ def IfFirstOperandIsNoneThenSoIsSecond :
let arguments = (ins AnyType:$x, AnyType:$y);
}
def BroadcastableOp : TEST_Op<"broadcastable", [Broadcastable]> {
let arguments = (ins AnyTensor, AnyTensor);
def BroadcastableOp : TEST_Op<"broadcastable", [ResultsBroadcastableShape]> {
let arguments = (ins Variadic<AnyTensor>);
let results = (outs AnyTensor);
}

View File

@ -781,11 +781,17 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex,
return resultValue;
}
bool isBroadcastable =
resultOp.getTrait("OpTrait::BroadcastableTwoOperandsOneResult");
// TODO: Remove once broadcastable has been updated. This query here is not
// really about broadcastable or not, it is about which build method to invoke
// and that requires knowledge of whether ODS generated a builder that need
// not take return types. That knowledge should be captured in one place
// rather than duplicated.
bool isResultsBroadcastableShape =
resultOp.getTrait("OpTrait::ResultsBroadcastableShape");
bool usePartialResults = valuePackName != resultValue;
if (isBroadcastable || usePartialResults || depth > 0 || resultIndex < 0) {
if (isResultsBroadcastableShape || usePartialResults || depth > 0 ||
resultIndex < 0) {
// For these cases (broadcastable ops, op results used both as auxiliary
// values and replacement values, ops in nested patterns, auxiliary ops), we
// still need to supply the result types when building the op. But because