[mlir] Add I1 support to DenseArrayAttr
This patch adds a DenseI1ArrayAttr to support arrays of i1. Importantly, the implementation is as a simple `ArrayRef<bool>` instead of using bit compression, which was problematic in DenseElementsAttr. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D130957
This commit is contained in:
parent
74940d2668
commit
d0541b4700
|
@ -791,8 +791,11 @@ public:
|
|||
static bool classof(Attribute attr);
|
||||
};
|
||||
template <>
|
||||
void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const;
|
||||
template <>
|
||||
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const;
|
||||
|
||||
extern template class DenseArrayAttr<bool>;
|
||||
extern template class DenseArrayAttr<int8_t>;
|
||||
extern template class DenseArrayAttr<int16_t>;
|
||||
extern template class DenseArrayAttr<int32_t>;
|
||||
|
@ -802,6 +805,7 @@ extern template class DenseArrayAttr<double>;
|
|||
} // namespace detail
|
||||
|
||||
// Public name for all the supported DenseArrayAttr
|
||||
using DenseBoolArrayAttr = detail::DenseArrayAttr<bool>;
|
||||
using DenseI8ArrayAttr = detail::DenseArrayAttr<int8_t>;
|
||||
using DenseI16ArrayAttr = detail::DenseArrayAttr<int16_t>;
|
||||
using DenseI32ArrayAttr = detail::DenseArrayAttr<int32_t>;
|
||||
|
|
|
@ -180,7 +180,7 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
|||
ArrayRefParameter<"char">:$elements);
|
||||
let extraClassDeclaration = [{
|
||||
// All possible supported element type.
|
||||
enum class EltType { I8, I16, I32, I64, F32, F64 };
|
||||
enum class EltType { I1, I8, I16, I32, I64, F32, F64 };
|
||||
|
||||
/// Allow implicit conversion to ElementsAttr.
|
||||
operator ElementsAttr() const {
|
||||
|
@ -189,7 +189,8 @@ def Builtin_DenseArrayBase : Builtin_Attr<
|
|||
|
||||
/// ElementsAttr implementation.
|
||||
using ContiguousIterableTypesT =
|
||||
std::tuple<int8_t, int16_t, int32_t, int64_t, float, double>;
|
||||
std::tuple<bool, int8_t, int16_t, int32_t, int64_t, float, double>;
|
||||
const bool *value_begin_impl(OverloadToken<bool>) const;
|
||||
const int8_t *value_begin_impl(OverloadToken<int8_t>) const;
|
||||
const int16_t *value_begin_impl(OverloadToken<int16_t>) const;
|
||||
const int32_t *value_begin_impl(OverloadToken<int32_t>) const;
|
||||
|
|
|
@ -1282,6 +1282,7 @@ class DenseArrayAttrBase<string denseAttrName, string cppType, string summaryNam
|
|||
let storageType = "::mlir::" # denseAttrName;
|
||||
let returnType = "::llvm::ArrayRef<" # cppType # ">";
|
||||
}
|
||||
def DenseBoolArrayAttr : DenseArrayAttrBase<"DenseBoolArrayAttr", "bool", "i1">;
|
||||
def DenseI8ArrayAttr : DenseArrayAttrBase<"DenseI8ArrayAttr", "int8_t", "i8">;
|
||||
def DenseI16ArrayAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
|
||||
def DenseI32ArrayAttr : DenseArrayAttrBase<"DenseI32ArrayAttr", "int32_t", "i32">;
|
||||
|
|
|
@ -845,6 +845,12 @@ Attribute Parser::parseDenseArrayAttr() {
|
|||
|
||||
if (auto intType = type.dyn_cast<IntegerType>()) {
|
||||
switch (type.getIntOrFloatBitWidth()) {
|
||||
case 1:
|
||||
if (isEmptyList)
|
||||
result = DenseBoolArrayAttr::get(parser.getContext(), {});
|
||||
else
|
||||
result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
case 8:
|
||||
if (isEmptyList)
|
||||
result = DenseI8ArrayAttr::get(parser.getContext(), {});
|
||||
|
@ -870,7 +876,7 @@ Attribute Parser::parseDenseArrayAttr() {
|
|||
result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{});
|
||||
break;
|
||||
default:
|
||||
emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type;
|
||||
emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type;
|
||||
return {};
|
||||
}
|
||||
} else if (auto floatType = type.dyn_cast<FloatType>()) {
|
||||
|
|
|
@ -238,6 +238,15 @@ ParseResult Parser::parseToken(Token::Kind expectedToken,
|
|||
|
||||
/// Parse an optional integer value from the stream.
|
||||
OptionalParseResult Parser::parseOptionalInteger(APInt &result) {
|
||||
// Parse `false` and `true` keywords as 0 and 1 respectively.
|
||||
if (consumeIf(Token::kw_false)) {
|
||||
result = false;
|
||||
return success();
|
||||
} else if (consumeIf(Token::kw_true)) {
|
||||
result = true;
|
||||
return success();
|
||||
}
|
||||
|
||||
Token curToken = getToken();
|
||||
if (curToken.isNot(Token::integer, Token::minus))
|
||||
return llvm::None;
|
||||
|
|
|
@ -1860,26 +1860,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
|
|||
}
|
||||
} else if (auto denseArrayAttr = attr.dyn_cast<DenseArrayBaseAttr>()) {
|
||||
typeElision = AttrTypeElision::Must;
|
||||
switch (denseArrayAttr.getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
os << "[:i8";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I16:
|
||||
os << "[:i16";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I32:
|
||||
os << "[:i32";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I64:
|
||||
os << "[:i64";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F32:
|
||||
os << "[:f32";
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::F64:
|
||||
os << "[:f64";
|
||||
break;
|
||||
}
|
||||
os << "[:" << denseArrayAttr.getType().getElementType();
|
||||
if (denseArrayAttr.size())
|
||||
os << " ";
|
||||
denseArrayAttr.printWithoutBraces(os);
|
||||
|
|
|
@ -732,6 +732,9 @@ DenseArrayBaseAttr::EltType DenseArrayBaseAttr::getElementType() const {
|
|||
|
||||
ShapedType DenseArrayBaseAttr::getType() const { return getImpl()->type; }
|
||||
|
||||
const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
|
||||
return cast<DenseBoolArrayAttr>().asArrayRef().begin();
|
||||
}
|
||||
const int8_t *
|
||||
DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
|
||||
return cast<DenseI8ArrayAttr>().asArrayRef().begin();
|
||||
|
@ -762,6 +765,9 @@ void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
|
|||
|
||||
void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
|
||||
switch (getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I1:
|
||||
this->cast<DenseBoolArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
this->cast<DenseI8ArrayAttr>().printWithoutBraces(os);
|
||||
return;
|
||||
|
@ -797,15 +803,20 @@ void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
|
|||
|
||||
template <typename T>
|
||||
void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
|
||||
ArrayRef<T> values{*this};
|
||||
llvm::interleaveComma(values, os);
|
||||
llvm::interleaveComma(asArrayRef(), os);
|
||||
}
|
||||
|
||||
/// Specialization for bool to print `true` or `false`.
|
||||
template <>
|
||||
void DenseArrayAttr<bool>::printWithoutBraces(raw_ostream &os) const {
|
||||
llvm::interleaveComma(asArrayRef(), os,
|
||||
[&](bool v) { os << (v ? "true" : "false"); });
|
||||
}
|
||||
|
||||
/// Specialization for int8_t for forcing printing as number instead of chars.
|
||||
template <>
|
||||
void DenseArrayAttr<int8_t>::printWithoutBraces(raw_ostream &os) const {
|
||||
ArrayRef<int8_t> values{*this};
|
||||
llvm::interleaveComma(values, os, [&](int64_t v) { os << v; });
|
||||
llvm::interleaveComma(asArrayRef(), os, [&](int64_t v) { os << v; });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -816,7 +827,7 @@ void DenseArrayAttr<T>::print(raw_ostream &os) const {
|
|||
}
|
||||
|
||||
/// Parse a single element: generic template for int types, specialized for
|
||||
/// floating points below.
|
||||
/// floating point and boolean values below.
|
||||
template <typename T>
|
||||
static ParseResult parseDenseArrayAttrElt(AsmParser &parser, T &value) {
|
||||
return parser.parseInteger(value);
|
||||
|
@ -880,6 +891,14 @@ namespace {
|
|||
template <typename T>
|
||||
struct denseArrayAttrEltTypeBuilder;
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<bool> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I1;
|
||||
static ShapedType getShapedType(MLIRContext *context,
|
||||
ArrayRef<int64_t> shape) {
|
||||
return RankedTensorType::get(shape, IntegerType::get(context, 1));
|
||||
}
|
||||
};
|
||||
template <>
|
||||
struct denseArrayAttrEltTypeBuilder<int8_t> {
|
||||
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
|
||||
static ShapedType getShapedType(MLIRContext *context,
|
||||
|
@ -953,6 +972,7 @@ bool DenseArrayAttr<T>::classof(Attribute attr) {
|
|||
namespace mlir {
|
||||
namespace detail {
|
||||
// Explicit instantiation for all the supported DenseArrayAttr.
|
||||
template class DenseArrayAttr<bool>;
|
||||
template class DenseArrayAttr<int8_t>;
|
||||
template class DenseArrayAttr<int16_t>;
|
||||
template class DenseArrayAttr<int32_t>;
|
||||
|
|
|
@ -521,13 +521,15 @@ func.func @simple_scalar_example() {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @dense_array_attr
|
||||
func.func @dense_array_attr() attributes{
|
||||
func.func @dense_array_attr() attributes {
|
||||
// CHECK-SAME: emptyf32attr = [:f32],
|
||||
emptyf32attr = [:f32],
|
||||
// CHECK-SAME: emptyf64attr = [:f64],
|
||||
emptyf64attr = [:f64],
|
||||
// CHECK-SAME: emptyi16attr = [:i16],
|
||||
emptyi16attr = [:i16],
|
||||
// CHECK-SAME: emptyi1attr = [:i1],
|
||||
emptyi1attr = [:i1],
|
||||
// CHECK-SAME: emptyi32attr = [:i32],
|
||||
emptyi32attr = [:i32],
|
||||
// CHECK-SAME: emptyi64attr = [:i64],
|
||||
|
@ -540,6 +542,8 @@ func.func @dense_array_attr() attributes{
|
|||
f64attr = [:f64 -142.],
|
||||
// CHECK-SAME: i16attr = [:i16 3, 5, -4, 10],
|
||||
i16attr = [:i16 3, 5, -4, 10],
|
||||
// CHECK-SAME: i1attr = [:i1 true, false, true],
|
||||
i1attr = [:i1 true, false, true],
|
||||
// CHECK-SAME: i32attr = [:i32 1024, 453, -6435],
|
||||
i32attr = [:i32 1024, 453, -6435],
|
||||
// CHECK-SAME: i64attr = [:i64 -142],
|
||||
|
@ -549,6 +553,8 @@ func.func @dense_array_attr() attributes{
|
|||
} {
|
||||
// CHECK: test.dense_array_attr
|
||||
test.dense_array_attr
|
||||
// CHECK-SAME: i1attr = [true, false, true]
|
||||
i1attr = [true, false, true]
|
||||
// CHECK-SAME: i8attr = [1, -2, 3]
|
||||
i8attr = [1, -2, 3]
|
||||
// CHECK-SAME: i16attr = [3, 5, -4, 10]
|
||||
|
|
|
@ -27,6 +27,8 @@ arith.constant dense<[10, 11, 12, 13, 14]> : tensor<5xi64>
|
|||
// expected-error@below {{Test iterating `IntegerAttr`: }}
|
||||
arith.constant dense<> : tensor<0xi64>
|
||||
|
||||
// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}}
|
||||
arith.constant [:i1 true, false, true, false, true, false]
|
||||
// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}}
|
||||
arith.constant [:i8 10, 11, -12, 13, 14]
|
||||
// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}}
|
||||
|
|
|
@ -272,6 +272,7 @@ def StringElementsAttrOp : TEST_Op<"string_elements_attr"> {
|
|||
|
||||
def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
|
||||
let arguments = (ins
|
||||
DenseBoolArrayAttr:$i1attr,
|
||||
DenseI8ArrayAttr:$i8attr,
|
||||
DenseI16ArrayAttr:$i16attr,
|
||||
DenseI32ArrayAttr:$i32attr,
|
||||
|
@ -281,10 +282,9 @@ def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> {
|
|||
DenseI32ArrayAttr:$emptyattr
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
`i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr
|
||||
`i64attr` `=` $i64attr `f32attr` `=` $f32attr `f64attr` `=` $f64attr
|
||||
`emptyattr` `=` $emptyattr
|
||||
attr-dict
|
||||
`i1attr` `=` $i1attr `i8attr` `=` $i8attr `i16attr` `=` $i16attr
|
||||
`i32attr` `=` $i32attr `i64attr` `=` $i64attr `f32attr` `=` $f32attr
|
||||
`f64attr` `=` $f64attr `emptyattr` `=` $emptyattr attr-dict
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ struct TestElementsAttrInterface
|
|||
if (auto concreteAttr =
|
||||
attr.getValue().dyn_cast<DenseArrayBaseAttr>()) {
|
||||
switch (concreteAttr.getElementType()) {
|
||||
case DenseArrayBaseAttr::EltType::I1:
|
||||
testElementsAttrIteration<bool>(op, elementsAttr, "bool");
|
||||
break;
|
||||
case DenseArrayBaseAttr::EltType::I8:
|
||||
testElementsAttrIteration<int8_t>(op, elementsAttr, "int8_t");
|
||||
break;
|
||||
|
|
Loading…
Reference in New Issue