Implement value type abstraction for attributes.

This is done by changing Attribute to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 218764173
This commit is contained in:
River Riddle 2018-10-25 15:46:10 -07:00 committed by jpienaar
parent 64d52014bd
commit 792d1c25e4
26 changed files with 971 additions and 692 deletions

View File

@ -95,8 +95,8 @@ public:
/// Folds the results of the application of an affine map on the provided /// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens, /// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise. /// true otherwise.
bool constantFold(ArrayRef<Attribute *> operandConstants, bool constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute *> &results) const; SmallVectorImpl<Attribute> &results) const;
friend ::llvm::hash_code hash_value(AffineMap arg); friend ::llvm::hash_code hash_value(AffineMap arg);

View File

@ -30,10 +30,32 @@ class MLIRContext;
class Type; class Type;
class VectorOrTensorType; class VectorOrTensorType;
namespace detail {
struct AttributeStorage;
struct BoolAttributeStorage;
struct IntegerAttributeStorage;
struct FloatAttributeStorage;
struct StringAttributeStorage;
struct ArrayAttributeStorage;
struct AffineMapAttributeStorage;
struct TypeAttributeStorage;
struct FunctionAttributeStorage;
struct ElementsAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
struct DenseFPElementsAttributeStorage;
struct OpaqueElementsAttributeStorage;
struct SparseElementsAttributeStorage;
} // namespace detail
/// Attributes are known-constant values of operations and functions. /// Attributes are known-constant values of operations and functions.
/// ///
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned /// Instances of the Attribute class are immutable, uniqued, immortal, and owned
/// by MLIRContext. As such, they are passed around by raw non-const pointer. /// by MLIRContext. As such, an Attribute is a POD interface to an underlying
/// storage pointer.
class Attribute { class Attribute {
public: public:
enum class Kind { enum class Kind {
@ -55,177 +77,151 @@ public:
LAST_ELEMENTS_ATTR = SparseElements, LAST_ELEMENTS_ATTR = SparseElements,
}; };
typedef detail::AttributeStorage ImplType;
Attribute() : attr(nullptr) {}
/* implicit */ Attribute(const ImplType *attr)
: attr(const_cast<ImplType *>(attr)) {}
Attribute(const Attribute &other) : attr(other.attr) {}
Attribute &operator=(Attribute other) {
attr = other.attr;
return *this;
}
bool operator==(Attribute other) const { return attr == other.attr; }
bool operator!=(Attribute other) const { return !(*this == other); }
explicit operator bool() const { return attr; }
bool operator!() const { return attr == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
/// Return the classification for this attribute. /// Return the classification for this attribute.
Kind getKind() const { return kind; } Kind getKind() const;
/// Return true if this field is, or contains, a function attribute. /// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const { return isOrContainsFunctionCache; } bool isOrContainsFunction() const;
/// Print the attribute. /// Print the attribute.
void print(raw_ostream &os) const; void print(raw_ostream &os) const;
void dump() const; void dump() const;
friend ::llvm::hash_code hash_value(Attribute arg);
protected: protected:
explicit Attribute(Kind kind, bool isOrContainsFunction) ImplType *attr;
: kind(kind), isOrContainsFunctionCache(isOrContainsFunction) {}
~Attribute() {}
private:
/// Classification of the subclass, used for type checking.
Kind kind : 8;
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
Attribute(const Attribute &) = delete;
void operator=(const Attribute &) = delete;
}; };
inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) { inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
attr.print(os); attr.print(os);
return os; return os;
} }
class BoolAttr : public Attribute { class BoolAttr : public Attribute {
public: public:
static BoolAttr *get(bool value, MLIRContext *context); typedef detail::BoolAttributeStorage ImplType;
BoolAttr() = default;
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
bool getValue() const { return value; } static BoolAttr get(bool value, MLIRContext *context);
bool getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Bool; }
return attr->getKind() == Kind::Bool;
}
private:
BoolAttr(bool value)
: Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
~BoolAttr() = delete;
bool value;
}; };
class IntegerAttr : public Attribute { class IntegerAttr : public Attribute {
public: public:
static IntegerAttr *get(int64_t value, MLIRContext *context); typedef detail::IntegerAttributeStorage ImplType;
IntegerAttr() = default;
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
int64_t getValue() const { return value; } static IntegerAttr get(int64_t value, MLIRContext *context);
int64_t getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Integer; }
return attr->getKind() == Kind::Integer;
}
private:
IntegerAttr(int64_t value)
: Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
}
~IntegerAttr() = delete;
int64_t value;
}; };
class FloatAttr final : public Attribute, class FloatAttr final : public Attribute {
public llvm::TrailingObjects<FloatAttr, uint64_t> {
public: public:
static FloatAttr *get(double value, MLIRContext *context); typedef detail::FloatAttributeStorage ImplType;
static FloatAttr *get(const APFloat &value, MLIRContext *context); FloatAttr() = default;
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
static FloatAttr get(double value, MLIRContext *context);
static FloatAttr get(const APFloat &value, MLIRContext *context);
APFloat getValue() const; APFloat getValue() const;
double getDouble() const { return getValue().convertToDouble(); } double getDouble() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Float; }
return attr->getKind() == Kind::Float;
}
private:
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
semantics(semantics), numObjects(numObjects) {}
FloatAttr(const FloatAttr &value) = delete;
~FloatAttr() = delete;
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
return numObjects;
}
const llvm::fltSemantics &semantics;
size_t numObjects;
}; };
class StringAttr : public Attribute { class StringAttr : public Attribute {
public: public:
static StringAttr *get(StringRef bytes, MLIRContext *context); typedef detail::StringAttributeStorage ImplType;
StringAttr() = default;
/* implicit */ StringAttr(Attribute::ImplType *ptr);
StringRef getValue() const { return value; } static StringAttr get(StringRef bytes, MLIRContext *context);
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::String; }
return attr->getKind() == Kind::String;
}
private:
StringAttr(StringRef value)
: Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
~StringAttr() = delete;
StringRef value;
}; };
/// Array attributes are lists of other attributes. They are not necessarily /// Array attributes are lists of other attributes. They are not necessarily
/// type homogenous given that attributes don't, in general, carry types. /// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute { class ArrayAttr : public Attribute {
public: public:
static ArrayAttr *get(ArrayRef<Attribute *> value, MLIRContext *context); typedef detail::ArrayAttributeStorage ImplType;
ArrayAttr() = default;
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
ArrayRef<Attribute *> getValue() const { return value; } static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
ArrayRef<Attribute> getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Array; }
return attr->getKind() == Kind::Array;
}
private:
ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
: Attribute(Kind::Array, isOrContainsFunction), value(value) {}
~ArrayAttr() = delete;
ArrayRef<Attribute *> value;
}; };
class AffineMapAttr : public Attribute { class AffineMapAttr : public Attribute {
public: public:
static AffineMapAttr *get(AffineMap value); typedef detail::AffineMapAttributeStorage ImplType;
AffineMapAttr() = default;
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
AffineMap getValue() const { return value; } static AffineMapAttr get(AffineMap value);
AffineMap getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
return attr->getKind() == Kind::AffineMap;
}
private:
AffineMapAttr(AffineMap value)
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
value(value) {}
~AffineMapAttr() = delete;
AffineMap value;
}; };
class TypeAttr : public Attribute { class TypeAttr : public Attribute {
public: public:
static TypeAttr *get(Type *type, MLIRContext *context); typedef detail::TypeAttributeStorage ImplType;
TypeAttr() = default;
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
Type *getValue() const { return value; } static TypeAttr get(Type *type, MLIRContext *context);
Type *getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Type; }
return attr->getKind() == Kind::Type;
}
private:
TypeAttr(Type *value)
: Attribute(Kind::Type, /*isOrContainsFunction=*/false), value(value) {}
~TypeAttr() = delete;
Type *value;
}; };
/// A function attribute represents a reference to a function object. /// A function attribute represents a reference to a function object.
@ -237,63 +233,53 @@ private:
/// remain in MLIRContext. /// remain in MLIRContext.
class FunctionAttr : public Attribute { class FunctionAttr : public Attribute {
public: public:
static FunctionAttr *get(const Function *value, MLIRContext *context); typedef detail::FunctionAttributeStorage ImplType;
FunctionAttr() = default;
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
Function *getValue() const { return value; } static FunctionAttr get(const Function *value, MLIRContext *context);
Function *getValue() const;
FunctionType *getType() const; FunctionType *getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::Function; }
return attr->getKind() == Kind::Function;
}
/// This function is used by the internals of the Function class to null out /// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted. /// attributes refering to functions that are about to be deleted.
static void dropFunctionReference(Function *value); static void dropFunctionReference(Function *value);
private:
FunctionAttr(Function *value)
: Attribute(Kind::Function, /*isOrContainsFunction=*/true), value(value) {
}
~FunctionAttr() = delete;
Function *value;
}; };
/// A base attribute represents a reference to a vector or tensor constant. /// A base attribute represents a reference to a vector or tensor constant.
class ElementsAttr : public Attribute { class ElementsAttr : public Attribute {
public: public:
ElementsAttr(Kind kind, VectorOrTensorType *type) typedef detail::ElementsAttributeStorage ImplType;
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {} ElementsAttr() = default;
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
VectorOrTensorType *getType() const { return type; } VectorOrTensorType *getType() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) {
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR && return kind >= Kind::FIRST_ELEMENTS_ATTR &&
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR; kind <= Kind::LAST_ELEMENTS_ATTR;
} }
private:
VectorOrTensorType *type;
}; };
/// An attribute represents a reference to a splat vecctor or tensor constant, /// An attribute represents a reference to a splat vecctor or tensor constant,
/// meaning all of the elements have the same value. /// meaning all of the elements have the same value.
class SplatElementsAttr : public ElementsAttr { class SplatElementsAttr : public ElementsAttr {
public: public:
static SplatElementsAttr *get(VectorOrTensorType *type, Attribute *elt); typedef detail::SplatElementsAttributeStorage ImplType;
Attribute *getValue() const { return elt; } SplatElementsAttr() = default;
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
Attribute getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::SplatElements; }
return attr->getKind() == Kind::SplatElements;
}
private:
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
Attribute *elt;
}; };
/// An attribute represents a reference to a dense vector or tensor object. /// An attribute represents a reference to a dense vector or tensor object.
@ -302,42 +288,42 @@ private:
/// than 64. /// than 64.
class DenseElementsAttr : public ElementsAttr { class DenseElementsAttr : public ElementsAttr {
public: public:
typedef detail::DenseElementsAttributeStorage ImplType;
DenseElementsAttr() = default;
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
/// It assumes the elements in the input array have been truncated to the bits /// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type (note all float type are 64 bits). /// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend /// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary. /// to 64 bits if necessary.
static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef<char> data); static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them // TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the // to a character array. Then call the above method to construct the
// attribute. // attribute.
static DenseElementsAttr *get(VectorOrTensorType *type, static DenseElementsAttr get(VectorOrTensorType *type,
ArrayRef<Attribute *> values); ArrayRef<Attribute> values);
void getValues(SmallVectorImpl<Attribute *> &values) const; void getValues(SmallVectorImpl<Attribute> &values) const;
ArrayRef<char> getRawData() const { return data; } ArrayRef<char> getRawData() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) {
return attr->getKind() == Kind::DenseIntElements || return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
attr->getKind() == Kind::DenseFPElements;
} }
protected:
DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef<char> data)
: ElementsAttr(kind, type), data(data) {}
private:
ArrayRef<char> data;
}; };
/// An attribute represents a reference to a dense integer vector or tensor /// An attribute represents a reference to a dense integer vector or tensor
/// object. /// object.
class DenseIntElementsAttr : public DenseElementsAttr { class DenseIntElementsAttr : public DenseElementsAttr {
public: public:
typedef detail::DenseIntElementsAttributeStorage ImplType;
DenseIntElementsAttr() = default;
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
// TODO: returns APInts instead of IntegerAttr. // TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const; void getValues(SmallVectorImpl<Attribute> &values) const;
APInt getValue(ArrayRef<unsigned> indices) const; APInt getValue(ArrayRef<unsigned> indices) const;
@ -352,41 +338,24 @@ public:
size_t bitsWidth); size_t bitsWidth);
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
return attr->getKind() == Kind::DenseIntElements;
}
private:
friend class DenseElementsAttr;
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
size_t bitsWidth)
: DenseElementsAttr(Kind::DenseIntElements, type, data),
bitsWidth(bitsWidth) {}
~DenseIntElementsAttr() = delete;
size_t bitsWidth;
}; };
/// An attribute represents a reference to a dense float vector or tensor /// An attribute represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double. /// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr { class DenseFPElementsAttr : public DenseElementsAttr {
public: public:
typedef detail::DenseFPElementsAttributeStorage ImplType;
DenseFPElementsAttr() = default;
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
// TODO: returns APFPs instead of FloatAttr. // TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const; void getValues(SmallVectorImpl<Attribute> &values) const;
APFloat getValue(ArrayRef<unsigned> indices) const; APFloat getValue(ArrayRef<unsigned> indices) const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
return attr->getKind() == Kind::DenseFPElements;
}
private:
friend class DenseElementsAttr;
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
~DenseFPElementsAttr() = delete;
}; };
/// An attribute represents a reference to a tensor constant with opaque /// An attribute represents a reference to a tensor constant with opaque
@ -394,20 +363,16 @@ private:
/// doesn't need to interpret. /// doesn't need to interpret.
class OpaqueElementsAttr : public ElementsAttr { class OpaqueElementsAttr : public ElementsAttr {
public: public:
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes); typedef detail::OpaqueElementsAttributeStorage ImplType;
OpaqueElementsAttr() = default;
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
StringRef getValue() const { return bytes; } static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
StringRef getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; }
return attr->getKind() == Kind::OpaqueElements;
}
private:
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
~OpaqueElementsAttr() = delete;
StringRef bytes;
}; };
/// An attribute represents a reference to a sparse vector or tensor object. /// An attribute represents a reference to a sparse vector or tensor object.
@ -427,32 +392,67 @@ private:
/// [0, 0, 0, 0]]. /// [0, 0, 0, 0]].
class SparseElementsAttr : public ElementsAttr { class SparseElementsAttr : public ElementsAttr {
public: public:
static SparseElementsAttr *get(VectorOrTensorType *type, typedef detail::SparseElementsAttributeStorage ImplType;
DenseIntElementsAttr *indices, SparseElementsAttr() = default;
DenseElementsAttr *values); /* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
DenseIntElementsAttr *getIndices() const { return indices; } static SparseElementsAttr get(VectorOrTensorType *type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
DenseElementsAttr *getValues() const { return values; } DenseIntElementsAttr getIndices() const;
DenseElementsAttr getValues() const;
/// Return the value at the given index. /// Return the value at the given index.
Attribute *getValue(ArrayRef<unsigned> index) const; Attribute getValue(ArrayRef<unsigned> index) const;
/// Method for support type inquiry through isa, cast and dyn_cast. /// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) { static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
return attr->getKind() == Kind::SparseElements;
}
private:
SparseElementsAttr(VectorOrTensorType *type, DenseIntElementsAttr *indices,
DenseElementsAttr *values)
: ElementsAttr(Kind::SparseElements, type), indices(indices),
values(values) {}
~SparseElementsAttr() = delete;
DenseIntElementsAttr *const indices;
DenseElementsAttr *const values;
}; };
template <typename U> bool Attribute::isa() const {
assert(attr && "isa<> used on a null attribute.");
return U::kindof(getKind());
}
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(attr) : U(nullptr);
}
template <typename U> U Attribute::dyn_cast_or_null() const {
return (attr && isa<U>()) ? U(attr) : U(nullptr);
}
template <typename U> U Attribute::cast() const {
assert(isa<U>());
return U(attr);
}
// Make Attribute hashable.
inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.attr);
}
} // end namespace mlir. } // end namespace mlir.
namespace llvm {
// Attribute hash just like pointers
template <> struct DenseMapInfo<mlir::Attribute> {
static mlir::Attribute getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static mlir::Attribute getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::Attribute val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif #endif

View File

@ -93,23 +93,23 @@ public:
UnrankedTensorType *getTensorType(Type *elementType); UnrankedTensorType *getTensorType(Type *elementType);
// Attributes. // Attributes.
BoolAttr *getBoolAttr(bool value);
IntegerAttr *getIntegerAttr(int64_t value); BoolAttr getBoolAttr(bool value);
FloatAttr *getFloatAttr(double value); IntegerAttr getIntegerAttr(int64_t value);
FloatAttr *getFloatAttr(const APFloat &value); FloatAttr getFloatAttr(double value);
StringAttr *getStringAttr(StringRef bytes); FloatAttr getFloatAttr(const APFloat &value);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value); StringAttr getStringAttr(StringRef bytes);
AffineMapAttr *getAffineMapAttr(AffineMap map); ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
TypeAttr *getTypeAttr(Type *type); AffineMapAttr getAffineMapAttr(AffineMap map);
FunctionAttr *getFunctionAttr(const Function *value); TypeAttr getTypeAttr(Type *type);
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt); FunctionAttr getFunctionAttr(const Function *value);
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type, ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
ArrayRef<char> data); ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type, ArrayRef<char> data);
DenseIntElementsAttr *indices, ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
DenseElementsAttr *values); DenseIntElementsAttr indices,
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type, DenseElementsAttr values);
StringRef bytes); ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
// Affine expressions and affine maps. // Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position); AffineExpr getAffineDimExpr(unsigned position);

View File

@ -60,7 +60,7 @@ public:
/// Returns the affine map to be applied by this operation. /// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const { AffineMap getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue(); return getAttrOfType<AffineMapAttr>("map").getValue();
} }
/// Returns true if the result of this operation can be used as dimension id. /// Returns true if the result of this operation can be used as dimension id.
@ -75,8 +75,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
bool verify() const; bool verify() const;
bool constantFold(ArrayRef<Attribute *> operands, bool constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute *> &results, SmallVectorImpl<Attribute> &results,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
@ -94,10 +94,10 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> { OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public: public:
/// Builds a constant op with the specified attribute value and result type. /// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Attribute *value, static void build(Builder *builder, OperationState *result, Attribute value,
Type *type); Type *type);
Attribute *getValue() const { return getAttr("value"); } Attribute getValue() const { return getAttr("value"); }
static StringRef getOperationName() { return "constant"; } static StringRef getOperationName() { return "constant"; }
@ -105,8 +105,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result); static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const; void print(OpAsmPrinter *p) const;
bool verify() const; bool verify() const;
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
protected: protected:
friend class Operation; friend class Operation;
@ -125,7 +125,7 @@ public:
const APFloat &value, FloatType *type); const APFloat &value, FloatType *type);
APFloat getValue() const { APFloat getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue(); return getAttrOfType<FloatAttr>("value").getValue();
} }
static bool isClassFor(const Operation *op); static bool isClassFor(const Operation *op);
@ -152,7 +152,7 @@ public:
Type *type); Type *type);
int64_t getValue() const { int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue(); return getAttrOfType<IntegerAttr>("value").getValue();
} }
static bool isClassFor(const Operation *op); static bool isClassFor(const Operation *op);
@ -173,7 +173,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t value); static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const { int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue(); return getAttrOfType<IntegerAttr>("value").getValue();
} }
static bool isClassFor(const Operation *op); static bool isClassFor(const Operation *op);

View File

@ -24,12 +24,12 @@
#ifndef MLIR_IR_FUNCTION_H #ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H #define MLIR_IR_FUNCTION_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h" #include "llvm/ADT/ilist.h"
namespace mlir { namespace mlir {
class Attribute;
class AttributeListStorage; class AttributeListStorage;
class FunctionType; class FunctionType;
class Location; class Location;
@ -39,7 +39,7 @@ class Module;
/// NamedAttribute is used for function attribute lists, it holds an /// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute /// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null. /// pointer should always be non-null.
using NamedAttribute = std::pair<Identifier, Attribute *>; using NamedAttribute = std::pair<Identifier, Attribute>;
/// This is the base class for all of the MLIR function types. /// This is the base class for all of the MLIR function types.
class Function : public llvm::ilist_node_with_parent<Function, Module> { class Function : public llvm::ilist_node_with_parent<Function, Module> {

View File

@ -138,17 +138,16 @@ public:
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); } ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
/// Return an attribute with the specified name. /// Return an attribute with the specified name.
Attribute *getAttr(StringRef name) const { return state->getAttr(name); } Attribute getAttr(StringRef name) const { return state->getAttr(name); }
/// If the operation has an attribute of the specified type, return it. /// If the operation has an attribute of the specified type, return it.
template <typename AttrClass> template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
AttrClass *getAttrOfType(StringRef name) const { return getAttr(name).dyn_cast_or_null<AttrClass>();
return dyn_cast_or_null<AttrClass>(getAttr(name));
} }
/// If the an attribute exists with the specified name, change it to the new /// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value. /// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute *value) { void setAttr(Identifier name, Attribute value) {
state->setAttr(name, value); state->setAttr(name, value);
} }
@ -211,8 +210,8 @@ public:
/// true if folding failed, or returns false and fills in `results` on /// true if folding failed, or returns false and fills in `results` on
/// success. /// success.
static bool constantFoldHook(const Operation *op, static bool constantFoldHook(const Operation *op,
ArrayRef<Attribute *> operands, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results) { SmallVectorImpl<Attribute> &results) {
return op->cast<ConcreteType>()->constantFold(operands, results, return op->cast<ConcreteType>()->constantFold(operands, results,
op->getContext()); op->getContext());
} }
@ -226,8 +225,8 @@ public:
/// ///
/// If not overridden, this fallback implementation always fails to fold. /// If not overridden, this fallback implementation always fails to fold.
/// ///
bool constantFold(ArrayRef<Attribute *> operands, bool constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results, SmallVectorImpl<Attribute> &results,
MLIRContext *context) const { MLIRContext *context) const {
return true; return true;
} }
@ -244,9 +243,9 @@ public:
/// true if folding failed, or returns false and fills in `results` on /// true if folding failed, or returns false and fills in `results` on
/// success. /// success.
static bool constantFoldHook(const Operation *op, static bool constantFoldHook(const Operation *op,
ArrayRef<Attribute *> operands, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results) { SmallVectorImpl<Attribute> &results) {
auto *result = auto result =
op->cast<ConcreteType>()->constantFold(operands, op->getContext()); op->cast<ConcreteType>()->constantFold(operands, op->getContext());
if (!result) if (!result)
return true; return true;
@ -511,8 +510,8 @@ public:
/// ///
/// If not overridden, this fallback implementation always fails to fold. /// If not overridden, this fallback implementation always fails to fold.
/// ///
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
return nullptr; return nullptr;
} }
}; };

View File

@ -69,7 +69,7 @@ public:
} }
virtual void printType(const Type *type) = 0; virtual void printType(const Type *type) = 0;
virtual void printFunctionReference(const Function *func) = 0; virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(const Attribute *attr) = 0; virtual void printAttribute(Attribute attr) = 0;
virtual void printAffineMap(AffineMap map) = 0; virtual void printAffineMap(AffineMap map) = 0;
virtual void printAffineExpr(AffineExpr expr) = 0; virtual void printAffineExpr(AffineExpr expr) = 0;
@ -100,8 +100,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
return p; return p;
} }
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) { inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
p.printAttribute(&attr); p.printAttribute(attr);
return p; return p;
} }
@ -210,24 +210,24 @@ public:
/// Parse an arbitrary attribute and return it in result. This also adds the /// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this /// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null. /// captures the location of the attribute in 'loc' if it is non-null.
virtual bool parseAttribute(Attribute *&result, const char *attrName, virtual bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0; SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind, capturing the location into `loc` /// Parse an attribute of a specific kind, capturing the location into `loc`
/// if specified. /// if specified.
template <typename AttrType> template <typename AttrType>
bool parseAttribute(AttrType *&result, const char *attrName, bool parseAttribute(AttrType &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) { SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc; llvm::SMLoc loc;
getCurrentLocation(&loc); getCurrentLocation(&loc);
// Parse any kind of attribute. // Parse any kind of attribute.
Attribute *attr; Attribute attr;
if (parseAttribute(attr, attrName, attrs)) if (parseAttribute(attr, attrName, attrs))
return true; return true;
// Check for the right kind of attribute. // Check for the right kind of attribute.
result = dyn_cast<AttrType>(attr); result = attr.dyn_cast<AttrType>();
if (!result) { if (!result) {
emitError(loc, "invalid kind of constant specified"); emitError(loc, "invalid kind of constant specified");
return true; return true;

View File

@ -113,33 +113,31 @@ public:
ArrayRef<NamedAttribute> getAttrs() const; ArrayRef<NamedAttribute> getAttrs() const;
/// Return the specified attribute if present, null otherwise. /// Return the specified attribute if present, null otherwise.
Attribute *getAttr(Identifier name) const { Attribute getAttr(Identifier name) const {
for (auto elt : getAttrs()) for (auto elt : getAttrs())
if (elt.first == name) if (elt.first == name)
return elt.second; return elt.second;
return nullptr; return nullptr;
} }
Attribute *getAttr(StringRef name) const { Attribute getAttr(StringRef name) const {
for (auto elt : getAttrs()) for (auto elt : getAttrs())
if (elt.first.is(name)) if (elt.first.is(name))
return elt.second; return elt.second;
return nullptr; return nullptr;
} }
template <typename AttrClass> template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
AttrClass *getAttrOfType(Identifier name) const { return getAttr(name).dyn_cast_or_null<AttrClass>();
return dyn_cast_or_null<AttrClass>(getAttr(name));
} }
template <typename AttrClass> template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
AttrClass *getAttrOfType(StringRef name) const { return getAttr(name).dyn_cast_or_null<AttrClass>();
return dyn_cast_or_null<AttrClass>(getAttr(name));
} }
/// If the an attribute exists with the specified name, change it to the new /// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value. /// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute *value); void setAttr(Identifier name, Attribute value);
enum class RemoveResult { enum class RemoveResult {
Removed, NotFound Removed, NotFound
@ -250,8 +248,8 @@ public:
/// the operands of the operation, but may be null if non-constant. If /// the operands of the operation, but may be null if non-constant. If
/// constant folding is successful, this returns false and fills in the /// constant folding is successful, this returns false and fills in the
/// `results` vector. If not, this returns true and `results` is unspecified. /// `results` vector. If not, this returns true and `results` is unspecified.
bool constantFold(ArrayRef<Attribute *> operands, bool constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results) const; SmallVectorImpl<Attribute> &results) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast. /// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst); static bool classof(const Instruction *inst);

View File

@ -23,11 +23,11 @@
#ifndef MLIR_IR_OPERATION_SUPPORT_H #ifndef MLIR_IR_OPERATION_SUPPORT_H
#define MLIR_IR_OPERATION_SUPPORT_H #define MLIR_IR_OPERATION_SUPPORT_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h" #include "mlir/IR/Identifier.h"
#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/PointerUnion.h"
namespace mlir { namespace mlir {
class Attribute;
class Dialect; class Dialect;
class Location; class Location;
class Operation; class Operation;
@ -80,8 +80,8 @@ public:
/// This hook implements a constant folder for this operation. It returns /// This hook implements a constant folder for this operation. It returns
/// true if folding failed, or returns false and fills in `results` on /// true if folding failed, or returns false and fills in `results` on
/// success. /// success.
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute *> operands, bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results); SmallVectorImpl<Attribute> &results);
// Returns whether the operation has a particular property. // Returns whether the operation has a particular property.
bool hasProperty(OperationProperty property) const { bool hasProperty(OperationProperty property) const {
@ -110,8 +110,8 @@ private:
void (&printAssembly)(const Operation *op, OpAsmPrinter *p), void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
bool (&verifyInvariants)(const Operation *op), bool (&verifyInvariants)(const Operation *op),
bool (&constantFoldHook)(const Operation *op, bool (&constantFoldHook)(const Operation *op,
ArrayRef<Attribute *> operands, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results)) SmallVectorImpl<Attribute> &results))
: name(name), dialect(dialect), isClassFor(isClassFor), : name(name), dialect(dialect), isClassFor(isClassFor),
parseAssembly(parseAssembly), printAssembly(printAssembly), parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook), verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
@ -124,7 +124,7 @@ private:
/// NamedAttribute is a used for operation attribute lists, it holds an /// NamedAttribute is a used for operation attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute /// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null. /// pointer should always be non-null.
using NamedAttribute = std::pair<Identifier, Attribute *>; using NamedAttribute = std::pair<Identifier, Attribute>;
class OperationName { class OperationName {
public: public:
@ -204,7 +204,7 @@ public:
types.append(newTypes.begin(), newTypes.end()); types.append(newTypes.begin(), newTypes.end());
} }
void addAttribute(StringRef name, Attribute *attr) { void addAttribute(StringRef name, Attribute attr) {
attributes.push_back({Identifier::get(name, context), attr}); attributes.push_back({Identifier::get(name, context), attr});
} }
}; };

View File

@ -49,8 +49,8 @@ class AddFOp
public: public:
static StringRef getOperationName() { return "addf"; } static StringRef getOperationName() { return "addf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;
@ -70,8 +70,8 @@ class AddIOp
public: public:
static StringRef getOperationName() { return "addi"; } static StringRef getOperationName() { return "addi"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;
@ -134,7 +134,7 @@ public:
ArrayRef<SSAValue *> operands); ArrayRef<SSAValue *> operands);
Function *getCallee() const { Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee")->getValue(); return getAttrOfType<FunctionAttr>("callee").getValue();
} }
// Hooks to customize behavior of this op. // Hooks to customize behavior of this op.
@ -218,13 +218,13 @@ public:
static void build(Builder *builder, OperationState *result, static void build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index); SSAValue *memrefOrTensor, unsigned index);
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
/// This returns the dimension number that the 'dim' is inspecting. /// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const { unsigned getIndex() const {
return static_cast<unsigned>( return static_cast<unsigned>(
getAttrOfType<IntegerAttr>("index")->getValue()); getAttrOfType<IntegerAttr>("index").getValue());
} }
static StringRef getOperationName() { return "dim"; } static StringRef getOperationName() { return "dim"; }
@ -513,8 +513,8 @@ class MulFOp
public: public:
static StringRef getOperationName() { return "mulf"; } static StringRef getOperationName() { return "mulf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;
@ -534,8 +534,8 @@ class MulIOp
public: public:
static StringRef getOperationName() { return "muli"; } static StringRef getOperationName() { return "muli"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;
@ -597,8 +597,8 @@ class SubFOp : public BinaryOp<SubFOp, OpTrait::ResultsAreFloatLike,
public: public:
static StringRef getOperationName() { return "subf"; } static StringRef getOperationName() { return "subf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;
@ -617,8 +617,8 @@ class SubIOp : public BinaryOp<SubIOp, OpTrait::ResultsAreIntegerLike,
public: public:
static StringRef getOperationName() { return "subi"; } static StringRef getOperationName() { return "subi"; }
Attribute *constantFold(ArrayRef<Attribute *> operands, Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const; MLIRContext *context) const;
private: private:
friend class Operation; friend class Operation;

View File

@ -82,7 +82,7 @@ public:
} }
bool verifyOperation(const Operation &op); bool verifyOperation(const Operation &op);
bool verifyAttribute(Attribute *attr, const Operation &op); bool verifyAttribute(Attribute attr, const Operation &op);
protected: protected:
explicit Verifier(const Function &fn) : fn(fn) {} explicit Verifier(const Function &fn) : fn(fn) {}
@ -94,26 +94,26 @@ private:
} // end anonymous namespace } // end anonymous namespace
// Check that function attributes are all well formed. // Check that function attributes are all well formed.
bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) { bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
if (!attr->isOrContainsFunction()) if (!attr.isOrContainsFunction())
return false; return false;
// If we have a function attribute, check that it is non-null and in the // If we have a function attribute, check that it is non-null and in the
// same module as the operation that refers to it. // same module as the operation that refers to it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) { if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
if (!fnAttr->getValue()) if (!fnAttr.getValue())
return failure("attribute refers to deallocated function!", op); return failure("attribute refers to deallocated function!", op);
if (fnAttr->getValue()->getModule() != fn.getModule()) if (fnAttr.getValue()->getModule() != fn.getModule())
return failure("attribute refers to function '" + return failure("attribute refers to function '" +
Twine(fnAttr->getValue()->getName()) + Twine(fnAttr.getValue()->getName()) +
"' defined in another module!", "' defined in another module!",
op); op);
return false; return false;
} }
// Otherwise, we must have an array attribute, remap the elements. // Otherwise, we must have an array attribute, remap the elements.
for (auto *elt : cast<ArrayAttr>(attr)->getValue()) { for (auto elt : attr.cast<ArrayAttr>().getValue()) {
if (verifyAttribute(elt, op)) if (verifyAttribute(elt, op))
return true; return true;
} }

View File

@ -32,13 +32,12 @@ namespace {
// evaluated on constant 'operandConsts'. // evaluated on constant 'operandConsts'.
class AffineExprConstantFolder { class AffineExprConstantFolder {
public: public:
AffineExprConstantFolder(unsigned numDims, AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
ArrayRef<Attribute *> operandConsts)
: numDims(numDims), operandConsts(operandConsts) {} : numDims(numDims), operandConsts(operandConsts) {}
/// Attempt to constant fold the specified affine expr, or return null on /// Attempt to constant fold the specified affine expr, or return null on
/// failure. /// failure.
IntegerAttr *constantFold(AffineExpr expr) { IntegerAttr constantFold(AffineExpr expr) {
switch (expr.getKind()) { switch (expr.getKind()) {
case AffineExprKind::Add: case AffineExprKind::Add:
return constantFoldBinExpr( return constantFoldBinExpr(
@ -59,31 +58,32 @@ public:
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(), return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
expr.getContext()); expr.getContext());
case AffineExprKind::DimId: case AffineExprKind::DimId:
return dyn_cast_or_null<IntegerAttr>( return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
operandConsts[expr.cast<AffineDimExpr>().getPosition()]); .dyn_cast_or_null<IntegerAttr>();
case AffineExprKind::SymbolId: case AffineExprKind::SymbolId:
return dyn_cast_or_null<IntegerAttr>( return operandConsts[numDims +
operandConsts[numDims + expr.cast<AffineSymbolExpr>().getPosition()]); expr.cast<AffineSymbolExpr>().getPosition()]
.dyn_cast_or_null<IntegerAttr>();
} }
} }
private: private:
IntegerAttr * IntegerAttr
constantFoldBinExpr(AffineExpr expr, constantFoldBinExpr(AffineExpr expr,
std::function<uint64_t(int64_t, uint64_t)> op) { std::function<uint64_t(int64_t, uint64_t)> op) {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>(); auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
auto *lhs = constantFold(binOpExpr.getLHS()); auto lhs = constantFold(binOpExpr.getLHS());
auto *rhs = constantFold(binOpExpr.getRHS()); auto rhs = constantFold(binOpExpr.getRHS());
if (!lhs || !rhs) if (!lhs || !rhs)
return nullptr; return nullptr;
return IntegerAttr::get(op(lhs->getValue(), rhs->getValue()), return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
expr.getContext()); expr.getContext());
} }
// The number of dimension operands in AffineMap containing this expression. // The number of dimension operands in AffineMap containing this expression.
unsigned numDims; unsigned numDims;
// The constant valued operands used to evaluate this AffineExpr. // The constant valued operands used to evaluate this AffineExpr.
ArrayRef<Attribute *> operandConsts; ArrayRef<Attribute> operandConsts;
}; };
} // end anonymous namespace } // end anonymous namespace
@ -137,15 +137,15 @@ ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
/// Folds the results of the application of an affine map on the provided /// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens, /// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise. /// true otherwise.
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants, bool AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute *> &results) const { SmallVectorImpl<Attribute> &results) const {
assert(getNumInputs() == operandConstants.size()); assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions. // Fold each of the result expressions.
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants); AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
// Constant fold each AffineExpr in AffineMap and add to 'results'. // Constant fold each AffineExpr in AffineMap and add to 'results'.
for (auto expr : getResults()) { for (auto expr : getResults()) {
auto *folded = exprFolder.constantFold(expr); auto folded = exprFolder.constantFold(expr);
// If we didn't fold to a constant, then folding fails. // If we didn't fold to a constant, then folding fails.
if (!folded) if (!folded)
return true; return true;

View File

@ -123,7 +123,7 @@ private:
void visitIfStmt(const IfStmt *ifStmt); void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt); void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type); void visitType(const Type *type);
void visitAttribute(const Attribute *attr); void visitAttribute(Attribute attr);
void visitOperation(const Operation *op); void visitOperation(const Operation *op);
DenseMap<AffineMap, int> affineMapIds; DenseMap<AffineMap, int> affineMapIds;
@ -150,11 +150,11 @@ void ModuleState::visitType(const Type *type) {
} }
} }
void ModuleState::visitAttribute(const Attribute *attr) { void ModuleState::visitAttribute(Attribute attr) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) { if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
recordAffineMapReference(mapAttr->getValue()); recordAffineMapReference(mapAttr.getValue());
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) { } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (auto elt : arrayAttr->getValue()) { for (auto elt : arrayAttr.getValue()) {
visitAttribute(elt); visitAttribute(elt);
} }
} }
@ -268,7 +268,7 @@ public:
void print(const Module *module); void print(const Module *module);
void printFunctionReference(const Function *func); void printFunctionReference(const Function *func);
void printAttribute(const Attribute *attr); void printAttribute(Attribute attr);
void printType(const Type *type); void printType(const Type *type);
void print(const Function *fn); void print(const Function *fn);
void print(const ExtFunction *fn); void print(const ExtFunction *fn);
@ -293,7 +293,7 @@ protected:
void printAffineMapReference(AffineMap affineMap); void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const; void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet integerSet); void printIntegerSetReference(IntegerSet integerSet);
void printDenseElementsAttr(const DenseElementsAttr *attr); void printDenseElementsAttr(DenseElementsAttr attr);
/// This enum is used to represent the binding stength of the enclosing /// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExprStorage is being printed in, so we can /// context that an AffineExprStorage is being printed in, so we can
@ -404,36 +404,36 @@ void ModulePrinter::printFunctionReference(const Function *func) {
os << '@' << func->getName(); os << '@' << func->getName();
} }
void ModulePrinter::printAttribute(const Attribute *attr) { void ModulePrinter::printAttribute(Attribute attr) {
switch (attr->getKind()) { switch (attr.getKind()) {
case Attribute::Kind::Bool: case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false"); os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
break; break;
case Attribute::Kind::Integer: case Attribute::Kind::Integer:
os << cast<IntegerAttr>(attr)->getValue(); os << attr.cast<IntegerAttr>().getValue();
break; break;
case Attribute::Kind::Float: case Attribute::Kind::Float:
printFloatValue(cast<FloatAttr>(attr)->getValue(), os); printFloatValue(attr.cast<FloatAttr>().getValue(), os);
break; break;
case Attribute::Kind::String: case Attribute::Kind::String:
os << '"'; os << '"';
printEscapedString(cast<StringAttr>(attr)->getValue(), os); printEscapedString(attr.cast<StringAttr>().getValue(), os);
os << '"'; os << '"';
break; break;
case Attribute::Kind::Array: case Attribute::Kind::Array:
os << '['; os << '[';
interleaveComma(cast<ArrayAttr>(attr)->getValue(), interleaveComma(attr.cast<ArrayAttr>().getValue(),
[&](Attribute *attr) { printAttribute(attr); }); [&](Attribute attr) { printAttribute(attr); });
os << ']'; os << ']';
break; break;
case Attribute::Kind::AffineMap: case Attribute::Kind::AffineMap:
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue()); printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
break; break;
case Attribute::Kind::Type: case Attribute::Kind::Type:
printType(cast<TypeAttr>(attr)->getValue()); printType(attr.cast<TypeAttr>().getValue());
break; break;
case Attribute::Kind::Function: { case Attribute::Kind::Function: {
auto *function = cast<FunctionAttr>(attr)->getValue(); auto *function = attr.cast<FunctionAttr>().getValue();
if (!function) { if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>"; os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else { } else {
@ -444,53 +444,52 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
break; break;
} }
case Attribute::Kind::OpaqueElements: { case Attribute::Kind::OpaqueElements: {
auto *eltsAttr = cast<OpaqueElementsAttr>(attr); auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<"; os << "opaque<";
printType(eltsAttr->getType()); printType(eltsAttr.getType());
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"' os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
<< '>';
break; break;
} }
case Attribute::Kind::DenseIntElements: case Attribute::Kind::DenseIntElements:
case Attribute::Kind::DenseFPElements: { case Attribute::Kind::DenseFPElements: {
auto *eltsAttr = cast<DenseElementsAttr>(attr); auto eltsAttr = attr.cast<DenseElementsAttr>();
os << "dense<"; os << "dense<";
printType(eltsAttr->getType()); printType(eltsAttr.getType());
os << ", "; os << ", ";
printDenseElementsAttr(eltsAttr); printDenseElementsAttr(eltsAttr);
os << '>'; os << '>';
break; break;
} }
case Attribute::Kind::SplatElements: { case Attribute::Kind::SplatElements: {
auto *elementsAttr = cast<SplatElementsAttr>(attr); auto elementsAttr = attr.cast<SplatElementsAttr>();
os << "splat<"; os << "splat<";
printType(elementsAttr->getType()); printType(elementsAttr.getType());
os << ", "; os << ", ";
printAttribute(elementsAttr->getValue()); printAttribute(elementsAttr.getValue());
os << '>'; os << '>';
break; break;
} }
case Attribute::Kind::SparseElements: { case Attribute::Kind::SparseElements: {
auto *elementsAttr = cast<SparseElementsAttr>(attr); auto elementsAttr = attr.cast<SparseElementsAttr>();
os << "sparse<"; os << "sparse<";
printType(elementsAttr->getType()); printType(elementsAttr.getType());
os << ", "; os << ", ";
printDenseElementsAttr(elementsAttr->getIndices()); printDenseElementsAttr(elementsAttr.getIndices());
os << ", "; os << ", ";
printDenseElementsAttr(elementsAttr->getValues()); printDenseElementsAttr(elementsAttr.getValues());
os << '>'; os << '>';
break; break;
} }
} }
} }
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) { void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
auto *type = attr->getType(); auto *type = attr.getType();
auto shape = type->getShape(); auto shape = type->getShape();
auto rank = type->getRank(); auto rank = type->getRank();
SmallVector<Attribute *, 16> elements; SmallVector<Attribute, 16> elements;
attr->getValues(elements); attr.getValues(elements);
// Special case for degenerate tensors. // Special case for degenerate tensors.
if (elements.empty()) { if (elements.empty()) {
@ -934,9 +933,7 @@ public:
// Implement OpAsmPrinter. // Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; } raw_ostream &getStream() const { return os; }
void printType(const Type *type) { ModulePrinter::printType(type); } void printType(const Type *type) { ModulePrinter::printType(type); }
void printAttribute(const Attribute *attr) { void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
ModulePrinter::printAttribute(attr);
}
void printAffineMap(AffineMap map) { void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map); return ModulePrinter::printAffineMapReference(map);
} }
@ -980,7 +977,7 @@ protected:
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) { } else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue(); specialName << 'c' << intOp->getValue();
} else if (auto constant = op->dyn_cast<ConstantOp>()) { } else if (auto constant = op->dyn_cast<ConstantOp>()) {
if (isa<FunctionAttr>(constant->getValue())) if (constant->getValue().isa<FunctionAttr>())
specialName << 'f'; specialName << 'f';
else else
specialName << "cst"; specialName << "cst";
@ -1570,7 +1567,7 @@ void ModulePrinter::print(const MLFunction *fn) {
void Attribute::print(raw_ostream &os) const { void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr); ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAttribute(this); ModulePrinter(os, state).printAttribute(*this);
} }
void Attribute::dump() const { print(llvm::errs()); } void Attribute::dump() const { print(llvm::errs()); }

View File

@ -0,0 +1,131 @@
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of Attribute.
//
//===----------------------------------------------------------------------===//
#ifndef ATTRIBUTEDETAIL_H_
#define ATTRIBUTEDETAIL_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
namespace mlir {
namespace detail {
/// Base storage class appearing in an attribute.
struct AttributeStorage {
Attribute::Kind kind : 8;
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
};
/// An attribute representing a boolean value.
struct BoolAttributeStorage : public AttributeStorage {
bool value;
};
/// An attribute representing a integral value.
struct IntegerAttributeStorage : public AttributeStorage {
int64_t value;
};
/// An attribute representing a floating point value.
struct FloatAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
const llvm::fltSemantics &semantics;
size_t numObjects;
/// Returns an APFloat representing the stored value.
APFloat getValue() const {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
};
/// An attribute representing a string value.
struct StringAttributeStorage : public AttributeStorage {
StringRef value;
};
/// An attribute representing an array of other attributes.
struct ArrayAttributeStorage : public AttributeStorage {
ArrayRef<Attribute> value;
};
// An attribute representing a reference to an affine map.
struct AffineMapAttributeStorage : public AttributeStorage {
AffineMap value;
};
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
Type *value;
};
/// An attribute representing a reference to a function.
struct FunctionAttributeStorage : public AttributeStorage {
Function *value;
};
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
VectorOrTensorType *type;
};
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
Attribute elt;
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
ArrayRef<char> data;
};
/// An attribute representing a reference to a dense integer vector or tensor
/// object.
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
size_t bitsWidth;
};
/// An attribute representing a reference to a dense float vector or tensor
/// object.
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
};
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
StringRef bytes;
};
/// An attribute representing a reference to a sparse vector or tensor object.
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
DenseIntElementsAttr indices;
DenseElementsAttr values;
};
} // namespace detail
} // namespace mlir
#endif // ATTRIBUTEDETAIL_H_

214
mlir/lib/IR/Attributes.cpp Normal file
View File

@ -0,0 +1,214 @@
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Attributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Types.h"
using namespace mlir;
using namespace mlir::detail;
Attribute::Kind Attribute::getKind() const { return attr->kind; }
bool Attribute::isOrContainsFunction() const {
return attr->isOrContainsFunctionCache;
}
BoolAttr::BoolAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
int64_t IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
double FloatAttr::getDouble() const { return getValue().convertToDouble(); }
StringAttr::StringAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
StringRef StringAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
ArrayAttr::ArrayAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
ArrayRef<Attribute> ArrayAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
AffineMapAttr::AffineMapAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
AffineMap AffineMapAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Type *TypeAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
VectorOrTensorType *ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
SplatElementsAttr::SplatElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
Attribute SplatElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->elt;
}
DenseElementsAttr::DenseElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>().getValues(values);
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>().getValues(values);
return;
default:
llvm_unreachable("unexpected element type");
}
}
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(attr)->data;
}
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`.
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) {
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto attr = IntegerAttr::get(value, context);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto attr = IntegerAttr::get(value.getSExtValue(), context);
values.push_back(attr);
}
}
}
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto attr = FloatAttr::get(v, context);
values.push_back(attr);
}
}
OpaqueElementsAttr::OpaqueElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
StringRef OpaqueElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->bytes;
}
SparseElementsAttr::SparseElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
return static_cast<ImplType *>(attr)->indices;
}
DenseElementsAttr SparseElementsAttr::getValues() const {
return static_cast<ImplType *>(attr)->values;
}

View File

@ -112,60 +112,60 @@ UnrankedTensorType *Builder::getTensorType(Type *elementType) {
// Attributes. // Attributes.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
BoolAttr *Builder::getBoolAttr(bool value) { BoolAttr Builder::getBoolAttr(bool value) {
return BoolAttr::get(value, context); return BoolAttr::get(value, context);
} }
IntegerAttr *Builder::getIntegerAttr(int64_t value) { IntegerAttr Builder::getIntegerAttr(int64_t value) {
return IntegerAttr::get(value, context); return IntegerAttr::get(value, context);
} }
FloatAttr *Builder::getFloatAttr(double value) { FloatAttr Builder::getFloatAttr(double value) {
return FloatAttr::get(APFloat(value), context); return FloatAttr::get(APFloat(value), context);
} }
FloatAttr *Builder::getFloatAttr(const APFloat &value) { FloatAttr Builder::getFloatAttr(const APFloat &value) {
return FloatAttr::get(value, context); return FloatAttr::get(value, context);
} }
StringAttr *Builder::getStringAttr(StringRef bytes) { StringAttr Builder::getStringAttr(StringRef bytes) {
return StringAttr::get(bytes, context); return StringAttr::get(bytes, context);
} }
ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) { ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(value, context); return ArrayAttr::get(value, context);
} }
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) { AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
return AffineMapAttr::get(map); return AffineMapAttr::get(map);
} }
TypeAttr *Builder::getTypeAttr(Type *type) { TypeAttr Builder::getTypeAttr(Type *type) {
return TypeAttr::get(type, context); return TypeAttr::get(type, context);
} }
FunctionAttr *Builder::getFunctionAttr(const Function *value) { FunctionAttr Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context); return FunctionAttr::get(value, context);
} }
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
Attribute *elt) { Attribute elt) {
return SplatElementsAttr::get(type, elt); return SplatElementsAttr::get(type, elt);
} }
ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data) { ArrayRef<char> data) {
return DenseElementsAttr::get(type, data); return DenseElementsAttr::get(type, data);
} }
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indices, DenseIntElementsAttr indices,
DenseElementsAttr *values) { DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values); return SparseElementsAttr::get(type, indices, values);
} }
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type, ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes) { StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes); return OpaqueElementsAttr::get(type, bytes);
} }

View File

@ -86,13 +86,13 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder(); auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getIndexType(); auto *affineIntTy = builder.getIndexType();
AffineMapAttr *mapAttr; AffineMapAttr mapAttr;
unsigned numDims; unsigned numDims;
if (parser->parseAttribute(mapAttr, "map", result->attributes) || if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) || parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes)) parser->parseOptionalAttributeDict(result->attributes))
return true; return true;
auto map = mapAttr->getValue(); auto map = mapAttr.getValue();
if (map.getNumDims() != numDims || if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result->operands.size()) { numDims + map.getNumSymbols() != result->operands.size()) {
@ -113,12 +113,12 @@ void AffineApplyOp::print(OpAsmPrinter *p) const {
bool AffineApplyOp::verify() const { bool AffineApplyOp::verify() const {
// Check that affine map attribute was specified. // Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map"); auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr) if (!affineMapAttr)
return emitOpError("requires an affine map"); return emitOpError("requires an affine map");
// Check input and output dimensions match. // Check input and output dimensions match.
auto map = affineMapAttr->getValue(); auto map = affineMapAttr.getValue();
// Verify that operand count matches affine map dimension and symbol count. // Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map.getNumDims() + map.getNumSymbols()) if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
@ -155,8 +155,8 @@ bool AffineApplyOp::isValidSymbol() const {
return true; return true;
} }
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants, bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute *> &results, SmallVectorImpl<Attribute> &results,
MLIRContext *context) const { MLIRContext *context) const {
auto map = getAffineMap(); auto map = getAffineMap();
if (map.constantFold(operandConstants, results)) if (map.constantFold(operandConstants, results))
@ -171,21 +171,21 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
/// Builds a constant op with the specified attribute value and result type. /// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result, void ConstantOp::build(Builder *builder, OperationState *result,
Attribute *value, Type *type) { Attribute value, Type *type) {
result->addAttribute("value", value); result->addAttribute("value", value);
result->types.push_back(type); result->types.push_back(type);
} }
void ConstantOp::print(OpAsmPrinter *p) const { void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue(); *p << "constant " << getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value"); p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!isa<FunctionAttr>(getValue())) if (!getValue().isa<FunctionAttr>())
*p << " : " << *getType(); *p << " : " << *getType();
} }
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) { bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute *valueAttr; Attribute valueAttr;
Type *type; Type *type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) || if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
@ -194,8 +194,8 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
// 'constant' taking a function reference doesn't get a redundant type // 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it. // specifier. The attribute itself carries it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr)) if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
return parser->addTypeToList(fnAttr->getValue()->getType(), result->types); return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
return parser->parseColonType(type) || return parser->parseColonType(type) ||
parser->addTypeToList(type, result->types); parser->addTypeToList(type, result->types);
@ -204,32 +204,32 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
/// The constant op requires an attribute, and furthermore requires that it /// The constant op requires an attribute, and furthermore requires that it
/// matches the return type. /// matches the return type.
bool ConstantOp::verify() const { bool ConstantOp::verify() const {
auto *value = getValue(); auto value = getValue();
if (!value) if (!value)
return emitOpError("requires a 'value' attribute"); return emitOpError("requires a 'value' attribute");
auto *type = this->getType(); auto *type = this->getType();
if (isa<IntegerType>(type) || type->isIndex()) { if (isa<IntegerType>(type) || type->isIndex()) {
if (!isa<IntegerAttr>(value)) if (!value.isa<IntegerAttr>())
return emitOpError( return emitOpError(
"requires 'value' to be an integer for an integer result type"); "requires 'value' to be an integer for an integer result type");
return false; return false;
} }
if (isa<FloatType>(type)) { if (isa<FloatType>(type)) {
if (!isa<FloatAttr>(value)) if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant"); return emitOpError("requires 'value' to be a floating point constant");
return false; return false;
} }
if (type->isTFString()) { if (type->isTFString()) {
if (!isa<StringAttr>(value)) if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant"); return emitOpError("requires 'value' to be a string constant");
return false; return false;
} }
if (isa<FunctionType>(type)) { if (isa<FunctionType>(type)) {
if (!isa<FunctionAttr>(value)) if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference"); return emitOpError("requires 'value' to be a function reference");
return false; return false;
} }
@ -238,8 +238,8 @@ bool ConstantOp::verify() const {
"requires a result type that aligns with the 'value' attribute"); "requires a result type that aligns with the 'value' attribute");
} }
Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands, Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.empty() && "constant has no operands"); assert(operands.empty() && "constant has no operands");
return getValue(); return getValue();
} }

View File

@ -18,6 +18,7 @@
#include "mlir/IR/MLIRContext.h" #include "mlir/IR/MLIRContext.h"
#include "AffineExprDetail.h" #include "AffineExprDetail.h"
#include "AffineMapDetail.h" #include "AffineMapDetail.h"
#include "AttributeDetail.h"
#include "AttributeListStorage.h" #include "AttributeListStorage.h"
#include "IntegerSetDetail.h" #include "IntegerSetDetail.h"
#include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExpr.h"
@ -169,35 +170,35 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
} }
}; };
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> { struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
// Float attributes are uniqued based on wrapped APFloat. // Float attributes are uniqued based on wrapped APFloat.
using KeyTy = APFloat; using KeyTy = APFloat;
using DenseMapInfo<FloatAttr *>::getHashValue; using DenseMapInfo<FloatAttributeStorage *>::getHashValue;
using DenseMapInfo<FloatAttr *>::isEqual; using DenseMapInfo<FloatAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); } static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) { static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs.bitwiseIsEqual(rhs->getValue()); return lhs.bitwiseIsEqual(rhs->getValue());
} }
}; };
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> { struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
// Array attributes are uniqued based on their elements. // Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute *>; using KeyTy = ArrayRef<Attribute>;
using DenseMapInfo<ArrayAttr *>::getHashValue; using DenseMapInfo<ArrayAttributeStorage *>::getHashValue;
using DenseMapInfo<ArrayAttr *>::isEqual; using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine_range(key.begin(), key.end()); return hash_combine_range(key.begin(), key.end());
} }
static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) { static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == rhs->getValue(); return lhs == rhs->value;
} }
}; };
@ -218,37 +219,39 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
} }
}; };
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> { struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>; using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttr *>::getHashValue; using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttr *>::isEqual; using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end())); key.first, hash_combine_range(key.second.begin(), key.second.end()));
} }
static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) { static bool isEqual(const KeyTy &lhs,
const DenseElementsAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == std::make_pair(rhs->getType(), rhs->getRawData()); return lhs == std::make_pair(rhs->type, rhs->data);
} }
}; };
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> { struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, StringRef>; using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue; using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttr *>::isEqual; using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { static unsigned getHashValue(KeyTy key) {
return hash_combine( return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end())); key.first, hash_combine_range(key.second.begin(), key.second.end()));
} }
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) { static bool isEqual(const KeyTy &lhs,
const OpaqueElementsAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey()) if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false; return false;
return lhs == std::make_pair(rhs->getType(), rhs->getValue()); return lhs == std::make_pair(rhs->type, rhs->bytes);
} }
}; };
} // end anonymous namespace. } // end anonymous namespace.
@ -343,28 +346,29 @@ public:
MemRefTypeSet memrefs; MemRefTypeSet memrefs;
// Attribute uniquing. // Attribute uniquing.
BoolAttr *boolAttrs[2] = {nullptr}; BoolAttributeStorage *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttr *> integerAttrs; DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs; DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttr *> stringAttrs; StringMap<StringAttributeStorage *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>; using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs; ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs; DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<Type *, TypeAttr *> typeAttrs; DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
using AttributeListSet = using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>; DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists; AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttr *> functionAttrs; DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *> DenseMap<std::pair<VectorOrTensorType *, Attribute>,
SplatElementsAttributeStorage *>
splatElementsAttrs; splatElementsAttrs;
using DenseElementsAttrSet = using DenseElementsAttrSet =
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>; DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
DenseElementsAttrSet denseElementsAttrs; DenseElementsAttrSet denseElementsAttrs;
using OpaqueElementsAttrSet = using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>; DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs; OpaqueElementsAttrSet opaqueElementsAttrs;
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>, DenseMap<std::tuple<Type *, Attribute, Attribute>,
SparseElementsAttr *> SparseElementsAttributeStorage *>
sparseElementsAttrs; sparseElementsAttrs;
public: public:
@ -716,31 +720,36 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
// Attribute uniquing // Attribute uniquing
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
BoolAttr *BoolAttr::get(bool value, MLIRContext *context) { BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
auto *&result = context->getImpl().boolAttrs[value]; auto *&result = context->getImpl().boolAttrs[value];
if (result) if (result)
return result; return result;
result = context->getImpl().allocator.Allocate<BoolAttr>(); result = context->getImpl().allocator.Allocate<BoolAttributeStorage>();
new (result) BoolAttr(value); new (result) BoolAttributeStorage{{Attribute::Kind::Bool,
/*isOrContainsFunction=*/false},
value};
return result; return result;
} }
IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) { IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
auto *&result = context->getImpl().integerAttrs[value]; auto *&result = context->getImpl().integerAttrs[value];
if (result) if (result)
return result; return result;
result = context->getImpl().allocator.Allocate<IntegerAttr>(); result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
new (result) IntegerAttr(value); new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
/*isOrContainsFunction=*/false},
value};
result->value = value;
return result; return result;
} }
FloatAttr *FloatAttr::get(double value, MLIRContext *context) { FloatAttr FloatAttr::get(double value, MLIRContext *context) {
return get(APFloat(value), context); return get(APFloat(value), context);
} }
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) { FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if the float attribute has been created already. // Look to see if the float attribute has been created already.
@ -755,33 +764,35 @@ FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
// Here one word's bitwidth equals to that of uint64_t. // Here one word's bitwidth equals to that of uint64_t.
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords()); auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size()); auto byteSize =
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr)); FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size()); auto rawMem =
impl.allocator.Allocate(byteSize, alignof(FloatAttributeStorage));
auto result = ::new (rawMem) FloatAttributeStorage{
{Attribute::Kind::Float, /*isOrContainsFunction=*/false},
{},
value.getSemantics(),
elements.size()};
std::uninitialized_copy(elements.begin(), elements.end(), std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>()); result->getTrailingObjects<uint64_t>());
return *existing.first = result; return *existing.first = result;
} }
APFloat FloatAttr::getValue() const { StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first; auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
if (it->second) if (it->second)
return it->second; return it->second;
auto result = context->getImpl().allocator.Allocate<StringAttr>(); auto result = context->getImpl().allocator.Allocate<StringAttributeStorage>();
new (result) StringAttr(it->first()); new (result) StringAttributeStorage{{Attribute::Kind::String,
/*isOrContainsFunction=*/false},
it->first()};
it->second = result; it->second = result;
return result; return result;
} }
ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) { ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
auto &impl = context->getImpl(); auto &impl = context->getImpl();
// Look to see if we already have this. // Look to see if we already have this.
@ -792,61 +803,66 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
return *existing.first; return *existing.first;
// On the first use, we allocate them into the bump pointer. // On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<ArrayAttr>(); auto *result = impl.allocator.Allocate<ArrayAttributeStorage>();
// Copy the elements into the bump pointer. // Copy the elements into the bump pointer.
value = impl.copyInto(value); value = impl.copyInto(value);
// Check to see if any of the elements have a function attr. // Check to see if any of the elements have a function attr.
bool hasFunctionAttr = false; bool hasFunctionAttr = false;
for (auto *elt : value) for (auto elt : value)
if (elt->isOrContainsFunction()) { if (elt.isOrContainsFunction()) {
hasFunctionAttr = true; hasFunctionAttr = true;
break; break;
} }
// Initialize the memory using placement new. // Initialize the memory using placement new.
new (result) ArrayAttr(value, hasFunctionAttr); new (result)
ArrayAttributeStorage{{Attribute::Kind::Array, hasFunctionAttr}, value};
// Cache and return it. // Cache and return it.
return *existing.first = result; return *existing.first = result;
} }
AffineMapAttr *AffineMapAttr::get(AffineMap value) { AffineMapAttr AffineMapAttr::get(AffineMap value) {
auto *context = value.getResult(0).getContext(); auto *context = value.getResult(0).getContext();
auto &result = context->getImpl().affineMapAttrs[value]; auto &result = context->getImpl().affineMapAttrs[value];
if (result) if (result)
return result; return result;
result = context->getImpl().allocator.Allocate<AffineMapAttr>(); result = context->getImpl().allocator.Allocate<AffineMapAttributeStorage>();
new (result) AffineMapAttr(value); new (result) AffineMapAttributeStorage{{Attribute::Kind::AffineMap,
/*isOrContainsFunction=*/false},
value};
return result; return result;
} }
TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) { TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type]; auto *&result = context->getImpl().typeAttrs[type];
if (result) if (result)
return result; return result;
result = context->getImpl().allocator.Allocate<TypeAttr>(); result = context->getImpl().allocator.Allocate<TypeAttributeStorage>();
new (result) TypeAttr(type); new (result) TypeAttributeStorage{{Attribute::Kind::Type,
/*isOrContainsFunction=*/false},
type};
return result; return result;
} }
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) { FunctionAttr FunctionAttr::get(const Function *value, MLIRContext *context) {
assert(value && "Cannot get FunctionAttr for a null function"); assert(value && "Cannot get FunctionAttr for a null function");
auto *&result = context->getImpl().functionAttrs[value]; auto *&result = context->getImpl().functionAttrs[value];
if (result) if (result)
return result; return result;
result = context->getImpl().allocator.Allocate<FunctionAttr>(); result = context->getImpl().allocator.Allocate<FunctionAttributeStorage>();
new (result) FunctionAttr(const_cast<Function *>(value)); new (result) FunctionAttributeStorage{{Attribute::Kind::Function,
/*isOrContainsFunction=*/true},
const_cast<Function *>(value)};
return result; return result;
} }
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
/// This function is used by the internals of the Function class to null out /// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted. /// attributes refering to functions that are about to be deleted.
void FunctionAttr::dropFunctionReference(Function *value) { void FunctionAttr::dropFunctionReference(Function *value) {
@ -935,30 +951,29 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result; return *existing.first = result;
} }
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type, SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
StringRef bytes) { Attribute elt) {
assert(isValidTensorElementType(type->getElementType()) &&
"Input element type should be a valid tensor element type");
auto &impl = type->getContext()->getImpl(); auto &impl = type->getContext()->getImpl();
// Look to see if this constant is already defined. // Look to see if we already have this.
OpaqueElementsAttrInfo::KeyTy key({type, bytes}); auto *&result = impl.splatElementsAttrs[{type, elt}];
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value. // If we already have it, return that value.
if (!existing.second) if (result)
return *existing.first; return result;
// Otherwise, allocate a new one, unique it and return it. // Otherwise, allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>(); result = impl.allocator.Allocate<SplatElementsAttributeStorage>();
bytes = bytes.copy(impl.allocator); new (result) SplatElementsAttributeStorage{{{Attribute::Kind::SplatElements,
new (result) OpaqueElementsAttr(type, bytes); /*isOrContainsFunction=*/false},
return *existing.first = result; type},
elt};
return result;
} }
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type, DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
ArrayRef<char> data) { ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements(); auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
(void)(bitsRequired); (void)(bitsRequired);
assert((bitsRequired <= data.size() * 8L) && assert((bitsRequired <= data.size() * 8L) &&
@ -981,18 +996,25 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
case Type::Kind::F16: case Type::Kind::F16:
case Type::Kind::F32: case Type::Kind::F32:
case Type::Kind::F64: { case Type::Kind::F64: {
auto *result = impl.allocator.Allocate<DenseFPElementsAttr>(); auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy); std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseFPElementsAttr(type, {copy, data.size()}); new (result) DenseFPElementsAttributeStorage{
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}}};
return *existing.first = result; return *existing.first = result;
} }
case Type::Kind::Integer: { case Type::Kind::Integer: {
auto width = cast<IntegerType>(eltType)->getWidth(); auto width = ::cast<IntegerType>(eltType)->getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttr>(); auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy); std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseIntElementsAttr(type, {copy, data.size()}, width); new (result) DenseIntElementsAttributeStorage{
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}},
width};
return *existing.first = result; return *existing.first = result;
} }
default: default:
@ -1000,118 +1022,33 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
} }
} }
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos` OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
/// starting from `rawData`. StringRef bytes) {
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth, assert(isValidTensorElementType(type->getElementType()) &&
uint64_t value) { "Input element type should be a valid tensor element type");
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute *> &values) const {
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>(this)->getValues(values);
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>(this)->getValues(values);
return;
default:
llvm_unreachable("unexpected element type");
}
}
void DenseIntElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto *attr = IntegerAttr::get(value, context);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto *attr = IntegerAttr::get(value.getSExtValue(), context);
values.push_back(attr);
}
}
}
void DenseFPElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto *attr = FloatAttr::get(v, context);
values.push_back(attr);
}
}
SplatElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type,
Attribute *elt) {
auto &impl = type->getContext()->getImpl(); auto &impl = type->getContext()->getImpl();
// Look to see if we already have this. // Look to see if this constant is already defined.
auto *&result = impl.splatElementsAttrs[{type, elt}]; OpaqueElementsAttrInfo::KeyTy key({type, bytes});
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value. // If we already have it, return that value.
if (result) if (!existing.second)
return result; return *existing.first;
// Otherwise, allocate them into the bump pointer. // Otherwise, allocate a new one, unique it and return it.
result = impl.allocator.Allocate<SplatElementsAttr>(); auto *result = impl.allocator.Allocate<OpaqueElementsAttributeStorage>();
new (result) SplatElementsAttr(type, elt); bytes = bytes.copy(impl.allocator);
new (result) OpaqueElementsAttributeStorage{
return result; {{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type},
bytes};
return *existing.first = result;
} }
SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type, SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
DenseIntElementsAttr *indices, DenseIntElementsAttr indices,
DenseElementsAttr *values) { DenseElementsAttr values) {
auto &impl = type->getContext()->getImpl(); auto &impl = type->getContext()->getImpl();
// Look to see if we already have this. // Look to see if we already have this.
@ -1123,8 +1060,12 @@ SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
return result; return result;
// Otherwise, allocate them into the bump pointer. // Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SparseElementsAttr>(); result = impl.allocator.Allocate<SparseElementsAttributeStorage>();
new (result) SparseElementsAttr(type, indices, values); new (result) SparseElementsAttributeStorage{{{Attribute::Kind::SparseElements,
/*isOrContainsFunction=*/false},
type},
indices,
values};
return result; return result;
} }

View File

@ -148,7 +148,7 @@ ArrayRef<NamedAttribute> Operation::getAttrs() const {
/// If an attribute exists with the specified name, change it to the new /// If an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value. /// value. Otherwise, add a new attribute with the specified name/value.
void Operation::setAttr(Identifier name, Attribute *value) { void Operation::setAttr(Identifier name, Attribute value) {
assert(value && "attributes may never be null"); assert(value && "attributes may never be null");
auto origAttrs = getAttrs(); auto origAttrs = getAttrs();
@ -225,8 +225,8 @@ void Operation::erase() {
/// Attempt to constant fold this operation with the specified constant /// Attempt to constant fold this operation with the specified constant
/// operand values. If successful, this returns false and fills in the /// operand values. If successful, this returns false and fills in the
/// results vector. If not, this returns true and results is unspecified. /// results vector. If not, this returns true and results is unspecified.
bool Operation::constantFold(ArrayRef<Attribute *> operands, bool Operation::constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute *> &results) const { SmallVectorImpl<Attribute> &results) const {
// If we have a registered operation definition matching this one, use it to // If we have a registered operation definition matching this one, use it to
// try to constant fold the operation. // try to constant fold the operation.
if (auto *abstractOp = getAbstractOperation()) if (auto *abstractOp = getAbstractOperation())

View File

@ -195,7 +195,7 @@ public:
// Attribute parsing. // Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc, Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type); FunctionType *type);
Attribute *parseAttribute(); Attribute parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes); ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -204,8 +204,8 @@ public:
AffineMap parseAffineMapReference(); AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline(); IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference(); IntegerSet parseIntegerSetReference();
DenseElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type); DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr *parseDenseElementsAttr(Type *eltType, bool isVector); DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
VectorOrTensorType *parseVectorOrTensorType(); VectorOrTensorType *parseVectorOrTensorType();
private: private:
@ -684,7 +684,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Token::floatliteral: case Token::floatliteral:
case Token::integer: case Token::integer:
case Token::minus: { case Token::minus: {
auto *result = p.parseAttribute(); auto result = p.parseAttribute();
if (!result) if (!result)
return p.emitError("expected tensor element"); return p.emitError("expected tensor element");
// check result matches the element type. // check result matches the element type.
@ -693,16 +693,16 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Type::Kind::F16: case Type::Kind::F16:
case Type::Kind::F32: case Type::Kind::F32:
case Type::Kind::F64: { case Type::Kind::F64: {
if (!isa<FloatAttr>(result)) if (!result.isa<FloatAttr>())
return p.emitError("expected tensor literal element has float type"); return p.emitError("expected tensor literal element has float type");
double value = cast<FloatAttr>(result)->getDouble(); double value = result.cast<FloatAttr>().getDouble();
addToStorage(*(uint64_t *)(&value)); addToStorage(*(uint64_t *)(&value));
break; break;
} }
case Type::Kind::Integer: { case Type::Kind::Integer: {
if (!isa<IntegerAttr>(result)) if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type"); return p.emitError("expected tensor literal element has integer type");
auto value = cast<IntegerAttr>(result)->getValue(); auto value = result.cast<IntegerAttr>().getValue();
// If we couldn't successfully round trip the value, it means some bits // If we couldn't successfully round trip the value, it means some bits
// are truncated and we should give up here. // are truncated and we should give up here.
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true); llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
@ -804,7 +804,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | `sparse<` (tensor-type | vector-type)`,` /// | `sparse<` (tensor-type | vector-type)`,`
/// attribute-value`, ` attribute-value `>` /// attribute-value`, ` attribute-value `>`
/// ///
Attribute *Parser::parseAttribute() { Attribute Parser::parseAttribute() {
switch (getToken().getKind()) { switch (getToken().getKind()) {
case Token::kw_true: case Token::kw_true:
consumeToken(Token::kw_true); consumeToken(Token::kw_true);
@ -859,7 +859,7 @@ Attribute *Parser::parseAttribute() {
case Token::l_square: { case Token::l_square: {
consumeToken(Token::l_square); consumeToken(Token::l_square);
SmallVector<Attribute *, 4> elements; SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult { auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute()); elements.push_back(parseAttribute());
@ -928,7 +928,7 @@ Attribute *Parser::parseAttribute() {
case Token::floatliteral: case Token::floatliteral:
case Token::integer: case Token::integer:
case Token::minus: { case Token::minus: {
auto *scalar = parseAttribute(); auto scalar = parseAttribute();
if (parseToken(Token::greater, "expected '>'")) if (parseToken(Token::greater, "expected '>'"))
return nullptr; return nullptr;
return builder.getSplatElementsAttr(type, scalar); return builder.getSplatElementsAttr(type, scalar);
@ -973,7 +973,7 @@ Attribute *Parser::parseAttribute() {
case Token::l_square: { case Token::l_square: {
/// Parse indices /// Parse indices
auto *indicesEltType = builder.getIntegerType(32); auto *indicesEltType = builder.getIntegerType(32);
auto *indices = auto indices =
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type)); parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
if (parseToken(Token::comma, "expected ','")) if (parseToken(Token::comma, "expected ','"))
@ -981,12 +981,12 @@ Attribute *Parser::parseAttribute() {
/// Parse values. /// Parse values.
auto *valuesEltType = type->getElementType(); auto *valuesEltType = type->getElementType();
auto *values = auto values =
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type)); parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
/// Sanity check. /// Sanity check.
auto *indicesType = indices->getType(); auto *indicesType = indices.getType();
auto *valuesType = values->getType(); auto *valuesType = values.getType();
auto sameShape = (indicesType->getRank() == 1) || auto sameShape = (indicesType->getRank() == 1) ||
(type->getRank() == indicesType->getDimSize(1)); (type->getRank() == indicesType->getDimSize(1));
auto sameElementNum = auto sameElementNum =
@ -1009,7 +1009,7 @@ Attribute *Parser::parseAttribute() {
// Build the sparse elements attribute by the indices and values. // Build the sparse elements attribute by the indices and values.
return builder.getSparseElementsAttr( return builder.getSparseElementsAttr(
type, cast<DenseIntElementsAttr>(indices), values); type, indices.cast<DenseIntElementsAttr>(), values);
} }
default: default:
return (emitError("expected '[' to start sparse tensor literal"), return (emitError("expected '[' to start sparse tensor literal"),
@ -1035,8 +1035,7 @@ Attribute *Parser::parseAttribute() {
/// ///
/// This method returns a constructed dense elements attribute with the shape /// This method returns a constructed dense elements attribute with the shape
/// from the parsing result. /// from the parsing result.
DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType, DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
bool isVector) {
TensorLiteralParser literalParser(*this, eltType); TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse()) if (literalParser.parse())
return nullptr; return nullptr;
@ -1047,8 +1046,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
} else { } else {
type = builder.getTensorType(literalParser.getShape(), eltType); type = builder.getTensorType(literalParser.getShape(), eltType);
} }
return (DenseElementsAttr *)builder.getDenseElementsAttr( return builder.getDenseElementsAttr(type, literalParser.getValues())
type, literalParser.getValues()); .cast<DenseElementsAttr>();
} }
/// Dense elements attribute. /// Dense elements attribute.
@ -1061,7 +1060,7 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
/// This method compares the shapes from the parsing result and that from the /// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both /// input argument. It returns a constructed dense elements attribute if both
/// match. /// match.
DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) { DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
auto *eltTy = type->getElementType(); auto *eltTy = type->getElementType();
TensorLiteralParser literalParser(*this, eltTy); TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse()) if (literalParser.parse())
@ -1076,8 +1075,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
s << "])"; s << "])";
return (emitError(s.str()), nullptr); return (emitError(s.str()), nullptr);
} }
return (DenseElementsAttr *)builder.getDenseElementsAttr( return builder.getDenseElementsAttr(type, literalParser.getValues())
type, literalParser.getValues()); .cast<DenseElementsAttr>();
} }
/// Vector or tensor type for elements attribute. /// Vector or tensor type for elements attribute.
@ -2133,7 +2132,7 @@ public:
/// Parse an arbitrary attribute and return it in result. This also adds /// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name. /// the attribute to the specified attribute list with the specified name.
/// this captures the location of the attribute in 'loc' if it is non-null. /// this captures the location of the attribute in 'loc' if it is non-null.
bool parseAttribute(Attribute *&result, const char *attrName, bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) override { SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute(); result = parser.parseAttribute();
if (!result) if (!result)
@ -3336,27 +3335,27 @@ ParseResult ModuleParser::parseMLFunc() {
/// Given an attribute that could refer to a function attribute in the /// Given an attribute that could refer to a function attribute in the
/// remapping table, walk it and rewrite it to use the mapped function. If it /// remapping table, walk it and rewrite it to use the mapped function. If it
/// doesn't refer to anything in the table, then it is returned unmodified. /// doesn't refer to anything in the table, then it is returned unmodified.
static Attribute * static Attribute
remapFunctionAttrs(Attribute *input, remapFunctionAttrs(Attribute input,
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable, DenseMap<Attribute, FunctionAttr> &remappingTable,
MLIRContext *context) { MLIRContext *context) {
// Most attributes are trivially unrelated to function attributes, skip them // Most attributes are trivially unrelated to function attributes, skip them
// rapidly. // rapidly.
if (!input->isOrContainsFunction()) if (!input.isOrContainsFunction())
return input; return input;
// If we have a function attribute, remap it. // If we have a function attribute, remap it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) { if (auto fnAttr = input.dyn_cast<FunctionAttr>()) {
auto it = remappingTable.find(fnAttr); auto it = remappingTable.find(fnAttr);
return it != remappingTable.end() ? it->second : input; return it != remappingTable.end() ? it->second : input;
} }
// Otherwise, we must have an array attribute, remap the elements. // Otherwise, we must have an array attribute, remap the elements.
auto *arrayAttr = cast<ArrayAttr>(input); auto arrayAttr = input.cast<ArrayAttr>();
SmallVector<Attribute *, 8> remappedElts; SmallVector<Attribute, 8> remappedElts;
bool anyChange = false; bool anyChange = false;
for (auto *elt : arrayAttr->getValue()) { for (auto elt : arrayAttr.getValue()) {
auto *newElt = remapFunctionAttrs(elt, remappingTable, context); auto newElt = remapFunctionAttrs(elt, remappingTable, context);
remappedElts.push_back(newElt); remappedElts.push_back(newElt);
anyChange |= (elt != newElt); anyChange |= (elt != newElt);
} }
@ -3370,11 +3369,11 @@ remapFunctionAttrs(Attribute *input,
/// Remap function attributes to resolve forward references to their actual /// Remap function attributes to resolve forward references to their actual
/// definition. /// definition.
static void remapFunctionAttrsInOperation( static void remapFunctionAttrsInOperation(
Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) { Operation *op, DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op->getAttrs()) { for (auto attr : op->getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain // Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped. // functions that aren't getting remapped.
auto *newVal = auto newVal =
remapFunctionAttrs(attr.second, remappingTable, op->getContext()); remapFunctionAttrs(attr.second, remappingTable, op->getContext());
if (newVal == attr.second) if (newVal == attr.second)
continue; continue;
@ -3391,7 +3390,7 @@ static void remapFunctionAttrsInOperation(
ParseResult ModuleParser::finalizeModule() { ParseResult ModuleParser::finalizeModule() {
// Resolve all forward references, building a remapping table of attributes. // Resolve all forward references, building a remapping table of attributes.
DenseMap<FunctionAttr *, FunctionAttr *> remappingTable; DenseMap<Attribute, FunctionAttr> remappingTable;
for (auto forwardRef : getState().functionForwardRefs) { for (auto forwardRef : getState().functionForwardRefs) {
auto name = forwardRef.first; auto name = forwardRef.first;
@ -3428,13 +3427,13 @@ ParseResult ModuleParser::finalizeModule() {
continue; continue;
struct MLFnWalker : public StmtWalker<MLFnWalker> { struct MLFnWalker : public StmtWalker<MLFnWalker> {
MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) MLFnWalker(DenseMap<Attribute, FunctionAttr> &remappingTable)
: remappingTable(remappingTable) {} : remappingTable(remappingTable) {}
void visitOperationStmt(OperationStmt *opStmt) { void visitOperationStmt(OperationStmt *opStmt) {
remapFunctionAttrsInOperation(opStmt, remappingTable); remapFunctionAttrsInOperation(opStmt, remappingTable);
} }
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable; DenseMap<Attribute, FunctionAttr> &remappingTable;
}; };
MLFnWalker(remappingTable).walk(mlFn); MLFnWalker(remappingTable).walk(mlFn);

View File

@ -44,13 +44,13 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
// AddFOp // AddFOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands, Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "addf takes two operands"); assert(operands.size() == 2 && "addf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs->getValue() + rhs->getValue(), context); return FloatAttr::get(lhs.getValue() + rhs.getValue(), context);
} }
return nullptr; return nullptr;
@ -60,13 +60,13 @@ Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
// AddIOp // AddIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands, Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "addi takes two operands"); assert(operands.size() == 2 && "addi takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context); return IntegerAttr::get(lhs.getValue() + rhs.getValue(), context);
} }
return nullptr; return nullptr;
@ -192,12 +192,12 @@ void CallOp::print(OpAsmPrinter *p) const {
bool CallOp::verify() const { bool CallOp::verify() const {
// Check that the callee attribute was specified. // Check that the callee attribute was specified.
auto *fnAttr = getAttrOfType<FunctionAttr>("callee"); auto fnAttr = getAttrOfType<FunctionAttr>("callee");
if (!fnAttr) if (!fnAttr)
return emitOpError("requires a 'callee' function attribute"); return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee. // Verify that the operand and result types match the callee.
auto *fnType = fnAttr->getValue()->getType(); auto *fnType = fnAttr.getValue()->getType();
if (fnType->getNumInputs() != getNumOperands()) if (fnType->getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee"); return emitOpError("incorrect number of operands for callee");
@ -329,7 +329,7 @@ void DimOp::print(OpAsmPrinter *p) const {
bool DimOp::parse(OpAsmParser *parser, OperationState *result) { bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo; OpAsmParser::OperandType operandInfo;
IntegerAttr *indexAttr; IntegerAttr indexAttr;
Type *type; Type *type;
return parser->parseOperand(operandInfo) || parser->parseComma() || return parser->parseOperand(operandInfo) || parser->parseComma() ||
@ -346,7 +346,7 @@ bool DimOp::verify() const {
auto indexAttr = getAttrOfType<IntegerAttr>("index"); auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr) if (!indexAttr)
return emitOpError("requires an integer attribute named 'index'"); return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr->getValue(); uint64_t index = (uint64_t)indexAttr.getValue();
auto *type = getOperand()->getType(); auto *type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) { if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
@ -365,8 +365,8 @@ bool DimOp::verify() const {
return false; return false;
} }
Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands, Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant. // Constant fold dim when the size along the index referred to is a constant.
auto *opType = getOperand()->getType(); auto *opType = getOperand()->getType();
int indexSize = -1; int indexSize = -1;
@ -671,13 +671,13 @@ bool MemRefCastOp::verify() const {
// MulFOp // MulFOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands, Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "mulf takes two operands"); assert(operands.size() == 2 && "mulf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs->getValue() * rhs->getValue(), context); return FloatAttr::get(lhs.getValue() * rhs.getValue(), context);
} }
return nullptr; return nullptr;
@ -687,23 +687,23 @@ Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
// MulIOp // MulIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands, Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "muli takes two operands"); assert(operands.size() == 2 && "muli takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
// 0*x == 0 // 0*x == 0
if (lhs->getValue() == 0) if (lhs.getValue() == 0)
return lhs; return lhs;
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
// TODO: Handle the overflow case. // TODO: Handle the overflow case.
return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context); return IntegerAttr::get(lhs.getValue() * rhs.getValue(), context);
} }
// x*0 == 0 // x*0 == 0
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
if (rhs->getValue() == 0) if (rhs.getValue() == 0)
return rhs; return rhs;
return nullptr; return nullptr;
@ -817,13 +817,13 @@ bool StoreOp::verify() const {
// SubFOp // SubFOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands, Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "subf takes two operands"); assert(operands.size() == 2 && "subf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs->getValue() - rhs->getValue(), context); return FloatAttr::get(lhs.getValue() - rhs.getValue(), context);
} }
return nullptr; return nullptr;
@ -833,13 +833,13 @@ Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
// SubIOp // SubIOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands, Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const { MLIRContext *context) const {
assert(operands.size() == 2 && "subi takes two operands"); assert(operands.size() == 2 && "subi takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) { if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1])) if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context); return IntegerAttr::get(lhs.getValue() - rhs.getValue(), context);
} }
return nullptr; return nullptr;

View File

@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
SmallVector<SSAValue *, 8> existingConstants; SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased. // Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase; std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute *, Type *)>; using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
bool foldOperation(Operation *op, bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants, SmallVectorImpl<SSAValue *> &existingConstants,
@ -60,9 +60,9 @@ bool ConstantFold::foldOperation(Operation *op,
// Check to see if each of the operands is a trivial constant. If so, get // Check to see if each of the operands is a trivial constant. If so, get
// the value. If not, ignore the instruction. // the value. If not, ignore the instruction.
SmallVector<Attribute *, 8> operandConstants; SmallVector<Attribute, 8> operandConstants;
for (auto *operand : op->getOperands()) { for (auto *operand : op->getOperands()) {
Attribute *operandCst = nullptr; Attribute operandCst = nullptr;
if (auto *operandOp = operand->getDefiningOperation()) { if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue(); operandCst = operandConstantOp->getValue();
@ -71,7 +71,7 @@ bool ConstantFold::foldOperation(Operation *op,
} }
// Attempt to constant fold the operation. // Attempt to constant fold the operation.
SmallVector<Attribute *, 8> resultConstants; SmallVector<Attribute, 8> resultConstants;
if (op->constantFold(operandConstants, resultConstants)) if (op->constantFold(operandConstants, resultConstants))
return true; return true;
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) { for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
auto &inst = *instIt++; auto &inst = *instIt++;
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
builder.setInsertionPoint(&inst); builder.setInsertionPoint(&inst);
return builder.create<ConstantOp>(inst.getLoc(), value, type); return builder.create<ConstantOp>(inst.getLoc(), value, type);
}; };
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding. // Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) { void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * { auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
MLFuncBuilder builder(stmt); MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type); return builder.create<ConstantOp>(stmt->getLoc(), value, type);
}; };

View File

@ -71,8 +71,8 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) { void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
for (auto attr : opStmt->getAttrs()) { for (auto attr : opStmt->getAttrs()) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) { if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr->getValue()); MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify(); mMap.simplify();
auto map = mMap.getAffineMap(); auto map = mMap.getAffineMap();
opStmt->setAttr(attr.first, AffineMapAttr::get(map)); opStmt->setAttr(attr.first, AffineMapAttr::get(map));

View File

@ -79,7 +79,7 @@ private:
/// As part of canonicalization, we move constants to the top of the entry /// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of /// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for. /// constants we have done this for.
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants; DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
}; };
}; // end anonymous namespace }; // end anonymous namespace
@ -107,7 +107,7 @@ public:
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction, void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
WorklistRewriter &rewriter) { WorklistRewriter &rewriter) {
// These are scratch vectors used in the constant folding loop below. // These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute *, 8> operandConstants, resultConstants; SmallVector<Attribute, 8> operandConstants, resultConstants;
while (!worklist.empty()) { while (!worklist.empty()) {
auto *op = popFromWorklist(); auto *op = popFromWorklist();
@ -175,7 +175,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// the operation knows how to constant fold itself. // the operation knows how to constant fold itself.
operandConstants.clear(); operandConstants.clear();
for (auto *operand : op->getOperands()) { for (auto *operand : op->getOperands()) {
Attribute *operandCst = nullptr; Attribute operandCst;
if (auto *operandOp = operand->getDefiningOperation()) { if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue(); operandCst = operandConstantOp->getValue();

View File

@ -353,11 +353,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
// Check to see if each of the operands is the result of a constant. If so, // Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it. // get the value. If not, ignore it.
SmallVector<Attribute *, 8> operandConstants; SmallVector<Attribute, 8> operandConstants;
auto boundOperands = lower ? forStmt->getLowerBoundOperands() auto boundOperands = lower ? forStmt->getLowerBoundOperands()
: forStmt->getUpperBoundOperands(); : forStmt->getUpperBoundOperands();
for (const auto *operand : boundOperands) { for (const auto *operand : boundOperands) {
Attribute *operandCst = nullptr; Attribute operandCst;
if (auto *operandOp = operand->getDefiningOperation()) { if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>()) if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue(); operandCst = operandConstantOp->getValue();
@ -369,15 +369,15 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap(); lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 && assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result"); "bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults; SmallVector<Attribute, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults)) if (boundMap.constantFold(operandConstants, foldedResults))
return true; return true;
// Compute the max or min as applicable over the results. // Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result"); assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue(); auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1; i < foldedResults.size(); i++) { for (unsigned i = 1; i < foldedResults.size(); i++) {
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue(); auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? std::max(maxOrMin, foldedResult) maxOrMin = lower ? std::max(maxOrMin, foldedResult)
: std::min(maxOrMin, foldedResult); : std::min(maxOrMin, foldedResult);
} }

View File

@ -154,7 +154,7 @@ void OpEmitter::emitAttrGetters() {
<< val.getName() << "() const {\n"; << val.getName() << "() const {\n";
os << " return this->getAttrOfType<" os << " return this->getAttrOfType<"
<< attr.getValueAsString("AttrType").trim() << ">(\"" << name << attr.getValueAsString("AttrType").trim() << ">(\"" << name
<< "\")->getValue();\n }\n"; << "\").getValue();\n }\n";
} }
} }
@ -207,9 +207,9 @@ void OpEmitter::emitVerifier() {
// Verify the attributes have the correct type. // Verify the attributes have the correct type.
for (const auto attr : attrs) { for (const auto attr : attrs) {
auto name = attr.first->getName(); auto name = attr.first->getName();
os << " if (!dyn_cast_or_null<" os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.second->getValueAsString("AttrType") << ">(this->getAttr(\"" << attr.second->getValueAsString("AttrType") << ">("
<< name << "\"))) return emitOpError(\"requires " << ")) return emitOpError(\"requires "
<< attr.second->getValueAsString("PrimitiveType").trim() << attr.second->getValueAsString("PrimitiveType").trim()
<< " attribute '" << name << "'\");\n"; << " attribute '" << name << "'\");\n";
} }