Implement value type abstraction for attributes.
This is done by changing Attribute to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast. PiperOrigin-RevId: 218764173
This commit is contained in:
parent
64d52014bd
commit
792d1c25e4
|
@ -95,8 +95,8 @@ public:
|
|||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible. Returns false if the folding happens,
|
||||
/// true otherwise.
|
||||
bool constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results) const;
|
||||
bool constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results) const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(AffineMap arg);
|
||||
|
||||
|
|
|
@ -30,10 +30,32 @@ class MLIRContext;
|
|||
class Type;
|
||||
class VectorOrTensorType;
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct AttributeStorage;
|
||||
struct BoolAttributeStorage;
|
||||
struct IntegerAttributeStorage;
|
||||
struct FloatAttributeStorage;
|
||||
struct StringAttributeStorage;
|
||||
struct ArrayAttributeStorage;
|
||||
struct AffineMapAttributeStorage;
|
||||
struct TypeAttributeStorage;
|
||||
struct FunctionAttributeStorage;
|
||||
struct ElementsAttributeStorage;
|
||||
struct SplatElementsAttributeStorage;
|
||||
struct DenseElementsAttributeStorage;
|
||||
struct DenseIntElementsAttributeStorage;
|
||||
struct DenseFPElementsAttributeStorage;
|
||||
struct OpaqueElementsAttributeStorage;
|
||||
struct SparseElementsAttributeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Attributes are known-constant values of operations and functions.
|
||||
///
|
||||
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
|
||||
/// by MLIRContext. As such, they are passed around by raw non-const pointer.
|
||||
/// by MLIRContext. As such, an Attribute is a POD interface to an underlying
|
||||
/// storage pointer.
|
||||
class Attribute {
|
||||
public:
|
||||
enum class Kind {
|
||||
|
@ -55,177 +77,151 @@ public:
|
|||
LAST_ELEMENTS_ATTR = SparseElements,
|
||||
};
|
||||
|
||||
typedef detail::AttributeStorage ImplType;
|
||||
|
||||
Attribute() : attr(nullptr) {}
|
||||
/* implicit */ Attribute(const ImplType *attr)
|
||||
: attr(const_cast<ImplType *>(attr)) {}
|
||||
|
||||
Attribute(const Attribute &other) : attr(other.attr) {}
|
||||
Attribute &operator=(Attribute other) {
|
||||
attr = other.attr;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool operator==(Attribute other) const { return attr == other.attr; }
|
||||
bool operator!=(Attribute other) const { return !(*this == other); }
|
||||
explicit operator bool() const { return attr; }
|
||||
|
||||
bool operator!() const { return attr == nullptr; }
|
||||
|
||||
template <typename U> bool isa() const;
|
||||
template <typename U> U dyn_cast() const;
|
||||
template <typename U> U dyn_cast_or_null() const;
|
||||
template <typename U> U cast() const;
|
||||
|
||||
/// Return the classification for this attribute.
|
||||
Kind getKind() const { return kind; }
|
||||
Kind getKind() const;
|
||||
|
||||
/// Return true if this field is, or contains, a function attribute.
|
||||
bool isOrContainsFunction() const { return isOrContainsFunctionCache; }
|
||||
bool isOrContainsFunction() const;
|
||||
|
||||
/// Print the attribute.
|
||||
void print(raw_ostream &os) const;
|
||||
void dump() const;
|
||||
|
||||
friend ::llvm::hash_code hash_value(Attribute arg);
|
||||
|
||||
protected:
|
||||
explicit Attribute(Kind kind, bool isOrContainsFunction)
|
||||
: kind(kind), isOrContainsFunctionCache(isOrContainsFunction) {}
|
||||
~Attribute() {}
|
||||
|
||||
private:
|
||||
/// Classification of the subclass, used for type checking.
|
||||
Kind kind : 8;
|
||||
|
||||
/// This field is true if this is, or contains, a function attribute.
|
||||
bool isOrContainsFunctionCache : 1;
|
||||
|
||||
Attribute(const Attribute &) = delete;
|
||||
void operator=(const Attribute &) = delete;
|
||||
ImplType *attr;
|
||||
};
|
||||
|
||||
inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) {
|
||||
inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
|
||||
attr.print(os);
|
||||
return os;
|
||||
}
|
||||
|
||||
class BoolAttr : public Attribute {
|
||||
public:
|
||||
static BoolAttr *get(bool value, MLIRContext *context);
|
||||
typedef detail::BoolAttributeStorage ImplType;
|
||||
BoolAttr() = default;
|
||||
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
|
||||
|
||||
bool getValue() const { return value; }
|
||||
static BoolAttr get(bool value, MLIRContext *context);
|
||||
|
||||
bool getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Bool;
|
||||
}
|
||||
|
||||
private:
|
||||
BoolAttr(bool value)
|
||||
: Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
|
||||
~BoolAttr() = delete;
|
||||
bool value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Bool; }
|
||||
};
|
||||
|
||||
class IntegerAttr : public Attribute {
|
||||
public:
|
||||
static IntegerAttr *get(int64_t value, MLIRContext *context);
|
||||
typedef detail::IntegerAttributeStorage ImplType;
|
||||
IntegerAttr() = default;
|
||||
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
|
||||
|
||||
int64_t getValue() const { return value; }
|
||||
static IntegerAttr get(int64_t value, MLIRContext *context);
|
||||
|
||||
int64_t getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Integer;
|
||||
}
|
||||
|
||||
private:
|
||||
IntegerAttr(int64_t value)
|
||||
: Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
|
||||
}
|
||||
~IntegerAttr() = delete;
|
||||
int64_t value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||
};
|
||||
|
||||
class FloatAttr final : public Attribute,
|
||||
public llvm::TrailingObjects<FloatAttr, uint64_t> {
|
||||
class FloatAttr final : public Attribute {
|
||||
public:
|
||||
static FloatAttr *get(double value, MLIRContext *context);
|
||||
static FloatAttr *get(const APFloat &value, MLIRContext *context);
|
||||
typedef detail::FloatAttributeStorage ImplType;
|
||||
FloatAttr() = default;
|
||||
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static FloatAttr get(double value, MLIRContext *context);
|
||||
static FloatAttr get(const APFloat &value, MLIRContext *context);
|
||||
|
||||
APFloat getValue() const;
|
||||
|
||||
double getDouble() const { return getValue().convertToDouble(); }
|
||||
double getDouble() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Float;
|
||||
}
|
||||
|
||||
private:
|
||||
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
|
||||
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
|
||||
semantics(semantics), numObjects(numObjects) {}
|
||||
FloatAttr(const FloatAttr &value) = delete;
|
||||
~FloatAttr() = delete;
|
||||
|
||||
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
|
||||
return numObjects;
|
||||
}
|
||||
|
||||
const llvm::fltSemantics &semantics;
|
||||
size_t numObjects;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Float; }
|
||||
};
|
||||
|
||||
class StringAttr : public Attribute {
|
||||
public:
|
||||
static StringAttr *get(StringRef bytes, MLIRContext *context);
|
||||
typedef detail::StringAttributeStorage ImplType;
|
||||
StringAttr() = default;
|
||||
/* implicit */ StringAttr(Attribute::ImplType *ptr);
|
||||
|
||||
StringRef getValue() const { return value; }
|
||||
static StringAttr get(StringRef bytes, MLIRContext *context);
|
||||
|
||||
StringRef getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::String;
|
||||
}
|
||||
|
||||
private:
|
||||
StringAttr(StringRef value)
|
||||
: Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
|
||||
~StringAttr() = delete;
|
||||
StringRef value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::String; }
|
||||
};
|
||||
|
||||
/// Array attributes are lists of other attributes. They are not necessarily
|
||||
/// type homogenous given that attributes don't, in general, carry types.
|
||||
class ArrayAttr : public Attribute {
|
||||
public:
|
||||
static ArrayAttr *get(ArrayRef<Attribute *> value, MLIRContext *context);
|
||||
typedef detail::ArrayAttributeStorage ImplType;
|
||||
ArrayAttr() = default;
|
||||
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
|
||||
|
||||
ArrayRef<Attribute *> getValue() const { return value; }
|
||||
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
|
||||
|
||||
ArrayRef<Attribute> getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Array;
|
||||
}
|
||||
|
||||
private:
|
||||
ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
|
||||
: Attribute(Kind::Array, isOrContainsFunction), value(value) {}
|
||||
~ArrayAttr() = delete;
|
||||
ArrayRef<Attribute *> value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Array; }
|
||||
};
|
||||
|
||||
class AffineMapAttr : public Attribute {
|
||||
public:
|
||||
static AffineMapAttr *get(AffineMap value);
|
||||
typedef detail::AffineMapAttributeStorage ImplType;
|
||||
AffineMapAttr() = default;
|
||||
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
|
||||
|
||||
AffineMap getValue() const { return value; }
|
||||
static AffineMapAttr get(AffineMap value);
|
||||
|
||||
AffineMap getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::AffineMap;
|
||||
}
|
||||
|
||||
private:
|
||||
AffineMapAttr(AffineMap value)
|
||||
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
|
||||
value(value) {}
|
||||
~AffineMapAttr() = delete;
|
||||
AffineMap value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
|
||||
};
|
||||
|
||||
class TypeAttr : public Attribute {
|
||||
public:
|
||||
static TypeAttr *get(Type *type, MLIRContext *context);
|
||||
typedef detail::TypeAttributeStorage ImplType;
|
||||
TypeAttr() = default;
|
||||
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
||||
|
||||
Type *getValue() const { return value; }
|
||||
static TypeAttr get(Type *type, MLIRContext *context);
|
||||
|
||||
Type *getValue() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Type;
|
||||
}
|
||||
|
||||
private:
|
||||
TypeAttr(Type *value)
|
||||
: Attribute(Kind::Type, /*isOrContainsFunction=*/false), value(value) {}
|
||||
~TypeAttr() = delete;
|
||||
Type *value;
|
||||
static bool kindof(Kind kind) { return kind == Kind::Type; }
|
||||
};
|
||||
|
||||
/// A function attribute represents a reference to a function object.
|
||||
|
@ -237,63 +233,53 @@ private:
|
|||
/// remain in MLIRContext.
|
||||
class FunctionAttr : public Attribute {
|
||||
public:
|
||||
static FunctionAttr *get(const Function *value, MLIRContext *context);
|
||||
typedef detail::FunctionAttributeStorage ImplType;
|
||||
FunctionAttr() = default;
|
||||
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
|
||||
|
||||
Function *getValue() const { return value; }
|
||||
static FunctionAttr get(const Function *value, MLIRContext *context);
|
||||
|
||||
Function *getValue() const;
|
||||
|
||||
FunctionType *getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::Function;
|
||||
}
|
||||
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||
|
||||
/// This function is used by the internals of the Function class to null out
|
||||
/// attributes refering to functions that are about to be deleted.
|
||||
static void dropFunctionReference(Function *value);
|
||||
|
||||
private:
|
||||
FunctionAttr(Function *value)
|
||||
: Attribute(Kind::Function, /*isOrContainsFunction=*/true), value(value) {
|
||||
}
|
||||
~FunctionAttr() = delete;
|
||||
Function *value;
|
||||
};
|
||||
|
||||
/// A base attribute represents a reference to a vector or tensor constant.
|
||||
class ElementsAttr : public Attribute {
|
||||
public:
|
||||
ElementsAttr(Kind kind, VectorOrTensorType *type)
|
||||
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {}
|
||||
typedef detail::ElementsAttributeStorage ImplType;
|
||||
ElementsAttr() = default;
|
||||
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
VectorOrTensorType *getType() const { return type; }
|
||||
VectorOrTensorType *getType() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR &&
|
||||
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR;
|
||||
static bool kindof(Kind kind) {
|
||||
return kind >= Kind::FIRST_ELEMENTS_ATTR &&
|
||||
kind <= Kind::LAST_ELEMENTS_ATTR;
|
||||
}
|
||||
|
||||
private:
|
||||
VectorOrTensorType *type;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a splat vecctor or tensor constant,
|
||||
/// meaning all of the elements have the same value.
|
||||
class SplatElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
static SplatElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
|
||||
Attribute *getValue() const { return elt; }
|
||||
typedef detail::SplatElementsAttributeStorage ImplType;
|
||||
SplatElementsAttr() = default;
|
||||
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
|
||||
Attribute getValue() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::SplatElements;
|
||||
}
|
||||
|
||||
private:
|
||||
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
|
||||
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
|
||||
Attribute *elt;
|
||||
static bool kindof(Kind kind) { return kind == Kind::SplatElements; }
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a dense vector or tensor object.
|
||||
|
@ -302,42 +288,42 @@ private:
|
|||
/// than 64.
|
||||
class DenseElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseElementsAttributeStorage ImplType;
|
||||
DenseElementsAttr() = default;
|
||||
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
/// It assumes the elements in the input array have been truncated to the bits
|
||||
/// width specified by the element type (note all float type are 64 bits).
|
||||
/// When the value is retrieved, the bits are read from the storage and extend
|
||||
/// to 64 bits if necessary.
|
||||
static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef<char> data);
|
||||
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
|
||||
|
||||
// TODO: Read the data from the attribute list and compress them
|
||||
// to a character array. Then call the above method to construct the
|
||||
// attribute.
|
||||
static DenseElementsAttr *get(VectorOrTensorType *type,
|
||||
ArrayRef<Attribute *> values);
|
||||
static DenseElementsAttr get(VectorOrTensorType *type,
|
||||
ArrayRef<Attribute> values);
|
||||
|
||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
ArrayRef<char> getRawData() const { return data; }
|
||||
ArrayRef<char> getRawData() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::DenseIntElements ||
|
||||
attr->getKind() == Kind::DenseFPElements;
|
||||
static bool kindof(Kind kind) {
|
||||
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
|
||||
}
|
||||
|
||||
protected:
|
||||
DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef<char> data)
|
||||
: ElementsAttr(kind, type), data(data) {}
|
||||
|
||||
private:
|
||||
ArrayRef<char> data;
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a dense integer vector or tensor
|
||||
/// object.
|
||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseIntElementsAttributeStorage ImplType;
|
||||
DenseIntElementsAttr() = default;
|
||||
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
// TODO: returns APInts instead of IntegerAttr.
|
||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
APInt getValue(ArrayRef<unsigned> indices) const;
|
||||
|
||||
|
@ -352,41 +338,24 @@ public:
|
|||
size_t bitsWidth);
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::DenseIntElements;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class DenseElementsAttr;
|
||||
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
|
||||
size_t bitsWidth)
|
||||
: DenseElementsAttr(Kind::DenseIntElements, type, data),
|
||||
bitsWidth(bitsWidth) {}
|
||||
|
||||
~DenseIntElementsAttr() = delete;
|
||||
|
||||
size_t bitsWidth;
|
||||
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a dense float vector or tensor
|
||||
/// object. Each element is stored as a double.
|
||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
typedef detail::DenseFPElementsAttributeStorage ImplType;
|
||||
DenseFPElementsAttr() = default;
|
||||
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
// TODO: returns APFPs instead of FloatAttr.
|
||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
||||
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||
|
||||
APFloat getValue(ArrayRef<unsigned> indices) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::DenseFPElements;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class DenseElementsAttr;
|
||||
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
|
||||
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
|
||||
~DenseFPElementsAttr() = delete;
|
||||
static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a tensor constant with opaque
|
||||
|
@ -394,20 +363,16 @@ private:
|
|||
/// doesn't need to interpret.
|
||||
class OpaqueElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes);
|
||||
typedef detail::OpaqueElementsAttributeStorage ImplType;
|
||||
OpaqueElementsAttr() = default;
|
||||
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
StringRef getValue() const { return bytes; }
|
||||
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
|
||||
|
||||
StringRef getValue() const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::OpaqueElements;
|
||||
}
|
||||
|
||||
private:
|
||||
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
|
||||
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
|
||||
~OpaqueElementsAttr() = delete;
|
||||
StringRef bytes;
|
||||
static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; }
|
||||
};
|
||||
|
||||
/// An attribute represents a reference to a sparse vector or tensor object.
|
||||
|
@ -427,32 +392,67 @@ private:
|
|||
/// [0, 0, 0, 0]].
|
||||
class SparseElementsAttr : public ElementsAttr {
|
||||
public:
|
||||
static SparseElementsAttr *get(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values);
|
||||
typedef detail::SparseElementsAttributeStorage ImplType;
|
||||
SparseElementsAttr() = default;
|
||||
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
||||
|
||||
DenseIntElementsAttr *getIndices() const { return indices; }
|
||||
static SparseElementsAttr get(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
|
||||
DenseElementsAttr *getValues() const { return values; }
|
||||
DenseIntElementsAttr getIndices() const;
|
||||
|
||||
DenseElementsAttr getValues() const;
|
||||
|
||||
/// Return the value at the given index.
|
||||
Attribute *getValue(ArrayRef<unsigned> index) const;
|
||||
Attribute getValue(ArrayRef<unsigned> index) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(const Attribute *attr) {
|
||||
return attr->getKind() == Kind::SparseElements;
|
||||
}
|
||||
|
||||
private:
|
||||
SparseElementsAttr(VectorOrTensorType *type, DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values)
|
||||
: ElementsAttr(Kind::SparseElements, type), indices(indices),
|
||||
values(values) {}
|
||||
~SparseElementsAttr() = delete;
|
||||
|
||||
DenseIntElementsAttr *const indices;
|
||||
DenseElementsAttr *const values;
|
||||
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
|
||||
};
|
||||
|
||||
template <typename U> bool Attribute::isa() const {
|
||||
assert(attr && "isa<> used on a null attribute.");
|
||||
return U::kindof(getKind());
|
||||
}
|
||||
template <typename U> U Attribute::dyn_cast() const {
|
||||
return isa<U>() ? U(attr) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Attribute::dyn_cast_or_null() const {
|
||||
return (attr && isa<U>()) ? U(attr) : U(nullptr);
|
||||
}
|
||||
template <typename U> U Attribute::cast() const {
|
||||
assert(isa<U>());
|
||||
return U(attr);
|
||||
}
|
||||
|
||||
// Make Attribute hashable.
|
||||
inline ::llvm::hash_code hash_value(Attribute arg) {
|
||||
return ::llvm::hash_value(arg.attr);
|
||||
}
|
||||
|
||||
} // end namespace mlir.
|
||||
|
||||
namespace llvm {
|
||||
|
||||
// Attribute hash just like pointers
|
||||
template <> struct DenseMapInfo<mlir::Attribute> {
|
||||
static mlir::Attribute getEmptyKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
|
||||
}
|
||||
static mlir::Attribute getTombstoneKey() {
|
||||
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
|
||||
}
|
||||
static unsigned getHashValue(mlir::Attribute val) {
|
||||
return mlir::hash_value(val);
|
||||
}
|
||||
static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
|
||||
return LHS == RHS;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
#endif
|
||||
|
|
|
@ -93,23 +93,23 @@ public:
|
|||
UnrankedTensorType *getTensorType(Type *elementType);
|
||||
|
||||
// Attributes.
|
||||
BoolAttr *getBoolAttr(bool value);
|
||||
IntegerAttr *getIntegerAttr(int64_t value);
|
||||
FloatAttr *getFloatAttr(double value);
|
||||
FloatAttr *getFloatAttr(const APFloat &value);
|
||||
StringAttr *getStringAttr(StringRef bytes);
|
||||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
||||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
||||
TypeAttr *getTypeAttr(Type *type);
|
||||
FunctionAttr *getFunctionAttr(const Function *value);
|
||||
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
|
||||
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values);
|
||||
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
StringRef bytes);
|
||||
|
||||
BoolAttr getBoolAttr(bool value);
|
||||
IntegerAttr getIntegerAttr(int64_t value);
|
||||
FloatAttr getFloatAttr(double value);
|
||||
FloatAttr getFloatAttr(const APFloat &value);
|
||||
StringAttr getStringAttr(StringRef bytes);
|
||||
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||
TypeAttr getTypeAttr(Type *type);
|
||||
FunctionAttr getFunctionAttr(const Function *value);
|
||||
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ArrayRef<char> data);
|
||||
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values);
|
||||
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
|
||||
|
||||
// Affine expressions and affine maps.
|
||||
AffineExpr getAffineDimExpr(unsigned position);
|
||||
|
|
|
@ -60,7 +60,7 @@ public:
|
|||
|
||||
/// Returns the affine map to be applied by this operation.
|
||||
AffineMap getAffineMap() const {
|
||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
||||
return getAttrOfType<AffineMapAttr>("map").getValue();
|
||||
}
|
||||
|
||||
/// Returns true if the result of this operation can be used as dimension id.
|
||||
|
@ -75,8 +75,8 @@ public:
|
|||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
bool constantFold(ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results,
|
||||
bool constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
|
@ -94,10 +94,10 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
|
|||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||
public:
|
||||
/// Builds a constant op with the specified attribute value and result type.
|
||||
static void build(Builder *builder, OperationState *result, Attribute *value,
|
||||
static void build(Builder *builder, OperationState *result, Attribute value,
|
||||
Type *type);
|
||||
|
||||
Attribute *getValue() const { return getAttr("value"); }
|
||||
Attribute getValue() const { return getAttr("value"); }
|
||||
|
||||
static StringRef getOperationName() { return "constant"; }
|
||||
|
||||
|
@ -105,8 +105,8 @@ public:
|
|||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
bool verify() const;
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
protected:
|
||||
friend class Operation;
|
||||
|
@ -125,7 +125,7 @@ public:
|
|||
const APFloat &value, FloatType *type);
|
||||
|
||||
APFloat getValue() const {
|
||||
return getAttrOfType<FloatAttr>("value")->getValue();
|
||||
return getAttrOfType<FloatAttr>("value").getValue();
|
||||
}
|
||||
|
||||
static bool isClassFor(const Operation *op);
|
||||
|
@ -152,7 +152,7 @@ public:
|
|||
Type *type);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||
}
|
||||
|
||||
static bool isClassFor(const Operation *op);
|
||||
|
@ -173,7 +173,7 @@ public:
|
|||
static void build(Builder *builder, OperationState *result, int64_t value);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||
}
|
||||
|
||||
static bool isClassFor(const Operation *op);
|
||||
|
|
|
@ -24,12 +24,12 @@
|
|||
#ifndef MLIR_IR_FUNCTION_H
|
||||
#define MLIR_IR_FUNCTION_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ilist.h"
|
||||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class AttributeListStorage;
|
||||
class FunctionType;
|
||||
class Location;
|
||||
|
@ -39,7 +39,7 @@ class Module;
|
|||
/// NamedAttribute is used for function attribute lists, it holds an
|
||||
/// identifier for the name and a value for the attribute. The attribute
|
||||
/// pointer should always be non-null.
|
||||
using NamedAttribute = std::pair<Identifier, Attribute *>;
|
||||
using NamedAttribute = std::pair<Identifier, Attribute>;
|
||||
|
||||
/// This is the base class for all of the MLIR function types.
|
||||
class Function : public llvm::ilist_node_with_parent<Function, Module> {
|
||||
|
|
|
@ -138,17 +138,16 @@ public:
|
|||
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
|
||||
|
||||
/// Return an attribute with the specified name.
|
||||
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
|
||||
Attribute getAttr(StringRef name) const { return state->getAttr(name); }
|
||||
|
||||
/// If the operation has an attribute of the specified type, return it.
|
||||
template <typename AttrClass>
|
||||
AttrClass *getAttrOfType(StringRef name) const {
|
||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
||||
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
|
||||
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void setAttr(Identifier name, Attribute *value) {
|
||||
void setAttr(Identifier name, Attribute value) {
|
||||
state->setAttr(name, value);
|
||||
}
|
||||
|
||||
|
@ -211,8 +210,8 @@ public:
|
|||
/// true if folding failed, or returns false and fills in `results` on
|
||||
/// success.
|
||||
static bool constantFoldHook(const Operation *op,
|
||||
ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results) {
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
return op->cast<ConcreteType>()->constantFold(operands, results,
|
||||
op->getContext());
|
||||
}
|
||||
|
@ -226,8 +225,8 @@ public:
|
|||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
bool constantFold(ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results,
|
||||
bool constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
MLIRContext *context) const {
|
||||
return true;
|
||||
}
|
||||
|
@ -244,9 +243,9 @@ public:
|
|||
/// true if folding failed, or returns false and fills in `results` on
|
||||
/// success.
|
||||
static bool constantFoldHook(const Operation *op,
|
||||
ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results) {
|
||||
auto *result =
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) {
|
||||
auto result =
|
||||
op->cast<ConcreteType>()->constantFold(operands, op->getContext());
|
||||
if (!result)
|
||||
return true;
|
||||
|
@ -511,8 +510,8 @@ public:
|
|||
///
|
||||
/// If not overridden, this fallback implementation always fails to fold.
|
||||
///
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
|
|
@ -69,7 +69,7 @@ public:
|
|||
}
|
||||
virtual void printType(const Type *type) = 0;
|
||||
virtual void printFunctionReference(const Function *func) = 0;
|
||||
virtual void printAttribute(const Attribute *attr) = 0;
|
||||
virtual void printAttribute(Attribute attr) = 0;
|
||||
virtual void printAffineMap(AffineMap map) = 0;
|
||||
virtual void printAffineExpr(AffineExpr expr) = 0;
|
||||
|
||||
|
@ -100,8 +100,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
|
|||
return p;
|
||||
}
|
||||
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
|
||||
p.printAttribute(&attr);
|
||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
|
||||
p.printAttribute(attr);
|
||||
return p;
|
||||
}
|
||||
|
||||
|
@ -210,24 +210,24 @@ public:
|
|||
/// Parse an arbitrary attribute and return it in result. This also adds the
|
||||
/// attribute to the specified attribute list with the specified name. this
|
||||
/// captures the location of the attribute in 'loc' if it is non-null.
|
||||
virtual bool parseAttribute(Attribute *&result, const char *attrName,
|
||||
virtual bool parseAttribute(Attribute &result, const char *attrName,
|
||||
SmallVectorImpl<NamedAttribute> &attrs) = 0;
|
||||
|
||||
/// Parse an attribute of a specific kind, capturing the location into `loc`
|
||||
/// if specified.
|
||||
template <typename AttrType>
|
||||
bool parseAttribute(AttrType *&result, const char *attrName,
|
||||
bool parseAttribute(AttrType &result, const char *attrName,
|
||||
SmallVectorImpl<NamedAttribute> &attrs) {
|
||||
llvm::SMLoc loc;
|
||||
getCurrentLocation(&loc);
|
||||
|
||||
// Parse any kind of attribute.
|
||||
Attribute *attr;
|
||||
Attribute attr;
|
||||
if (parseAttribute(attr, attrName, attrs))
|
||||
return true;
|
||||
|
||||
// Check for the right kind of attribute.
|
||||
result = dyn_cast<AttrType>(attr);
|
||||
result = attr.dyn_cast<AttrType>();
|
||||
if (!result) {
|
||||
emitError(loc, "invalid kind of constant specified");
|
||||
return true;
|
||||
|
|
|
@ -113,33 +113,31 @@ public:
|
|||
ArrayRef<NamedAttribute> getAttrs() const;
|
||||
|
||||
/// Return the specified attribute if present, null otherwise.
|
||||
Attribute *getAttr(Identifier name) const {
|
||||
Attribute getAttr(Identifier name) const {
|
||||
for (auto elt : getAttrs())
|
||||
if (elt.first == name)
|
||||
return elt.second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Attribute *getAttr(StringRef name) const {
|
||||
Attribute getAttr(StringRef name) const {
|
||||
for (auto elt : getAttrs())
|
||||
if (elt.first.is(name))
|
||||
return elt.second;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <typename AttrClass>
|
||||
AttrClass *getAttrOfType(Identifier name) const {
|
||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
||||
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
|
||||
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
template <typename AttrClass>
|
||||
AttrClass *getAttrOfType(StringRef name) const {
|
||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
||||
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
|
||||
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||
}
|
||||
|
||||
/// If the an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void setAttr(Identifier name, Attribute *value);
|
||||
void setAttr(Identifier name, Attribute value);
|
||||
|
||||
enum class RemoveResult {
|
||||
Removed, NotFound
|
||||
|
@ -250,8 +248,8 @@ public:
|
|||
/// the operands of the operation, but may be null if non-constant. If
|
||||
/// constant folding is successful, this returns false and fills in the
|
||||
/// `results` vector. If not, this returns true and `results` is unspecified.
|
||||
bool constantFold(ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results) const;
|
||||
bool constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool classof(const Instruction *inst);
|
||||
|
|
|
@ -23,11 +23,11 @@
|
|||
#ifndef MLIR_IR_OPERATION_SUPPORT_H
|
||||
#define MLIR_IR_OPERATION_SUPPORT_H
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
|
||||
namespace mlir {
|
||||
class Attribute;
|
||||
class Dialect;
|
||||
class Location;
|
||||
class Operation;
|
||||
|
@ -80,8 +80,8 @@ public:
|
|||
/// This hook implements a constant folder for this operation. It returns
|
||||
/// true if folding failed, or returns false and fills in `results` on
|
||||
/// success.
|
||||
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results);
|
||||
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results);
|
||||
|
||||
// Returns whether the operation has a particular property.
|
||||
bool hasProperty(OperationProperty property) const {
|
||||
|
@ -110,8 +110,8 @@ private:
|
|||
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
|
||||
bool (&verifyInvariants)(const Operation *op),
|
||||
bool (&constantFoldHook)(const Operation *op,
|
||||
ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results))
|
||||
ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results))
|
||||
: name(name), dialect(dialect), isClassFor(isClassFor),
|
||||
parseAssembly(parseAssembly), printAssembly(printAssembly),
|
||||
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
|
||||
|
@ -124,7 +124,7 @@ private:
|
|||
/// NamedAttribute is a used for operation attribute lists, it holds an
|
||||
/// identifier for the name and a value for the attribute. The attribute
|
||||
/// pointer should always be non-null.
|
||||
using NamedAttribute = std::pair<Identifier, Attribute *>;
|
||||
using NamedAttribute = std::pair<Identifier, Attribute>;
|
||||
|
||||
class OperationName {
|
||||
public:
|
||||
|
@ -204,7 +204,7 @@ public:
|
|||
types.append(newTypes.begin(), newTypes.end());
|
||||
}
|
||||
|
||||
void addAttribute(StringRef name, Attribute *attr) {
|
||||
void addAttribute(StringRef name, Attribute attr) {
|
||||
attributes.push_back({Identifier::get(name, context), attr});
|
||||
}
|
||||
};
|
||||
|
|
|
@ -49,8 +49,8 @@ class AddFOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "addf"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
@ -70,8 +70,8 @@ class AddIOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "addi"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
@ -134,7 +134,7 @@ public:
|
|||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
Function *getCallee() const {
|
||||
return getAttrOfType<FunctionAttr>("callee")->getValue();
|
||||
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||
}
|
||||
|
||||
// Hooks to customize behavior of this op.
|
||||
|
@ -218,13 +218,13 @@ public:
|
|||
static void build(Builder *builder, OperationState *result,
|
||||
SSAValue *memrefOrTensor, unsigned index);
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
/// This returns the dimension number that the 'dim' is inspecting.
|
||||
unsigned getIndex() const {
|
||||
return static_cast<unsigned>(
|
||||
getAttrOfType<IntegerAttr>("index")->getValue());
|
||||
getAttrOfType<IntegerAttr>("index").getValue());
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "dim"; }
|
||||
|
@ -513,8 +513,8 @@ class MulFOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "mulf"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
@ -534,8 +534,8 @@ class MulIOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "muli"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
@ -597,8 +597,8 @@ class SubFOp : public BinaryOp<SubFOp, OpTrait::ResultsAreFloatLike,
|
|||
public:
|
||||
static StringRef getOperationName() { return "subf"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
@ -617,8 +617,8 @@ class SubIOp : public BinaryOp<SubIOp, OpTrait::ResultsAreIntegerLike,
|
|||
public:
|
||||
static StringRef getOperationName() { return "subi"; }
|
||||
|
||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const;
|
||||
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const;
|
||||
|
||||
private:
|
||||
friend class Operation;
|
||||
|
|
|
@ -82,7 +82,7 @@ public:
|
|||
}
|
||||
|
||||
bool verifyOperation(const Operation &op);
|
||||
bool verifyAttribute(Attribute *attr, const Operation &op);
|
||||
bool verifyAttribute(Attribute attr, const Operation &op);
|
||||
|
||||
protected:
|
||||
explicit Verifier(const Function &fn) : fn(fn) {}
|
||||
|
@ -94,26 +94,26 @@ private:
|
|||
} // end anonymous namespace
|
||||
|
||||
// Check that function attributes are all well formed.
|
||||
bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) {
|
||||
if (!attr->isOrContainsFunction())
|
||||
bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
|
||||
if (!attr.isOrContainsFunction())
|
||||
return false;
|
||||
|
||||
// If we have a function attribute, check that it is non-null and in the
|
||||
// same module as the operation that refers to it.
|
||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) {
|
||||
if (!fnAttr->getValue())
|
||||
if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
|
||||
if (!fnAttr.getValue())
|
||||
return failure("attribute refers to deallocated function!", op);
|
||||
|
||||
if (fnAttr->getValue()->getModule() != fn.getModule())
|
||||
if (fnAttr.getValue()->getModule() != fn.getModule())
|
||||
return failure("attribute refers to function '" +
|
||||
Twine(fnAttr->getValue()->getName()) +
|
||||
Twine(fnAttr.getValue()->getName()) +
|
||||
"' defined in another module!",
|
||||
op);
|
||||
return false;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
for (auto *elt : cast<ArrayAttr>(attr)->getValue()) {
|
||||
for (auto elt : attr.cast<ArrayAttr>().getValue()) {
|
||||
if (verifyAttribute(elt, op))
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -32,13 +32,12 @@ namespace {
|
|||
// evaluated on constant 'operandConsts'.
|
||||
class AffineExprConstantFolder {
|
||||
public:
|
||||
AffineExprConstantFolder(unsigned numDims,
|
||||
ArrayRef<Attribute *> operandConsts)
|
||||
AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
|
||||
: numDims(numDims), operandConsts(operandConsts) {}
|
||||
|
||||
/// Attempt to constant fold the specified affine expr, or return null on
|
||||
/// failure.
|
||||
IntegerAttr *constantFold(AffineExpr expr) {
|
||||
IntegerAttr constantFold(AffineExpr expr) {
|
||||
switch (expr.getKind()) {
|
||||
case AffineExprKind::Add:
|
||||
return constantFoldBinExpr(
|
||||
|
@ -59,31 +58,32 @@ public:
|
|||
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
|
||||
expr.getContext());
|
||||
case AffineExprKind::DimId:
|
||||
return dyn_cast_or_null<IntegerAttr>(
|
||||
operandConsts[expr.cast<AffineDimExpr>().getPosition()]);
|
||||
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
|
||||
.dyn_cast_or_null<IntegerAttr>();
|
||||
case AffineExprKind::SymbolId:
|
||||
return dyn_cast_or_null<IntegerAttr>(
|
||||
operandConsts[numDims + expr.cast<AffineSymbolExpr>().getPosition()]);
|
||||
return operandConsts[numDims +
|
||||
expr.cast<AffineSymbolExpr>().getPosition()]
|
||||
.dyn_cast_or_null<IntegerAttr>();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
IntegerAttr *
|
||||
IntegerAttr
|
||||
constantFoldBinExpr(AffineExpr expr,
|
||||
std::function<uint64_t(int64_t, uint64_t)> op) {
|
||||
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
|
||||
auto *lhs = constantFold(binOpExpr.getLHS());
|
||||
auto *rhs = constantFold(binOpExpr.getRHS());
|
||||
auto lhs = constantFold(binOpExpr.getLHS());
|
||||
auto rhs = constantFold(binOpExpr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
return IntegerAttr::get(op(lhs->getValue(), rhs->getValue()),
|
||||
return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
|
||||
expr.getContext());
|
||||
}
|
||||
|
||||
// The number of dimension operands in AffineMap containing this expression.
|
||||
unsigned numDims;
|
||||
// The constant valued operands used to evaluate this AffineExpr.
|
||||
ArrayRef<Attribute *> operandConsts;
|
||||
ArrayRef<Attribute> operandConsts;
|
||||
};
|
||||
|
||||
} // end anonymous namespace
|
||||
|
@ -137,15 +137,15 @@ ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
|
|||
/// Folds the results of the application of an affine map on the provided
|
||||
/// operands to a constant if possible. Returns false if the folding happens,
|
||||
/// true otherwise.
|
||||
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results) const {
|
||||
bool AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results) const {
|
||||
assert(getNumInputs() == operandConstants.size());
|
||||
|
||||
// Fold each of the result expressions.
|
||||
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
|
||||
// Constant fold each AffineExpr in AffineMap and add to 'results'.
|
||||
for (auto expr : getResults()) {
|
||||
auto *folded = exprFolder.constantFold(expr);
|
||||
auto folded = exprFolder.constantFold(expr);
|
||||
// If we didn't fold to a constant, then folding fails.
|
||||
if (!folded)
|
||||
return true;
|
||||
|
|
|
@ -123,7 +123,7 @@ private:
|
|||
void visitIfStmt(const IfStmt *ifStmt);
|
||||
void visitOperationStmt(const OperationStmt *opStmt);
|
||||
void visitType(const Type *type);
|
||||
void visitAttribute(const Attribute *attr);
|
||||
void visitAttribute(Attribute attr);
|
||||
void visitOperation(const Operation *op);
|
||||
|
||||
DenseMap<AffineMap, int> affineMapIds;
|
||||
|
@ -150,11 +150,11 @@ void ModuleState::visitType(const Type *type) {
|
|||
}
|
||||
}
|
||||
|
||||
void ModuleState::visitAttribute(const Attribute *attr) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
|
||||
recordAffineMapReference(mapAttr->getValue());
|
||||
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
||||
for (auto elt : arrayAttr->getValue()) {
|
||||
void ModuleState::visitAttribute(Attribute attr) {
|
||||
if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
|
||||
recordAffineMapReference(mapAttr.getValue());
|
||||
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||
for (auto elt : arrayAttr.getValue()) {
|
||||
visitAttribute(elt);
|
||||
}
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ public:
|
|||
|
||||
void print(const Module *module);
|
||||
void printFunctionReference(const Function *func);
|
||||
void printAttribute(const Attribute *attr);
|
||||
void printAttribute(Attribute attr);
|
||||
void printType(const Type *type);
|
||||
void print(const Function *fn);
|
||||
void print(const ExtFunction *fn);
|
||||
|
@ -293,7 +293,7 @@ protected:
|
|||
void printAffineMapReference(AffineMap affineMap);
|
||||
void printIntegerSetId(int integerSetId) const;
|
||||
void printIntegerSetReference(IntegerSet integerSet);
|
||||
void printDenseElementsAttr(const DenseElementsAttr *attr);
|
||||
void printDenseElementsAttr(DenseElementsAttr attr);
|
||||
|
||||
/// This enum is used to represent the binding stength of the enclosing
|
||||
/// context that an AffineExprStorage is being printed in, so we can
|
||||
|
@ -404,36 +404,36 @@ void ModulePrinter::printFunctionReference(const Function *func) {
|
|||
os << '@' << func->getName();
|
||||
}
|
||||
|
||||
void ModulePrinter::printAttribute(const Attribute *attr) {
|
||||
switch (attr->getKind()) {
|
||||
void ModulePrinter::printAttribute(Attribute attr) {
|
||||
switch (attr.getKind()) {
|
||||
case Attribute::Kind::Bool:
|
||||
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
|
||||
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
|
||||
break;
|
||||
case Attribute::Kind::Integer:
|
||||
os << cast<IntegerAttr>(attr)->getValue();
|
||||
os << attr.cast<IntegerAttr>().getValue();
|
||||
break;
|
||||
case Attribute::Kind::Float:
|
||||
printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
|
||||
printFloatValue(attr.cast<FloatAttr>().getValue(), os);
|
||||
break;
|
||||
case Attribute::Kind::String:
|
||||
os << '"';
|
||||
printEscapedString(cast<StringAttr>(attr)->getValue(), os);
|
||||
printEscapedString(attr.cast<StringAttr>().getValue(), os);
|
||||
os << '"';
|
||||
break;
|
||||
case Attribute::Kind::Array:
|
||||
os << '[';
|
||||
interleaveComma(cast<ArrayAttr>(attr)->getValue(),
|
||||
[&](Attribute *attr) { printAttribute(attr); });
|
||||
interleaveComma(attr.cast<ArrayAttr>().getValue(),
|
||||
[&](Attribute attr) { printAttribute(attr); });
|
||||
os << ']';
|
||||
break;
|
||||
case Attribute::Kind::AffineMap:
|
||||
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
|
||||
printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
|
||||
break;
|
||||
case Attribute::Kind::Type:
|
||||
printType(cast<TypeAttr>(attr)->getValue());
|
||||
printType(attr.cast<TypeAttr>().getValue());
|
||||
break;
|
||||
case Attribute::Kind::Function: {
|
||||
auto *function = cast<FunctionAttr>(attr)->getValue();
|
||||
auto *function = attr.cast<FunctionAttr>().getValue();
|
||||
if (!function) {
|
||||
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
||||
} else {
|
||||
|
@ -444,53 +444,52 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
|
|||
break;
|
||||
}
|
||||
case Attribute::Kind::OpaqueElements: {
|
||||
auto *eltsAttr = cast<OpaqueElementsAttr>(attr);
|
||||
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
||||
os << "opaque<";
|
||||
printType(eltsAttr->getType());
|
||||
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"'
|
||||
<< '>';
|
||||
printType(eltsAttr.getType());
|
||||
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
case Attribute::Kind::DenseFPElements: {
|
||||
auto *eltsAttr = cast<DenseElementsAttr>(attr);
|
||||
auto eltsAttr = attr.cast<DenseElementsAttr>();
|
||||
os << "dense<";
|
||||
printType(eltsAttr->getType());
|
||||
printType(eltsAttr.getType());
|
||||
os << ", ";
|
||||
printDenseElementsAttr(eltsAttr);
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::SplatElements: {
|
||||
auto *elementsAttr = cast<SplatElementsAttr>(attr);
|
||||
auto elementsAttr = attr.cast<SplatElementsAttr>();
|
||||
os << "splat<";
|
||||
printType(elementsAttr->getType());
|
||||
printType(elementsAttr.getType());
|
||||
os << ", ";
|
||||
printAttribute(elementsAttr->getValue());
|
||||
printAttribute(elementsAttr.getValue());
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
case Attribute::Kind::SparseElements: {
|
||||
auto *elementsAttr = cast<SparseElementsAttr>(attr);
|
||||
auto elementsAttr = attr.cast<SparseElementsAttr>();
|
||||
os << "sparse<";
|
||||
printType(elementsAttr->getType());
|
||||
printType(elementsAttr.getType());
|
||||
os << ", ";
|
||||
printDenseElementsAttr(elementsAttr->getIndices());
|
||||
printDenseElementsAttr(elementsAttr.getIndices());
|
||||
os << ", ";
|
||||
printDenseElementsAttr(elementsAttr->getValues());
|
||||
printDenseElementsAttr(elementsAttr.getValues());
|
||||
os << '>';
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) {
|
||||
auto *type = attr->getType();
|
||||
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||
auto *type = attr.getType();
|
||||
auto shape = type->getShape();
|
||||
auto rank = type->getRank();
|
||||
|
||||
SmallVector<Attribute *, 16> elements;
|
||||
attr->getValues(elements);
|
||||
SmallVector<Attribute, 16> elements;
|
||||
attr.getValues(elements);
|
||||
|
||||
// Special case for degenerate tensors.
|
||||
if (elements.empty()) {
|
||||
|
@ -934,9 +933,7 @@ public:
|
|||
// Implement OpAsmPrinter.
|
||||
raw_ostream &getStream() const { return os; }
|
||||
void printType(const Type *type) { ModulePrinter::printType(type); }
|
||||
void printAttribute(const Attribute *attr) {
|
||||
ModulePrinter::printAttribute(attr);
|
||||
}
|
||||
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
|
||||
void printAffineMap(AffineMap map) {
|
||||
return ModulePrinter::printAffineMapReference(map);
|
||||
}
|
||||
|
@ -980,7 +977,7 @@ protected:
|
|||
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
||||
specialName << 'c' << intOp->getValue();
|
||||
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
|
||||
if (isa<FunctionAttr>(constant->getValue()))
|
||||
if (constant->getValue().isa<FunctionAttr>())
|
||||
specialName << 'f';
|
||||
else
|
||||
specialName << "cst";
|
||||
|
@ -1570,7 +1567,7 @@ void ModulePrinter::print(const MLFunction *fn) {
|
|||
|
||||
void Attribute::print(raw_ostream &os) const {
|
||||
ModuleState state(/*no context is known*/ nullptr);
|
||||
ModulePrinter(os, state).printAttribute(this);
|
||||
ModulePrinter(os, state).printAttribute(*this);
|
||||
}
|
||||
|
||||
void Attribute::dump() const { print(llvm::errs()); }
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This holds implementation details of Attribute.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ATTRIBUTEDETAIL_H_
|
||||
#define ATTRIBUTEDETAIL_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
|
||||
/// Base storage class appearing in an attribute.
|
||||
struct AttributeStorage {
|
||||
Attribute::Kind kind : 8;
|
||||
|
||||
/// This field is true if this is, or contains, a function attribute.
|
||||
bool isOrContainsFunctionCache : 1;
|
||||
};
|
||||
|
||||
/// An attribute representing a boolean value.
|
||||
struct BoolAttributeStorage : public AttributeStorage {
|
||||
bool value;
|
||||
};
|
||||
|
||||
/// An attribute representing a integral value.
|
||||
struct IntegerAttributeStorage : public AttributeStorage {
|
||||
int64_t value;
|
||||
};
|
||||
|
||||
/// An attribute representing a floating point value.
|
||||
struct FloatAttributeStorage final
|
||||
: public AttributeStorage,
|
||||
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
|
||||
const llvm::fltSemantics &semantics;
|
||||
size_t numObjects;
|
||||
|
||||
/// Returns an APFloat representing the stored value.
|
||||
APFloat getValue() const {
|
||||
auto val = APInt(APFloat::getSizeInBits(semantics),
|
||||
{getTrailingObjects<uint64_t>(), numObjects});
|
||||
return APFloat(semantics, val);
|
||||
}
|
||||
};
|
||||
|
||||
/// An attribute representing a string value.
|
||||
struct StringAttributeStorage : public AttributeStorage {
|
||||
StringRef value;
|
||||
};
|
||||
|
||||
/// An attribute representing an array of other attributes.
|
||||
struct ArrayAttributeStorage : public AttributeStorage {
|
||||
ArrayRef<Attribute> value;
|
||||
};
|
||||
|
||||
// An attribute representing a reference to an affine map.
|
||||
struct AffineMapAttributeStorage : public AttributeStorage {
|
||||
AffineMap value;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a type.
|
||||
struct TypeAttributeStorage : public AttributeStorage {
|
||||
Type *value;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a function.
|
||||
struct FunctionAttributeStorage : public AttributeStorage {
|
||||
Function *value;
|
||||
};
|
||||
|
||||
/// A base attribute representing a reference to a vector or tensor constant.
|
||||
struct ElementsAttributeStorage : public AttributeStorage {
|
||||
VectorOrTensorType *type;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a vector or tensor constant,
|
||||
/// inwhich all elements have the same value.
|
||||
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
Attribute elt;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a dense vector or tensor object.
|
||||
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
ArrayRef<char> data;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a dense integer vector or tensor
|
||||
/// object.
|
||||
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||
size_t bitsWidth;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a dense float vector or tensor
|
||||
/// object.
|
||||
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a tensor constant with opaque
|
||||
/// content.
|
||||
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
StringRef bytes;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a sparse vector or tensor object.
|
||||
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
DenseIntElementsAttr indices;
|
||||
DenseElementsAttr values;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace mlir
|
||||
|
||||
#endif // ATTRIBUTEDETAIL_H_
|
|
@ -0,0 +1,214 @@
|
|||
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "AttributeDetail.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
Attribute::Kind Attribute::getKind() const { return attr->kind; }
|
||||
|
||||
bool Attribute::isOrContainsFunction() const {
|
||||
return attr->isOrContainsFunctionCache;
|
||||
}
|
||||
|
||||
BoolAttr::BoolAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
||||
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
int64_t IntegerAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
APFloat FloatAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->getValue();
|
||||
}
|
||||
|
||||
double FloatAttr::getDouble() const { return getValue().convertToDouble(); }
|
||||
|
||||
StringAttr::StringAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
StringRef StringAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
ArrayAttr::ArrayAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
ArrayRef<Attribute> ArrayAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
AffineMapAttr::AffineMapAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
AffineMap AffineMapAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
Type *TypeAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
Function *FunctionAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
|
||||
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
VectorOrTensorType *ElementsAttr::getType() const {
|
||||
return static_cast<ImplType *>(attr)->type;
|
||||
}
|
||||
|
||||
SplatElementsAttr::SplatElementsAttr(Attribute::ImplType *ptr)
|
||||
: ElementsAttr(ptr) {}
|
||||
|
||||
Attribute SplatElementsAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->elt;
|
||||
}
|
||||
|
||||
DenseElementsAttr::DenseElementsAttr(Attribute::ImplType *ptr)
|
||||
: ElementsAttr(ptr) {}
|
||||
|
||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
switch (getKind()) {
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
cast<DenseIntElementsAttr>().getValues(values);
|
||||
return;
|
||||
case Attribute::Kind::DenseFPElements:
|
||||
cast<DenseFPElementsAttr>().getValues(values);
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
||||
return static_cast<ImplType *>(attr)->data;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
|
||||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
||||
/// starting from `rawData`.
|
||||
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value) {
|
||||
// Read the destination bytes which will be written to.
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitWidth;
|
||||
auto start = data + bitPos / 8;
|
||||
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
|
||||
// Clean up the invalid bits in the destination bytes.
|
||||
dst &= ~(-1UL << (bitPos % 8));
|
||||
|
||||
// Get the valid bits of the source value, shift them to right position,
|
||||
// then add them to the destination bytes.
|
||||
value <<= bitPos % 8;
|
||||
dst |= value;
|
||||
|
||||
// Write the destination bytes back.
|
||||
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
||||
std::copy(range.begin(), range.end(), start);
|
||||
}
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
||||
/// and put them in the lowest bits.
|
||||
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth) {
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitsWidth;
|
||||
auto start = rawData + bitPos / 8;
|
||||
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
|
||||
dst >>= bitPos % 8;
|
||||
dst &= ~(-1UL << bitsWidth);
|
||||
return dst;
|
||||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
values.reserve(elementNum);
|
||||
if (bitsWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
for (auto value : vs) {
|
||||
auto attr = IntegerAttr::get(value, context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
} else {
|
||||
const auto *rawData = getRawData().data();
|
||||
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
||||
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
||||
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
||||
auto attr = IntegerAttr::get(value.getSExtValue(), context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
|
||||
: DenseElementsAttr(ptr) {}
|
||||
|
||||
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
values.reserve(elementNum);
|
||||
for (auto v : vs) {
|
||||
auto attr = FloatAttr::get(v, context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
OpaqueElementsAttr::OpaqueElementsAttr(Attribute::ImplType *ptr)
|
||||
: ElementsAttr(ptr) {}
|
||||
|
||||
StringRef OpaqueElementsAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->bytes;
|
||||
}
|
||||
|
||||
SparseElementsAttr::SparseElementsAttr(Attribute::ImplType *ptr)
|
||||
: ElementsAttr(ptr) {}
|
||||
|
||||
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
||||
return static_cast<ImplType *>(attr)->indices;
|
||||
}
|
||||
|
||||
DenseElementsAttr SparseElementsAttr::getValues() const {
|
||||
return static_cast<ImplType *>(attr)->values;
|
||||
}
|
|
@ -112,60 +112,60 @@ UnrankedTensorType *Builder::getTensorType(Type *elementType) {
|
|||
// Attributes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BoolAttr *Builder::getBoolAttr(bool value) {
|
||||
BoolAttr Builder::getBoolAttr(bool value) {
|
||||
return BoolAttr::get(value, context);
|
||||
}
|
||||
|
||||
IntegerAttr *Builder::getIntegerAttr(int64_t value) {
|
||||
IntegerAttr Builder::getIntegerAttr(int64_t value) {
|
||||
return IntegerAttr::get(value, context);
|
||||
}
|
||||
|
||||
FloatAttr *Builder::getFloatAttr(double value) {
|
||||
FloatAttr Builder::getFloatAttr(double value) {
|
||||
return FloatAttr::get(APFloat(value), context);
|
||||
}
|
||||
|
||||
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
|
||||
FloatAttr Builder::getFloatAttr(const APFloat &value) {
|
||||
return FloatAttr::get(value, context);
|
||||
}
|
||||
|
||||
StringAttr *Builder::getStringAttr(StringRef bytes) {
|
||||
StringAttr Builder::getStringAttr(StringRef bytes) {
|
||||
return StringAttr::get(bytes, context);
|
||||
}
|
||||
|
||||
ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
||||
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
|
||||
return ArrayAttr::get(value, context);
|
||||
}
|
||||
|
||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) {
|
||||
AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
|
||||
return AffineMapAttr::get(map);
|
||||
}
|
||||
|
||||
TypeAttr *Builder::getTypeAttr(Type *type) {
|
||||
TypeAttr Builder::getTypeAttr(Type *type) {
|
||||
return TypeAttr::get(type, context);
|
||||
}
|
||||
|
||||
FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
||||
FunctionAttr Builder::getFunctionAttr(const Function *value) {
|
||||
return FunctionAttr::get(value, context);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
||||
Attribute *elt) {
|
||||
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
||||
Attribute elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ArrayRef<char> data) {
|
||||
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
||||
ArrayRef<char> data) {
|
||||
return DenseElementsAttr::get(type, data);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values) {
|
||||
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
return SparseElementsAttr::get(type, indices, values);
|
||||
}
|
||||
|
||||
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
return OpaqueElementsAttr::get(type, bytes);
|
||||
}
|
||||
|
||||
|
|
|
@ -86,13 +86,13 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
auto &builder = parser->getBuilder();
|
||||
auto *affineIntTy = builder.getIndexType();
|
||||
|
||||
AffineMapAttr *mapAttr;
|
||||
AffineMapAttr mapAttr;
|
||||
unsigned numDims;
|
||||
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
|
||||
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
||||
parser->parseOptionalAttributeDict(result->attributes))
|
||||
return true;
|
||||
auto map = mapAttr->getValue();
|
||||
auto map = mapAttr.getValue();
|
||||
|
||||
if (map.getNumDims() != numDims ||
|
||||
numDims + map.getNumSymbols() != result->operands.size()) {
|
||||
|
@ -113,12 +113,12 @@ void AffineApplyOp::print(OpAsmPrinter *p) const {
|
|||
|
||||
bool AffineApplyOp::verify() const {
|
||||
// Check that affine map attribute was specified.
|
||||
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||
if (!affineMapAttr)
|
||||
return emitOpError("requires an affine map");
|
||||
|
||||
// Check input and output dimensions match.
|
||||
auto map = affineMapAttr->getValue();
|
||||
auto map = affineMapAttr.getValue();
|
||||
|
||||
// Verify that operand count matches affine map dimension and symbol count.
|
||||
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
|
||||
|
@ -155,8 +155,8 @@ bool AffineApplyOp::isValidSymbol() const {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
|
||||
SmallVectorImpl<Attribute *> &results,
|
||||
bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
||||
SmallVectorImpl<Attribute> &results,
|
||||
MLIRContext *context) const {
|
||||
auto map = getAffineMap();
|
||||
if (map.constantFold(operandConstants, results))
|
||||
|
@ -171,21 +171,21 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
|
|||
|
||||
/// Builds a constant op with the specified attribute value and result type.
|
||||
void ConstantOp::build(Builder *builder, OperationState *result,
|
||||
Attribute *value, Type *type) {
|
||||
Attribute value, Type *type) {
|
||||
result->addAttribute("value", value);
|
||||
result->types.push_back(type);
|
||||
}
|
||||
|
||||
void ConstantOp::print(OpAsmPrinter *p) const {
|
||||
*p << "constant " << *getValue();
|
||||
*p << "constant " << getValue();
|
||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
||||
|
||||
if (!isa<FunctionAttr>(getValue()))
|
||||
if (!getValue().isa<FunctionAttr>())
|
||||
*p << " : " << *getType();
|
||||
}
|
||||
|
||||
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
Attribute *valueAttr;
|
||||
Attribute valueAttr;
|
||||
Type *type;
|
||||
|
||||
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
||||
|
@ -194,8 +194,8 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
|
||||
// 'constant' taking a function reference doesn't get a redundant type
|
||||
// specifier. The attribute itself carries it.
|
||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
|
||||
return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
|
||||
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
|
||||
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
|
||||
|
||||
return parser->parseColonType(type) ||
|
||||
parser->addTypeToList(type, result->types);
|
||||
|
@ -204,32 +204,32 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
/// The constant op requires an attribute, and furthermore requires that it
|
||||
/// matches the return type.
|
||||
bool ConstantOp::verify() const {
|
||||
auto *value = getValue();
|
||||
auto value = getValue();
|
||||
if (!value)
|
||||
return emitOpError("requires a 'value' attribute");
|
||||
|
||||
auto *type = this->getType();
|
||||
if (isa<IntegerType>(type) || type->isIndex()) {
|
||||
if (!isa<IntegerAttr>(value))
|
||||
if (!value.isa<IntegerAttr>())
|
||||
return emitOpError(
|
||||
"requires 'value' to be an integer for an integer result type");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<FloatType>(type)) {
|
||||
if (!isa<FloatAttr>(value))
|
||||
if (!value.isa<FloatAttr>())
|
||||
return emitOpError("requires 'value' to be a floating point constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type->isTFString()) {
|
||||
if (!isa<StringAttr>(value))
|
||||
if (!value.isa<StringAttr>())
|
||||
return emitOpError("requires 'value' to be a string constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (isa<FunctionType>(type)) {
|
||||
if (!isa<FunctionAttr>(value))
|
||||
if (!value.isa<FunctionAttr>())
|
||||
return emitOpError("requires 'value' to be a function reference");
|
||||
return false;
|
||||
}
|
||||
|
@ -238,8 +238,8 @@ bool ConstantOp::verify() const {
|
|||
"requires a result type that aligns with the 'value' attribute");
|
||||
}
|
||||
|
||||
Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.empty() && "constant has no operands");
|
||||
return getValue();
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "AffineExprDetail.h"
|
||||
#include "AffineMapDetail.h"
|
||||
#include "AttributeDetail.h"
|
||||
#include "AttributeListStorage.h"
|
||||
#include "IntegerSetDetail.h"
|
||||
#include "mlir/IR/AffineExpr.h"
|
||||
|
@ -169,35 +170,35 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
|||
}
|
||||
};
|
||||
|
||||
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
|
||||
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
|
||||
// Float attributes are uniqued based on wrapped APFloat.
|
||||
using KeyTy = APFloat;
|
||||
using DenseMapInfo<FloatAttr *>::getHashValue;
|
||||
using DenseMapInfo<FloatAttr *>::isEqual;
|
||||
using DenseMapInfo<FloatAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<FloatAttributeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs.bitwiseIsEqual(rhs->getValue());
|
||||
}
|
||||
};
|
||||
|
||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
|
||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
|
||||
// Array attributes are uniqued based on their elements.
|
||||
using KeyTy = ArrayRef<Attribute *>;
|
||||
using DenseMapInfo<ArrayAttr *>::getHashValue;
|
||||
using DenseMapInfo<ArrayAttr *>::isEqual;
|
||||
using KeyTy = ArrayRef<Attribute>;
|
||||
using DenseMapInfo<ArrayAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine_range(key.begin(), key.end());
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == rhs->getValue();
|
||||
return lhs == rhs->value;
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -218,37 +219,39 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
|
|||
}
|
||||
};
|
||||
|
||||
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
|
||||
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
||||
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
|
||||
using DenseMapInfo<DenseElementsAttr *>::getHashValue;
|
||||
using DenseMapInfo<DenseElementsAttr *>::isEqual;
|
||||
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs,
|
||||
const DenseElementsAttributeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
|
||||
return lhs == std::make_pair(rhs->type, rhs->data);
|
||||
}
|
||||
};
|
||||
|
||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> {
|
||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
|
||||
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
||||
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue;
|
||||
using DenseMapInfo<OpaqueElementsAttr *>::isEqual;
|
||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) {
|
||||
return hash_combine(
|
||||
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
||||
}
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) {
|
||||
static bool isEqual(const KeyTy &lhs,
|
||||
const OpaqueElementsAttributeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == std::make_pair(rhs->getType(), rhs->getValue());
|
||||
return lhs == std::make_pair(rhs->type, rhs->bytes);
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace.
|
||||
|
@ -343,28 +346,29 @@ public:
|
|||
MemRefTypeSet memrefs;
|
||||
|
||||
// Attribute uniquing.
|
||||
BoolAttr *boolAttrs[2] = {nullptr};
|
||||
DenseMap<int64_t, IntegerAttr *> integerAttrs;
|
||||
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
|
||||
StringMap<StringAttr *> stringAttrs;
|
||||
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
||||
BoolAttributeStorage *boolAttrs[2] = {nullptr};
|
||||
DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
|
||||
DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
|
||||
StringMap<StringAttributeStorage *> stringAttrs;
|
||||
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
|
||||
ArrayAttrSet arrayAttrs;
|
||||
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs;
|
||||
DenseMap<Type *, TypeAttr *> typeAttrs;
|
||||
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
|
||||
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
|
||||
using AttributeListSet =
|
||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||
AttributeListSet attributeLists;
|
||||
DenseMap<const Function *, FunctionAttr *> functionAttrs;
|
||||
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
|
||||
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
|
||||
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
|
||||
SplatElementsAttributeStorage *>
|
||||
splatElementsAttrs;
|
||||
using DenseElementsAttrSet =
|
||||
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
|
||||
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
|
||||
DenseElementsAttrSet denseElementsAttrs;
|
||||
using OpaqueElementsAttrSet =
|
||||
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>;
|
||||
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
|
||||
OpaqueElementsAttrSet opaqueElementsAttrs;
|
||||
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
|
||||
SparseElementsAttr *>
|
||||
DenseMap<std::tuple<Type *, Attribute, Attribute>,
|
||||
SparseElementsAttributeStorage *>
|
||||
sparseElementsAttrs;
|
||||
|
||||
public:
|
||||
|
@ -716,31 +720,36 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
|||
// Attribute uniquing
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
BoolAttr *BoolAttr::get(bool value, MLIRContext *context) {
|
||||
BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().boolAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<BoolAttr>();
|
||||
new (result) BoolAttr(value);
|
||||
result = context->getImpl().allocator.Allocate<BoolAttributeStorage>();
|
||||
new (result) BoolAttributeStorage{{Attribute::Kind::Bool,
|
||||
/*isOrContainsFunction=*/false},
|
||||
value};
|
||||
return result;
|
||||
}
|
||||
|
||||
IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().integerAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<IntegerAttr>();
|
||||
new (result) IntegerAttr(value);
|
||||
result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
|
||||
new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
|
||||
/*isOrContainsFunction=*/false},
|
||||
value};
|
||||
result->value = value;
|
||||
return result;
|
||||
}
|
||||
|
||||
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
|
||||
FloatAttr FloatAttr::get(double value, MLIRContext *context) {
|
||||
return get(APFloat(value), context);
|
||||
}
|
||||
|
||||
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||
FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if the float attribute has been created already.
|
||||
|
@ -755,33 +764,35 @@ FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
|||
// Here one word's bitwidth equals to that of uint64_t.
|
||||
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
||||
|
||||
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
|
||||
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
|
||||
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
|
||||
auto byteSize =
|
||||
FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
|
||||
auto rawMem =
|
||||
impl.allocator.Allocate(byteSize, alignof(FloatAttributeStorage));
|
||||
auto result = ::new (rawMem) FloatAttributeStorage{
|
||||
{Attribute::Kind::Float, /*isOrContainsFunction=*/false},
|
||||
{},
|
||||
value.getSemantics(),
|
||||
elements.size()};
|
||||
std::uninitialized_copy(elements.begin(), elements.end(),
|
||||
result->getTrailingObjects<uint64_t>());
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
APFloat FloatAttr::getValue() const {
|
||||
auto val = APInt(APFloat::getSizeInBits(semantics),
|
||||
{getTrailingObjects<uint64_t>(), numObjects});
|
||||
return APFloat(semantics, val);
|
||||
}
|
||||
|
||||
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||
auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
|
||||
|
||||
if (it->second)
|
||||
return it->second;
|
||||
|
||||
auto result = context->getImpl().allocator.Allocate<StringAttr>();
|
||||
new (result) StringAttr(it->first());
|
||||
auto result = context->getImpl().allocator.Allocate<StringAttributeStorage>();
|
||||
new (result) StringAttributeStorage{{Attribute::Kind::String,
|
||||
/*isOrContainsFunction=*/false},
|
||||
it->first()};
|
||||
it->second = result;
|
||||
return result;
|
||||
}
|
||||
|
||||
ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
|
||||
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
|
@ -792,61 +803,66 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
|
|||
return *existing.first;
|
||||
|
||||
// On the first use, we allocate them into the bump pointer.
|
||||
auto *result = impl.allocator.Allocate<ArrayAttr>();
|
||||
auto *result = impl.allocator.Allocate<ArrayAttributeStorage>();
|
||||
|
||||
// Copy the elements into the bump pointer.
|
||||
value = impl.copyInto(value);
|
||||
|
||||
// Check to see if any of the elements have a function attr.
|
||||
bool hasFunctionAttr = false;
|
||||
for (auto *elt : value)
|
||||
if (elt->isOrContainsFunction()) {
|
||||
for (auto elt : value)
|
||||
if (elt.isOrContainsFunction()) {
|
||||
hasFunctionAttr = true;
|
||||
break;
|
||||
}
|
||||
|
||||
// Initialize the memory using placement new.
|
||||
new (result) ArrayAttr(value, hasFunctionAttr);
|
||||
new (result)
|
||||
ArrayAttributeStorage{{Attribute::Kind::Array, hasFunctionAttr}, value};
|
||||
|
||||
// Cache and return it.
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
AffineMapAttr *AffineMapAttr::get(AffineMap value) {
|
||||
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
||||
auto *context = value.getResult(0).getContext();
|
||||
auto &result = context->getImpl().affineMapAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<AffineMapAttr>();
|
||||
new (result) AffineMapAttr(value);
|
||||
result = context->getImpl().allocator.Allocate<AffineMapAttributeStorage>();
|
||||
new (result) AffineMapAttributeStorage{{Attribute::Kind::AffineMap,
|
||||
/*isOrContainsFunction=*/false},
|
||||
value};
|
||||
return result;
|
||||
}
|
||||
|
||||
TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
|
||||
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().typeAttrs[type];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<TypeAttr>();
|
||||
new (result) TypeAttr(type);
|
||||
result = context->getImpl().allocator.Allocate<TypeAttributeStorage>();
|
||||
new (result) TypeAttributeStorage{{Attribute::Kind::Type,
|
||||
/*isOrContainsFunction=*/false},
|
||||
type};
|
||||
return result;
|
||||
}
|
||||
|
||||
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
|
||||
FunctionAttr FunctionAttr::get(const Function *value, MLIRContext *context) {
|
||||
assert(value && "Cannot get FunctionAttr for a null function");
|
||||
|
||||
auto *&result = context->getImpl().functionAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
result = context->getImpl().allocator.Allocate<FunctionAttr>();
|
||||
new (result) FunctionAttr(const_cast<Function *>(value));
|
||||
result = context->getImpl().allocator.Allocate<FunctionAttributeStorage>();
|
||||
new (result) FunctionAttributeStorage{{Attribute::Kind::Function,
|
||||
/*isOrContainsFunction=*/true},
|
||||
const_cast<Function *>(value)};
|
||||
return result;
|
||||
}
|
||||
|
||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
|
||||
/// This function is used by the internals of the Function class to null out
|
||||
/// attributes refering to functions that are about to be deleted.
|
||||
void FunctionAttr::dropFunctionReference(Function *value) {
|
||||
|
@ -935,30 +951,29 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
|||
return *existing.first = result;
|
||||
}
|
||||
|
||||
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
assert(isValidTensorElementType(type->getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
||||
Attribute elt) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
|
||||
// Look to see if this constant is already defined.
|
||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
|
||||
// Look to see if we already have this.
|
||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
if (result)
|
||||
return result;
|
||||
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>();
|
||||
bytes = bytes.copy(impl.allocator);
|
||||
new (result) OpaqueElementsAttr(type, bytes);
|
||||
return *existing.first = result;
|
||||
// Otherwise, allocate them into the bump pointer.
|
||||
result = impl.allocator.Allocate<SplatElementsAttributeStorage>();
|
||||
new (result) SplatElementsAttributeStorage{{{Attribute::Kind::SplatElements,
|
||||
/*isOrContainsFunction=*/false},
|
||||
type},
|
||||
elt};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
||||
ArrayRef<char> data) {
|
||||
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||
ArrayRef<char> data) {
|
||||
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
||||
(void)(bitsRequired);
|
||||
assert((bitsRequired <= data.size() * 8L) &&
|
||||
|
@ -981,18 +996,25 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
|||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
case Type::Kind::F64: {
|
||||
auto *result = impl.allocator.Allocate<DenseFPElementsAttr>();
|
||||
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
new (result) DenseFPElementsAttr(type, {copy, data.size()});
|
||||
new (result) DenseFPElementsAttributeStorage{
|
||||
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
|
||||
type},
|
||||
{copy, data.size()}}};
|
||||
return *existing.first = result;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
auto width = cast<IntegerType>(eltType)->getWidth();
|
||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttr>();
|
||||
auto width = ::cast<IntegerType>(eltType)->getWidth();
|
||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
new (result) DenseIntElementsAttr(type, {copy, data.size()}, width);
|
||||
new (result) DenseIntElementsAttributeStorage{
|
||||
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
|
||||
type},
|
||||
{copy, data.size()}},
|
||||
width};
|
||||
return *existing.first = result;
|
||||
}
|
||||
default:
|
||||
|
@ -1000,118 +1022,33 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
|||
}
|
||||
}
|
||||
|
||||
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
||||
/// starting from `rawData`.
|
||||
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||
uint64_t value) {
|
||||
// Read the destination bytes which will be written to.
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitWidth;
|
||||
auto start = data + bitPos / 8;
|
||||
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||
StringRef bytes) {
|
||||
assert(isValidTensorElementType(type->getElementType()) &&
|
||||
"Input element type should be a valid tensor element type");
|
||||
|
||||
// Clean up the invalid bits in the destination bytes.
|
||||
dst &= ~(-1UL << (bitPos % 8));
|
||||
|
||||
// Get the valid bits of the source value, shift them to right position,
|
||||
// then add them to the destination bytes.
|
||||
value <<= bitPos % 8;
|
||||
dst |= value;
|
||||
|
||||
// Write the destination bytes back.
|
||||
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
||||
std::copy(range.begin(), range.end(), start);
|
||||
}
|
||||
|
||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
||||
/// and put them in the lowest bits.
|
||||
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||
size_t bitsWidth) {
|
||||
uint64_t dst = 0;
|
||||
auto dstData = reinterpret_cast<char *>(&dst);
|
||||
auto endPos = bitPos + bitsWidth;
|
||||
auto start = rawData + bitPos / 8;
|
||||
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
||||
std::copy(start, end, dstData);
|
||||
|
||||
dst >>= bitPos % 8;
|
||||
dst &= ~(-1UL << bitsWidth);
|
||||
return dst;
|
||||
}
|
||||
|
||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute *> &values) const {
|
||||
switch (getKind()) {
|
||||
case Attribute::Kind::DenseIntElements:
|
||||
cast<DenseIntElementsAttr>(this)->getValues(values);
|
||||
return;
|
||||
case Attribute::Kind::DenseFPElements:
|
||||
cast<DenseFPElementsAttr>(this)->getValues(values);
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected element type");
|
||||
}
|
||||
}
|
||||
|
||||
void DenseIntElementsAttr::getValues(
|
||||
SmallVectorImpl<Attribute *> &values) const {
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
values.reserve(elementNum);
|
||||
if (bitsWidth == 64) {
|
||||
ArrayRef<int64_t> vs(
|
||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
for (auto value : vs) {
|
||||
auto *attr = IntegerAttr::get(value, context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
} else {
|
||||
const auto *rawData = getRawData().data();
|
||||
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
||||
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
||||
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
||||
auto *attr = IntegerAttr::get(value.getSExtValue(), context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DenseFPElementsAttr::getValues(
|
||||
SmallVectorImpl<Attribute *> &values) const {
|
||||
auto elementNum = getType()->getNumElements();
|
||||
auto context = getType()->getContext();
|
||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||
getRawData().size() / 8});
|
||||
values.reserve(elementNum);
|
||||
for (auto v : vs) {
|
||||
auto *attr = FloatAttr::get(v, context);
|
||||
values.push_back(attr);
|
||||
}
|
||||
}
|
||||
|
||||
SplatElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type,
|
||||
Attribute *elt) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||
// Look to see if this constant is already defined.
|
||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
|
||||
|
||||
// If we already have it, return that value.
|
||||
if (result)
|
||||
return result;
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// Otherwise, allocate them into the bump pointer.
|
||||
result = impl.allocator.Allocate<SplatElementsAttr>();
|
||||
new (result) SplatElementsAttr(type, elt);
|
||||
|
||||
return result;
|
||||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto *result = impl.allocator.Allocate<OpaqueElementsAttributeStorage>();
|
||||
bytes = bytes.copy(impl.allocator);
|
||||
new (result) OpaqueElementsAttributeStorage{
|
||||
{{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type},
|
||||
bytes};
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr *indices,
|
||||
DenseElementsAttr *values) {
|
||||
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
|
||||
DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values) {
|
||||
auto &impl = type->getContext()->getImpl();
|
||||
|
||||
// Look to see if we already have this.
|
||||
|
@ -1123,8 +1060,12 @@ SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
|
|||
return result;
|
||||
|
||||
// Otherwise, allocate them into the bump pointer.
|
||||
result = impl.allocator.Allocate<SparseElementsAttr>();
|
||||
new (result) SparseElementsAttr(type, indices, values);
|
||||
result = impl.allocator.Allocate<SparseElementsAttributeStorage>();
|
||||
new (result) SparseElementsAttributeStorage{{{Attribute::Kind::SparseElements,
|
||||
/*isOrContainsFunction=*/false},
|
||||
type},
|
||||
indices,
|
||||
values};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
|
|
@ -148,7 +148,7 @@ ArrayRef<NamedAttribute> Operation::getAttrs() const {
|
|||
|
||||
/// If an attribute exists with the specified name, change it to the new
|
||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||
void Operation::setAttr(Identifier name, Attribute *value) {
|
||||
void Operation::setAttr(Identifier name, Attribute value) {
|
||||
assert(value && "attributes may never be null");
|
||||
auto origAttrs = getAttrs();
|
||||
|
||||
|
@ -225,8 +225,8 @@ void Operation::erase() {
|
|||
/// Attempt to constant fold this operation with the specified constant
|
||||
/// operand values. If successful, this returns false and fills in the
|
||||
/// results vector. If not, this returns true and results is unspecified.
|
||||
bool Operation::constantFold(ArrayRef<Attribute *> operands,
|
||||
SmallVectorImpl<Attribute *> &results) const {
|
||||
bool Operation::constantFold(ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) const {
|
||||
// If we have a registered operation definition matching this one, use it to
|
||||
// try to constant fold the operation.
|
||||
if (auto *abstractOp = getAbstractOperation())
|
||||
|
|
|
@ -195,7 +195,7 @@ public:
|
|||
// Attribute parsing.
|
||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||
FunctionType *type);
|
||||
Attribute *parseAttribute();
|
||||
Attribute parseAttribute();
|
||||
|
||||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||
|
||||
|
@ -204,8 +204,8 @@ public:
|
|||
AffineMap parseAffineMapReference();
|
||||
IntegerSet parseIntegerSetInline();
|
||||
IntegerSet parseIntegerSetReference();
|
||||
DenseElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
|
||||
DenseElementsAttr *parseDenseElementsAttr(Type *eltType, bool isVector);
|
||||
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
|
||||
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
|
||||
VectorOrTensorType *parseVectorOrTensorType();
|
||||
|
||||
private:
|
||||
|
@ -684,7 +684,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
case Token::floatliteral:
|
||||
case Token::integer:
|
||||
case Token::minus: {
|
||||
auto *result = p.parseAttribute();
|
||||
auto result = p.parseAttribute();
|
||||
if (!result)
|
||||
return p.emitError("expected tensor element");
|
||||
// check result matches the element type.
|
||||
|
@ -693,16 +693,16 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
case Type::Kind::F64: {
|
||||
if (!isa<FloatAttr>(result))
|
||||
if (!result.isa<FloatAttr>())
|
||||
return p.emitError("expected tensor literal element has float type");
|
||||
double value = cast<FloatAttr>(result)->getDouble();
|
||||
double value = result.cast<FloatAttr>().getDouble();
|
||||
addToStorage(*(uint64_t *)(&value));
|
||||
break;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
if (!isa<IntegerAttr>(result))
|
||||
if (!result.isa<IntegerAttr>())
|
||||
return p.emitError("expected tensor literal element has integer type");
|
||||
auto value = cast<IntegerAttr>(result)->getValue();
|
||||
auto value = result.cast<IntegerAttr>().getValue();
|
||||
// If we couldn't successfully round trip the value, it means some bits
|
||||
// are truncated and we should give up here.
|
||||
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
|
||||
|
@ -804,7 +804,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
|||
/// | `sparse<` (tensor-type | vector-type)`,`
|
||||
/// attribute-value`, ` attribute-value `>`
|
||||
///
|
||||
Attribute *Parser::parseAttribute() {
|
||||
Attribute Parser::parseAttribute() {
|
||||
switch (getToken().getKind()) {
|
||||
case Token::kw_true:
|
||||
consumeToken(Token::kw_true);
|
||||
|
@ -859,7 +859,7 @@ Attribute *Parser::parseAttribute() {
|
|||
|
||||
case Token::l_square: {
|
||||
consumeToken(Token::l_square);
|
||||
SmallVector<Attribute *, 4> elements;
|
||||
SmallVector<Attribute, 4> elements;
|
||||
|
||||
auto parseElt = [&]() -> ParseResult {
|
||||
elements.push_back(parseAttribute());
|
||||
|
@ -928,7 +928,7 @@ Attribute *Parser::parseAttribute() {
|
|||
case Token::floatliteral:
|
||||
case Token::integer:
|
||||
case Token::minus: {
|
||||
auto *scalar = parseAttribute();
|
||||
auto scalar = parseAttribute();
|
||||
if (parseToken(Token::greater, "expected '>'"))
|
||||
return nullptr;
|
||||
return builder.getSplatElementsAttr(type, scalar);
|
||||
|
@ -973,7 +973,7 @@ Attribute *Parser::parseAttribute() {
|
|||
case Token::l_square: {
|
||||
/// Parse indices
|
||||
auto *indicesEltType = builder.getIntegerType(32);
|
||||
auto *indices =
|
||||
auto indices =
|
||||
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
|
||||
|
||||
if (parseToken(Token::comma, "expected ','"))
|
||||
|
@ -981,12 +981,12 @@ Attribute *Parser::parseAttribute() {
|
|||
|
||||
/// Parse values.
|
||||
auto *valuesEltType = type->getElementType();
|
||||
auto *values =
|
||||
auto values =
|
||||
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
|
||||
|
||||
/// Sanity check.
|
||||
auto *indicesType = indices->getType();
|
||||
auto *valuesType = values->getType();
|
||||
auto *indicesType = indices.getType();
|
||||
auto *valuesType = values.getType();
|
||||
auto sameShape = (indicesType->getRank() == 1) ||
|
||||
(type->getRank() == indicesType->getDimSize(1));
|
||||
auto sameElementNum =
|
||||
|
@ -1009,7 +1009,7 @@ Attribute *Parser::parseAttribute() {
|
|||
|
||||
// Build the sparse elements attribute by the indices and values.
|
||||
return builder.getSparseElementsAttr(
|
||||
type, cast<DenseIntElementsAttr>(indices), values);
|
||||
type, indices.cast<DenseIntElementsAttr>(), values);
|
||||
}
|
||||
default:
|
||||
return (emitError("expected '[' to start sparse tensor literal"),
|
||||
|
@ -1035,8 +1035,7 @@ Attribute *Parser::parseAttribute() {
|
|||
///
|
||||
/// This method returns a constructed dense elements attribute with the shape
|
||||
/// from the parsing result.
|
||||
DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
||||
bool isVector) {
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
||||
TensorLiteralParser literalParser(*this, eltType);
|
||||
if (literalParser.parse())
|
||||
return nullptr;
|
||||
|
@ -1047,8 +1046,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
|||
} else {
|
||||
type = builder.getTensorType(literalParser.getShape(), eltType);
|
||||
}
|
||||
return (DenseElementsAttr *)builder.getDenseElementsAttr(
|
||||
type, literalParser.getValues());
|
||||
return builder.getDenseElementsAttr(type, literalParser.getValues())
|
||||
.cast<DenseElementsAttr>();
|
||||
}
|
||||
|
||||
/// Dense elements attribute.
|
||||
|
@ -1061,7 +1060,7 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
|||
/// This method compares the shapes from the parsing result and that from the
|
||||
/// input argument. It returns a constructed dense elements attribute if both
|
||||
/// match.
|
||||
DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||
auto *eltTy = type->getElementType();
|
||||
TensorLiteralParser literalParser(*this, eltTy);
|
||||
if (literalParser.parse())
|
||||
|
@ -1076,8 +1075,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
|||
s << "])";
|
||||
return (emitError(s.str()), nullptr);
|
||||
}
|
||||
return (DenseElementsAttr *)builder.getDenseElementsAttr(
|
||||
type, literalParser.getValues());
|
||||
return builder.getDenseElementsAttr(type, literalParser.getValues())
|
||||
.cast<DenseElementsAttr>();
|
||||
}
|
||||
|
||||
/// Vector or tensor type for elements attribute.
|
||||
|
@ -2133,7 +2132,7 @@ public:
|
|||
/// Parse an arbitrary attribute and return it in result. This also adds
|
||||
/// the attribute to the specified attribute list with the specified name.
|
||||
/// this captures the location of the attribute in 'loc' if it is non-null.
|
||||
bool parseAttribute(Attribute *&result, const char *attrName,
|
||||
bool parseAttribute(Attribute &result, const char *attrName,
|
||||
SmallVectorImpl<NamedAttribute> &attrs) override {
|
||||
result = parser.parseAttribute();
|
||||
if (!result)
|
||||
|
@ -3336,27 +3335,27 @@ ParseResult ModuleParser::parseMLFunc() {
|
|||
/// Given an attribute that could refer to a function attribute in the
|
||||
/// remapping table, walk it and rewrite it to use the mapped function. If it
|
||||
/// doesn't refer to anything in the table, then it is returned unmodified.
|
||||
static Attribute *
|
||||
remapFunctionAttrs(Attribute *input,
|
||||
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
|
||||
static Attribute
|
||||
remapFunctionAttrs(Attribute input,
|
||||
DenseMap<Attribute, FunctionAttr> &remappingTable,
|
||||
MLIRContext *context) {
|
||||
// Most attributes are trivially unrelated to function attributes, skip them
|
||||
// rapidly.
|
||||
if (!input->isOrContainsFunction())
|
||||
if (!input.isOrContainsFunction())
|
||||
return input;
|
||||
|
||||
// If we have a function attribute, remap it.
|
||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) {
|
||||
if (auto fnAttr = input.dyn_cast<FunctionAttr>()) {
|
||||
auto it = remappingTable.find(fnAttr);
|
||||
return it != remappingTable.end() ? it->second : input;
|
||||
}
|
||||
|
||||
// Otherwise, we must have an array attribute, remap the elements.
|
||||
auto *arrayAttr = cast<ArrayAttr>(input);
|
||||
SmallVector<Attribute *, 8> remappedElts;
|
||||
auto arrayAttr = input.cast<ArrayAttr>();
|
||||
SmallVector<Attribute, 8> remappedElts;
|
||||
bool anyChange = false;
|
||||
for (auto *elt : arrayAttr->getValue()) {
|
||||
auto *newElt = remapFunctionAttrs(elt, remappingTable, context);
|
||||
for (auto elt : arrayAttr.getValue()) {
|
||||
auto newElt = remapFunctionAttrs(elt, remappingTable, context);
|
||||
remappedElts.push_back(newElt);
|
||||
anyChange |= (elt != newElt);
|
||||
}
|
||||
|
@ -3370,11 +3369,11 @@ remapFunctionAttrs(Attribute *input,
|
|||
/// Remap function attributes to resolve forward references to their actual
|
||||
/// definition.
|
||||
static void remapFunctionAttrsInOperation(
|
||||
Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) {
|
||||
Operation *op, DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||
for (auto attr : op->getAttrs()) {
|
||||
// Do the remapping, if we got the same thing back, then it must contain
|
||||
// functions that aren't getting remapped.
|
||||
auto *newVal =
|
||||
auto newVal =
|
||||
remapFunctionAttrs(attr.second, remappingTable, op->getContext());
|
||||
if (newVal == attr.second)
|
||||
continue;
|
||||
|
@ -3391,7 +3390,7 @@ static void remapFunctionAttrsInOperation(
|
|||
ParseResult ModuleParser::finalizeModule() {
|
||||
|
||||
// Resolve all forward references, building a remapping table of attributes.
|
||||
DenseMap<FunctionAttr *, FunctionAttr *> remappingTable;
|
||||
DenseMap<Attribute, FunctionAttr> remappingTable;
|
||||
for (auto forwardRef : getState().functionForwardRefs) {
|
||||
auto name = forwardRef.first;
|
||||
|
||||
|
@ -3428,13 +3427,13 @@ ParseResult ModuleParser::finalizeModule() {
|
|||
continue;
|
||||
|
||||
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
||||
MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable)
|
||||
MLFnWalker(DenseMap<Attribute, FunctionAttr> &remappingTable)
|
||||
: remappingTable(remappingTable) {}
|
||||
void visitOperationStmt(OperationStmt *opStmt) {
|
||||
remapFunctionAttrsInOperation(opStmt, remappingTable);
|
||||
}
|
||||
|
||||
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable;
|
||||
DenseMap<Attribute, FunctionAttr> &remappingTable;
|
||||
};
|
||||
|
||||
MLFnWalker(remappingTable).walk(mlFn);
|
||||
|
|
|
@ -44,13 +44,13 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
|||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "addf takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
||||
return FloatAttr::get(lhs->getValue() + rhs->getValue(), context);
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||
return FloatAttr::get(lhs.getValue() + rhs.getValue(), context);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -60,13 +60,13 @@ Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
// AddIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "addi takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
||||
return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context);
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||
return IntegerAttr::get(lhs.getValue() + rhs.getValue(), context);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -192,12 +192,12 @@ void CallOp::print(OpAsmPrinter *p) const {
|
|||
|
||||
bool CallOp::verify() const {
|
||||
// Check that the callee attribute was specified.
|
||||
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
|
||||
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
|
||||
if (!fnAttr)
|
||||
return emitOpError("requires a 'callee' function attribute");
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
auto *fnType = fnAttr->getValue()->getType();
|
||||
auto *fnType = fnAttr.getValue()->getType();
|
||||
if (fnType->getNumInputs() != getNumOperands())
|
||||
return emitOpError("incorrect number of operands for callee");
|
||||
|
||||
|
@ -329,7 +329,7 @@ void DimOp::print(OpAsmPrinter *p) const {
|
|||
|
||||
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
OpAsmParser::OperandType operandInfo;
|
||||
IntegerAttr *indexAttr;
|
||||
IntegerAttr indexAttr;
|
||||
Type *type;
|
||||
|
||||
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
||||
|
@ -346,7 +346,7 @@ bool DimOp::verify() const {
|
|||
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
||||
if (!indexAttr)
|
||||
return emitOpError("requires an integer attribute named 'index'");
|
||||
uint64_t index = (uint64_t)indexAttr->getValue();
|
||||
uint64_t index = (uint64_t)indexAttr.getValue();
|
||||
|
||||
auto *type = getOperand()->getType();
|
||||
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||
|
@ -365,8 +365,8 @@ bool DimOp::verify() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
// Constant fold dim when the size along the index referred to is a constant.
|
||||
auto *opType = getOperand()->getType();
|
||||
int indexSize = -1;
|
||||
|
@ -671,13 +671,13 @@ bool MemRefCastOp::verify() const {
|
|||
// MulFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "mulf takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
||||
return FloatAttr::get(lhs->getValue() * rhs->getValue(), context);
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||
return FloatAttr::get(lhs.getValue() * rhs.getValue(), context);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -687,23 +687,23 @@ Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
// MulIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "muli takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||
// 0*x == 0
|
||||
if (lhs->getValue() == 0)
|
||||
if (lhs.getValue() == 0)
|
||||
return lhs;
|
||||
|
||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||
// TODO: Handle the overflow case.
|
||||
return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context);
|
||||
return IntegerAttr::get(lhs.getValue() * rhs.getValue(), context);
|
||||
}
|
||||
|
||||
// x*0 == 0
|
||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
||||
if (rhs->getValue() == 0)
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||
if (rhs.getValue() == 0)
|
||||
return rhs;
|
||||
|
||||
return nullptr;
|
||||
|
@ -817,13 +817,13 @@ bool StoreOp::verify() const {
|
|||
// SubFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "subf takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
||||
return FloatAttr::get(lhs->getValue() - rhs->getValue(), context);
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||
return FloatAttr::get(lhs.getValue() - rhs.getValue(), context);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
@ -833,13 +833,13 @@ Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
|
|||
// SubIOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
|
||||
MLIRContext *context) const {
|
||||
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
|
||||
MLIRContext *context) const {
|
||||
assert(operands.size() == 2 && "subi takes two operands");
|
||||
|
||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
||||
return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context);
|
||||
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||
return IntegerAttr::get(lhs.getValue() - rhs.getValue(), context);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
|
|
|
@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
|||
SmallVector<SSAValue *, 8> existingConstants;
|
||||
// Operation statements that were folded and that need to be erased.
|
||||
std::vector<OperationStmt *> opStmtsToErase;
|
||||
using ConstantFactoryType = std::function<SSAValue *(Attribute *, Type *)>;
|
||||
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
|
||||
|
||||
bool foldOperation(Operation *op,
|
||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||
|
@ -60,9 +60,9 @@ bool ConstantFold::foldOperation(Operation *op,
|
|||
|
||||
// Check to see if each of the operands is a trivial constant. If so, get
|
||||
// the value. If not, ignore the instruction.
|
||||
SmallVector<Attribute *, 8> operandConstants;
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
for (auto *operand : op->getOperands()) {
|
||||
Attribute *operandCst = nullptr;
|
||||
Attribute operandCst = nullptr;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
|
@ -71,7 +71,7 @@ bool ConstantFold::foldOperation(Operation *op,
|
|||
}
|
||||
|
||||
// Attempt to constant fold the operation.
|
||||
SmallVector<Attribute *, 8> resultConstants;
|
||||
SmallVector<Attribute, 8> resultConstants;
|
||||
if (op->constantFold(operandConstants, resultConstants))
|
||||
return true;
|
||||
|
||||
|
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
||||
auto &inst = *instIt++;
|
||||
|
||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||
builder.setInsertionPoint(&inst);
|
||||
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
||||
};
|
||||
|
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
|||
|
||||
// Override the walker's operation statement visit for constant folding.
|
||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
||||
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||
MLFuncBuilder builder(stmt);
|
||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||
};
|
||||
|
|
|
@ -71,8 +71,8 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
|
|||
|
||||
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
|
||||
for (auto attr : opStmt->getAttrs()) {
|
||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
|
||||
MutableAffineMap mMap(mapAttr->getValue());
|
||||
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||
MutableAffineMap mMap(mapAttr.getValue());
|
||||
mMap.simplify();
|
||||
auto map = mMap.getAffineMap();
|
||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||
|
|
|
@ -79,7 +79,7 @@ private:
|
|||
/// As part of canonicalization, we move constants to the top of the entry
|
||||
/// block of the current function and de-duplicate them. This keeps track of
|
||||
/// constants we have done this for.
|
||||
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
|
||||
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
|
||||
};
|
||||
}; // end anonymous namespace
|
||||
|
||||
|
@ -107,7 +107,7 @@ public:
|
|||
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
||||
WorklistRewriter &rewriter) {
|
||||
// These are scratch vectors used in the constant folding loop below.
|
||||
SmallVector<Attribute *, 8> operandConstants, resultConstants;
|
||||
SmallVector<Attribute, 8> operandConstants, resultConstants;
|
||||
|
||||
while (!worklist.empty()) {
|
||||
auto *op = popFromWorklist();
|
||||
|
@ -175,7 +175,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
|||
// the operation knows how to constant fold itself.
|
||||
operandConstants.clear();
|
||||
for (auto *operand : op->getOperands()) {
|
||||
Attribute *operandCst = nullptr;
|
||||
Attribute operandCst;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
|
|
|
@ -353,11 +353,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
|
||||
// Check to see if each of the operands is the result of a constant. If so,
|
||||
// get the value. If not, ignore it.
|
||||
SmallVector<Attribute *, 8> operandConstants;
|
||||
SmallVector<Attribute, 8> operandConstants;
|
||||
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
|
||||
: forStmt->getUpperBoundOperands();
|
||||
for (const auto *operand : boundOperands) {
|
||||
Attribute *operandCst = nullptr;
|
||||
Attribute operandCst;
|
||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||
operandCst = operandConstantOp->getValue();
|
||||
|
@ -369,15 +369,15 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
|
||||
assert(boundMap.getNumResults() >= 1 &&
|
||||
"bound maps should have at least one result");
|
||||
SmallVector<Attribute *, 4> foldedResults;
|
||||
SmallVector<Attribute, 4> foldedResults;
|
||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||
return true;
|
||||
|
||||
// Compute the max or min as applicable over the results.
|
||||
assert(!foldedResults.empty() && "bounds should have at least one result");
|
||||
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
|
||||
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
|
||||
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
||||
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
|
||||
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
|
||||
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
||||
: std::min(maxOrMin, foldedResult);
|
||||
}
|
||||
|
|
|
@ -154,7 +154,7 @@ void OpEmitter::emitAttrGetters() {
|
|||
<< val.getName() << "() const {\n";
|
||||
os << " return this->getAttrOfType<"
|
||||
<< attr.getValueAsString("AttrType").trim() << ">(\"" << name
|
||||
<< "\")->getValue();\n }\n";
|
||||
<< "\").getValue();\n }\n";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -207,9 +207,9 @@ void OpEmitter::emitVerifier() {
|
|||
// Verify the attributes have the correct type.
|
||||
for (const auto attr : attrs) {
|
||||
auto name = attr.first->getName();
|
||||
os << " if (!dyn_cast_or_null<"
|
||||
<< attr.second->getValueAsString("AttrType") << ">(this->getAttr(\""
|
||||
<< name << "\"))) return emitOpError(\"requires "
|
||||
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||
<< attr.second->getValueAsString("AttrType") << ">("
|
||||
<< ")) return emitOpError(\"requires "
|
||||
<< attr.second->getValueAsString("PrimitiveType").trim()
|
||||
<< " attribute '" << name << "'\");\n";
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue