Handle dynamic shapes in Broadcastable op trait

That allows TensorFlow Add and Div ops to use Broadcastable op trait instead of
more restrictive SameValueType op trait.

That in turn allows TensorFlow ops to be registered by defining GET_OP_LIST and
including the generated ops file. Currently, tf-raise-control-flow pass tests
are using dynamic shapes in tf.Add op and AddOp can't be registered without
supporting the dynamic shapes.

TESTED with unit tests

PiperOrigin-RevId: 232927998
This commit is contained in:
Smit Hinsu 2019-02-07 12:56:12 -08:00 committed by jpienaar
parent 13a45c7194
commit c201e6ef05
3 changed files with 30 additions and 7 deletions

View File

@ -37,8 +37,9 @@ bool verifyCompatibleOperandBroadcast(const Instruction *op);
namespace util {
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returns null type if the two given
/// types are not broadcast-compatible.
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type getBroadcastedType(Type type1, Type type2);
} // namespace util
@ -46,7 +47,10 @@ Type getBroadcastedType(Type type1, Type type2);
/// compatible operand and result types. Specifically, starting from the
/// most varying dimension, each dimension pair of the two operands' types
/// should either be the same or one of them is one. Also, the result type
/// should be the same as the operand type with larger dimensions.
/// should be the same as the operand type with larger dimensions. Shapes are
/// checked partially if ranks or dimensions are not known. For example,
/// an op with tensor<? x 2 x f32> and tensor <2 x f32> as operand types and
/// tensor<3 x 2 x f32> as the result type is broadcast-compatible.
///
/// Ths trait assumes the op has two operands and one result, and it asserts
/// if the pre-condition is not satisfied.

View File

@ -33,7 +33,8 @@ static bool isBroadcastableType(Type type) {
case StandardTypes::Vector:
return true;
case StandardTypes::RankedTensor:
return type.cast<RankedTensorType>().getElementType().isIntOrFloat();
case StandardTypes::UnrankedTensor:
return type.cast<TensorType>().getElementType().isIntOrFloat();
default:
break;
}
@ -41,8 +42,9 @@ static bool isBroadcastableType(Type type) {
}
/// Returns the result broadcast composition type from the two given types by
/// following NumPy broadcast semantics. Returns null type if the two given
/// types are not broadcast-compatible.
/// following NumPy broadcast semantics. Returned type may have dynamic shape if
/// either of the input types has dynamic shape. Returns null type if the two
/// given types are not broadcast-compatible.
Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
// Make sure both types are able to participate in broadcasting.
if (!isBroadcastableType(type1) || !isBroadcastableType(type2))
@ -60,6 +62,14 @@ Type OpTrait::util::getBroadcastedType(Type type1, Type type2) {
if (scalarType != getScalarType(type2))
return {};
// If one of the types is unranked tensor, then the other type shouldn't be
// vector and the result should have unranked tensor type.
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>()) {
if (type1.isa<VectorType>() || type2.isa<VectorType>())
return {};
return UnrankedTensorType::get(scalarType);
}
// Returns the type kind if the given type is a vector or ranked tensor type.
// Returns llvm::None otherwise.
auto getCompositeTypeKind =

View File

@ -167,9 +167,18 @@ func @broadcast_tensor_tensor_tensor(tensor<vector<4xi32>>, tensor<vector<4xi32>
// -----
// Check unranked types
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<*xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<*xi32>):
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// -----
// Check unranked operand but ranked result
func @broadcast_tensor_tensor_tensor(tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32> {
^bb0(%arg0: tensor<4x3x2xi32>, %arg1: tensor<*xi32>):
// expected-error @+1 {{operands don't have broadcast-compatible types}}
// expected-error @+1 {{result type is not broadcast-compatible with operand types}}
%0 = "tfl.add"(%arg0, %arg1) {fused_activation_function: "RELU6"} : (tensor<4x3x2xi32>, tensor<*xi32>) -> tensor<4x3x2xi32>
return %0 : tensor<4x3x2xi32>
}