Remove the explicit attribute kinds for DenseIntElementsAttr and DenseFPElementsAttr in favor of just one DenseElementsAttr. Now that attribute has the ability to define 'classof(Attribute attr)' methods, these derived classes can just be specializations of the main attribute class.
PiperOrigin-RevId: 251948820
This commit is contained in:
parent
cc8a8fa76a
commit
b790a2f396
|
@ -154,8 +154,7 @@ enum Kind {
|
|||
Function,
|
||||
|
||||
SplatElements,
|
||||
DenseIntElements,
|
||||
DenseFPElements,
|
||||
DenseElements,
|
||||
OpaqueElements,
|
||||
SparseElements,
|
||||
FIRST_ELEMENTS_ATTR = SplatElements,
|
||||
|
@ -497,10 +496,11 @@ public:
|
|||
|
||||
/// An attribute that represents a reference to a dense vector or tensor object.
|
||||
///
|
||||
class DenseElementsAttr : public ElementsAttr {
|
||||
class DenseElementsAttr
|
||||
: public Attribute::AttrBase<DenseElementsAttr, ElementsAttr,
|
||||
detail::DenseElementsAttributeStorage> {
|
||||
public:
|
||||
using ElementsAttr::ElementsAttr;
|
||||
using ImplType = detail::DenseElementsAttributeStorage;
|
||||
using Base::Base;
|
||||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type. 'type' must be a vector or tensor
|
||||
|
@ -547,8 +547,7 @@ public:
|
|||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
return attr.getKind() == StandardAttributes::DenseIntElements ||
|
||||
attr.getKind() == StandardAttributes::DenseFPElements;
|
||||
return attr.getKind() == StandardAttributes::DenseElements;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
@ -609,15 +608,13 @@ protected:
|
|||
|
||||
/// An attribute that represents a reference to a dense integer vector or tensor
|
||||
/// object.
|
||||
class DenseIntElementsAttr
|
||||
: public Attribute::AttrBase<DenseIntElementsAttr, DenseElementsAttr,
|
||||
detail::DenseElementsAttributeStorage> {
|
||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
/// DenseIntElementsAttr iterates on APInt, so we can use the raw element
|
||||
/// iterator directly.
|
||||
using iterator = DenseElementsAttr::RawElementIterator;
|
||||
|
||||
using Base::Base;
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
||||
|
@ -645,17 +642,13 @@ public:
|
|||
iterator begin() const { return raw_begin(); }
|
||||
iterator end() const { return raw_end(); }
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::DenseIntElements;
|
||||
}
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr);
|
||||
};
|
||||
|
||||
/// An attribute that represents a reference to a dense float vector or tensor
|
||||
/// object. Each element is stored as a double.
|
||||
class DenseFPElementsAttr
|
||||
: public Attribute::AttrBase<DenseFPElementsAttr, DenseElementsAttr,
|
||||
detail::DenseElementsAttributeStorage> {
|
||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
/// DenseFPElementsAttr iterates on APFloat, so we need to wrap the raw
|
||||
/// element iterator.
|
||||
|
@ -669,7 +662,7 @@ public:
|
|||
};
|
||||
using iterator = ElementIterator;
|
||||
|
||||
using Base::Base;
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using DenseElementsAttr::get;
|
||||
using DenseElementsAttr::getValues;
|
||||
|
||||
|
@ -692,10 +685,8 @@ public:
|
|||
iterator begin() const;
|
||||
iterator end() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardAttributes::DenseFPElements;
|
||||
}
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr);
|
||||
};
|
||||
|
||||
/// An opaque attribute that represents a reference to a vector or tensor
|
||||
|
|
|
@ -692,8 +692,7 @@ void ModulePrinter::printAttributeOptionalType(Attribute attr,
|
|||
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
|
||||
break;
|
||||
}
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseFPElements: {
|
||||
case StandardAttributes::DenseElements: {
|
||||
auto eltsAttr = attr.cast<DenseElementsAttr>();
|
||||
os << "dense<";
|
||||
printType(eltsAttr.getType());
|
||||
|
|
|
@ -367,8 +367,7 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
switch (getKind()) {
|
||||
case StandardAttributes::SplatElements:
|
||||
return cast<SplatElementsAttr>().getValue();
|
||||
case StandardAttributes::DenseFPElements:
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseElements:
|
||||
return cast<DenseElementsAttr>().getValue(index);
|
||||
case StandardAttributes::OpaqueElements:
|
||||
return cast<OpaqueElementsAttr>().getValue(index);
|
||||
|
@ -383,8 +382,7 @@ ElementsAttr ElementsAttr::mapValues(
|
|||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseFPElements:
|
||||
case StandardAttributes::DenseElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
case StandardAttributes::SplatElements:
|
||||
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
|
||||
|
@ -397,8 +395,7 @@ ElementsAttr ElementsAttr::mapValues(
|
|||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements:
|
||||
case StandardAttributes::DenseFPElements:
|
||||
case StandardAttributes::DenseElements:
|
||||
return cast<DenseElementsAttr>().mapValues(newElementType, mapping);
|
||||
case StandardAttributes::SplatElements:
|
||||
return cast<SplatElementsAttr>().mapValues(newElementType, mapping);
|
||||
|
@ -542,19 +539,8 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<char> data) {
|
|||
assert((type.isa<RankedTensorType>() || type.isa<VectorType>()) &&
|
||||
"type must be ranked tensor or vector");
|
||||
assert(type.hasStaticShape() && "type must have static shape");
|
||||
switch (type.getElementType().getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64:
|
||||
return AttributeUniquer::get<DenseFPElementsAttr>(
|
||||
type.getContext(), StandardAttributes::DenseFPElements, type, data);
|
||||
case StandardTypes::Integer:
|
||||
return AttributeUniquer::get<DenseIntElementsAttr>(
|
||||
type.getContext(), StandardAttributes::DenseIntElements, type, data);
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
return Base::get(type.getContext(), StandardAttributes::DenseElements, type,
|
||||
data);
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::get(ShapedType type,
|
||||
|
@ -631,22 +617,17 @@ Attribute DenseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
readBits(getRawData().data(), valueIndex * storageBitWidth, bitWidth);
|
||||
|
||||
// Convert the raw value data to an attribute value.
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements:
|
||||
if (elementType.isa<IntegerType>())
|
||||
return IntegerAttr::get(elementType, rawValueData);
|
||||
case StandardAttributes::DenseFPElements:
|
||||
return FloatAttr::get(
|
||||
elementType, APFloat(elementType.cast<FloatType>().getFloatSemantics(),
|
||||
rawValueData));
|
||||
default:
|
||||
if (auto fType = elementType.dyn_cast<FloatType>())
|
||||
return FloatAttr::get(elementType,
|
||||
APFloat(fType.getFloatSemantics(), rawValueData));
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
}
|
||||
|
||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementType = getType().getElementType();
|
||||
switch (getKind()) {
|
||||
case StandardAttributes::DenseIntElements: {
|
||||
if (elementType.isa<IntegerType>()) {
|
||||
// Get the raw APInt values.
|
||||
SmallVector<APInt, 8> intValues;
|
||||
cast<DenseIntElementsAttr>().getValues(intValues);
|
||||
|
@ -656,7 +637,7 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|||
values.push_back(IntegerAttr::get(elementType, intVal));
|
||||
return;
|
||||
}
|
||||
case StandardAttributes::DenseFPElements: {
|
||||
if (elementType.isa<FloatType>()) {
|
||||
// Get the raw APFloat values.
|
||||
SmallVector<APFloat, 8> floatValues;
|
||||
cast<DenseFPElementsAttr>().getValues(floatValues);
|
||||
|
@ -666,10 +647,8 @@ void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
|||
values.push_back(FloatAttr::get(elementType, floatVal));
|
||||
return;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
}
|
||||
|
||||
DenseElementsAttr DenseElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
|
@ -810,6 +789,12 @@ DenseElementsAttr DenseIntElementsAttr::mapValues(
|
|||
return get(newArrayType, elementData);
|
||||
}
|
||||
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
bool DenseIntElementsAttr::classof(Attribute attr) {
|
||||
return attr.isa<DenseElementsAttr>() &&
|
||||
attr.getType().cast<ShapedType>().getElementType().isa<IntegerType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DenseFPElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -859,6 +844,12 @@ DenseFPElementsAttr::iterator DenseFPElementsAttr::end() const {
|
|||
return {elementSemantics, raw_end()};
|
||||
}
|
||||
|
||||
/// Method for supporting type inquiry through isa, cast and dyn_cast.
|
||||
bool DenseFPElementsAttr::classof(Attribute attr) {
|
||||
return attr.isa<DenseElementsAttr>() &&
|
||||
attr.getType().cast<ShapedType>().getElementType().isa<FloatType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpaqueElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -135,9 +135,9 @@ namespace {
|
|||
/// the IR.
|
||||
struct BuiltinDialect : public Dialect {
|
||||
BuiltinDialect(MLIRContext *context) : Dialect(/*name=*/"", context) {
|
||||
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseIntElementsAttr,
|
||||
DenseFPElementsAttr, DictionaryAttr, FloatAttr, FunctionAttr,
|
||||
IntegerAttr, IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
|
||||
addAttributes<AffineMapAttr, ArrayAttr, BoolAttr, DenseElementsAttr,
|
||||
DictionaryAttr, FloatAttr, FunctionAttr, IntegerAttr,
|
||||
IntegerSetAttr, OpaqueAttr, OpaqueElementsAttr,
|
||||
SparseElementsAttr, SplatElementsAttr, StringAttr, TypeAttr,
|
||||
UnitAttr>();
|
||||
addTypes<ComplexType, FloatType, FunctionType, IndexType, IntegerType,
|
||||
|
|
|
@ -116,7 +116,7 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
|
|||
auto expectedTensorType = realValue.getType().cast<TensorType>();
|
||||
EXPECT_EQ(tensorType.getShape(), expectedTensorType.getShape());
|
||||
EXPECT_EQ(tensorType.getElementType(), convertedType);
|
||||
EXPECT_EQ(returnedValue.getKind(), StandardAttributes::DenseIntElements);
|
||||
EXPECT_TRUE(returnedValue.isa<DenseIntElementsAttr>());
|
||||
|
||||
// Check Elements attribute element value is expected.
|
||||
auto firstValue = returnedValue.cast<ElementsAttr>().getValue({0, 0});
|
||||
|
|
Loading…
Reference in New Issue