Implement value type abstraction for attributes.
This is done by changing Attribute to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast. PiperOrigin-RevId: 218764173
This commit is contained in:
parent
64d52014bd
commit
792d1c25e4
|
@ -95,8 +95,8 @@ public:
|
||||||
/// Folds the results of the application of an affine map on the provided
|
/// Folds the results of the application of an affine map on the provided
|
||||||
/// operands to a constant if possible. Returns false if the folding happens,
|
/// operands to a constant if possible. Returns false if the folding happens,
|
||||||
/// true otherwise.
|
/// true otherwise.
|
||||||
bool constantFold(ArrayRef<Attribute *> operandConstants,
|
bool constantFold(ArrayRef<Attribute> operandConstants,
|
||||||
SmallVectorImpl<Attribute *> &results) const;
|
SmallVectorImpl<Attribute> &results) const;
|
||||||
|
|
||||||
friend ::llvm::hash_code hash_value(AffineMap arg);
|
friend ::llvm::hash_code hash_value(AffineMap arg);
|
||||||
|
|
||||||
|
|
|
@ -30,10 +30,32 @@ class MLIRContext;
|
||||||
class Type;
|
class Type;
|
||||||
class VectorOrTensorType;
|
class VectorOrTensorType;
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
struct AttributeStorage;
|
||||||
|
struct BoolAttributeStorage;
|
||||||
|
struct IntegerAttributeStorage;
|
||||||
|
struct FloatAttributeStorage;
|
||||||
|
struct StringAttributeStorage;
|
||||||
|
struct ArrayAttributeStorage;
|
||||||
|
struct AffineMapAttributeStorage;
|
||||||
|
struct TypeAttributeStorage;
|
||||||
|
struct FunctionAttributeStorage;
|
||||||
|
struct ElementsAttributeStorage;
|
||||||
|
struct SplatElementsAttributeStorage;
|
||||||
|
struct DenseElementsAttributeStorage;
|
||||||
|
struct DenseIntElementsAttributeStorage;
|
||||||
|
struct DenseFPElementsAttributeStorage;
|
||||||
|
struct OpaqueElementsAttributeStorage;
|
||||||
|
struct SparseElementsAttributeStorage;
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
/// Attributes are known-constant values of operations and functions.
|
/// Attributes are known-constant values of operations and functions.
|
||||||
///
|
///
|
||||||
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
|
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
|
||||||
/// by MLIRContext. As such, they are passed around by raw non-const pointer.
|
/// by MLIRContext. As such, an Attribute is a POD interface to an underlying
|
||||||
|
/// storage pointer.
|
||||||
class Attribute {
|
class Attribute {
|
||||||
public:
|
public:
|
||||||
enum class Kind {
|
enum class Kind {
|
||||||
|
@ -55,177 +77,151 @@ public:
|
||||||
LAST_ELEMENTS_ATTR = SparseElements,
|
LAST_ELEMENTS_ATTR = SparseElements,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef detail::AttributeStorage ImplType;
|
||||||
|
|
||||||
|
Attribute() : attr(nullptr) {}
|
||||||
|
/* implicit */ Attribute(const ImplType *attr)
|
||||||
|
: attr(const_cast<ImplType *>(attr)) {}
|
||||||
|
|
||||||
|
Attribute(const Attribute &other) : attr(other.attr) {}
|
||||||
|
Attribute &operator=(Attribute other) {
|
||||||
|
attr = other.attr;
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool operator==(Attribute other) const { return attr == other.attr; }
|
||||||
|
bool operator!=(Attribute other) const { return !(*this == other); }
|
||||||
|
explicit operator bool() const { return attr; }
|
||||||
|
|
||||||
|
bool operator!() const { return attr == nullptr; }
|
||||||
|
|
||||||
|
template <typename U> bool isa() const;
|
||||||
|
template <typename U> U dyn_cast() const;
|
||||||
|
template <typename U> U dyn_cast_or_null() const;
|
||||||
|
template <typename U> U cast() const;
|
||||||
|
|
||||||
/// Return the classification for this attribute.
|
/// Return the classification for this attribute.
|
||||||
Kind getKind() const { return kind; }
|
Kind getKind() const;
|
||||||
|
|
||||||
/// Return true if this field is, or contains, a function attribute.
|
/// Return true if this field is, or contains, a function attribute.
|
||||||
bool isOrContainsFunction() const { return isOrContainsFunctionCache; }
|
bool isOrContainsFunction() const;
|
||||||
|
|
||||||
/// Print the attribute.
|
/// Print the attribute.
|
||||||
void print(raw_ostream &os) const;
|
void print(raw_ostream &os) const;
|
||||||
void dump() const;
|
void dump() const;
|
||||||
|
|
||||||
|
friend ::llvm::hash_code hash_value(Attribute arg);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit Attribute(Kind kind, bool isOrContainsFunction)
|
ImplType *attr;
|
||||||
: kind(kind), isOrContainsFunctionCache(isOrContainsFunction) {}
|
|
||||||
~Attribute() {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
/// Classification of the subclass, used for type checking.
|
|
||||||
Kind kind : 8;
|
|
||||||
|
|
||||||
/// This field is true if this is, or contains, a function attribute.
|
|
||||||
bool isOrContainsFunctionCache : 1;
|
|
||||||
|
|
||||||
Attribute(const Attribute &) = delete;
|
|
||||||
void operator=(const Attribute &) = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) {
|
inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
|
||||||
attr.print(os);
|
attr.print(os);
|
||||||
return os;
|
return os;
|
||||||
}
|
}
|
||||||
|
|
||||||
class BoolAttr : public Attribute {
|
class BoolAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static BoolAttr *get(bool value, MLIRContext *context);
|
typedef detail::BoolAttributeStorage ImplType;
|
||||||
|
BoolAttr() = default;
|
||||||
|
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
bool getValue() const { return value; }
|
static BoolAttr get(bool value, MLIRContext *context);
|
||||||
|
|
||||||
|
bool getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Bool; }
|
||||||
return attr->getKind() == Kind::Bool;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
BoolAttr(bool value)
|
|
||||||
: Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
|
|
||||||
~BoolAttr() = delete;
|
|
||||||
bool value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class IntegerAttr : public Attribute {
|
class IntegerAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static IntegerAttr *get(int64_t value, MLIRContext *context);
|
typedef detail::IntegerAttributeStorage ImplType;
|
||||||
|
IntegerAttr() = default;
|
||||||
|
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
int64_t getValue() const { return value; }
|
static IntegerAttr get(int64_t value, MLIRContext *context);
|
||||||
|
|
||||||
|
int64_t getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||||
return attr->getKind() == Kind::Integer;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
IntegerAttr(int64_t value)
|
|
||||||
: Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
|
|
||||||
}
|
|
||||||
~IntegerAttr() = delete;
|
|
||||||
int64_t value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class FloatAttr final : public Attribute,
|
class FloatAttr final : public Attribute {
|
||||||
public llvm::TrailingObjects<FloatAttr, uint64_t> {
|
|
||||||
public:
|
public:
|
||||||
static FloatAttr *get(double value, MLIRContext *context);
|
typedef detail::FloatAttributeStorage ImplType;
|
||||||
static FloatAttr *get(const APFloat &value, MLIRContext *context);
|
FloatAttr() = default;
|
||||||
|
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
|
static FloatAttr get(double value, MLIRContext *context);
|
||||||
|
static FloatAttr get(const APFloat &value, MLIRContext *context);
|
||||||
|
|
||||||
APFloat getValue() const;
|
APFloat getValue() const;
|
||||||
|
|
||||||
double getDouble() const { return getValue().convertToDouble(); }
|
double getDouble() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Float; }
|
||||||
return attr->getKind() == Kind::Float;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
|
|
||||||
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
|
|
||||||
semantics(semantics), numObjects(numObjects) {}
|
|
||||||
FloatAttr(const FloatAttr &value) = delete;
|
|
||||||
~FloatAttr() = delete;
|
|
||||||
|
|
||||||
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
|
|
||||||
return numObjects;
|
|
||||||
}
|
|
||||||
|
|
||||||
const llvm::fltSemantics &semantics;
|
|
||||||
size_t numObjects;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class StringAttr : public Attribute {
|
class StringAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static StringAttr *get(StringRef bytes, MLIRContext *context);
|
typedef detail::StringAttributeStorage ImplType;
|
||||||
|
StringAttr() = default;
|
||||||
|
/* implicit */ StringAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
StringRef getValue() const { return value; }
|
static StringAttr get(StringRef bytes, MLIRContext *context);
|
||||||
|
|
||||||
|
StringRef getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::String; }
|
||||||
return attr->getKind() == Kind::String;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
StringAttr(StringRef value)
|
|
||||||
: Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
|
|
||||||
~StringAttr() = delete;
|
|
||||||
StringRef value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Array attributes are lists of other attributes. They are not necessarily
|
/// Array attributes are lists of other attributes. They are not necessarily
|
||||||
/// type homogenous given that attributes don't, in general, carry types.
|
/// type homogenous given that attributes don't, in general, carry types.
|
||||||
class ArrayAttr : public Attribute {
|
class ArrayAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static ArrayAttr *get(ArrayRef<Attribute *> value, MLIRContext *context);
|
typedef detail::ArrayAttributeStorage ImplType;
|
||||||
|
ArrayAttr() = default;
|
||||||
|
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
ArrayRef<Attribute *> getValue() const { return value; }
|
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
|
||||||
|
|
||||||
|
ArrayRef<Attribute> getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Array; }
|
||||||
return attr->getKind() == Kind::Array;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
|
|
||||||
: Attribute(Kind::Array, isOrContainsFunction), value(value) {}
|
|
||||||
~ArrayAttr() = delete;
|
|
||||||
ArrayRef<Attribute *> value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class AffineMapAttr : public Attribute {
|
class AffineMapAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static AffineMapAttr *get(AffineMap value);
|
typedef detail::AffineMapAttributeStorage ImplType;
|
||||||
|
AffineMapAttr() = default;
|
||||||
|
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
AffineMap getValue() const { return value; }
|
static AffineMapAttr get(AffineMap value);
|
||||||
|
|
||||||
|
AffineMap getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
|
||||||
return attr->getKind() == Kind::AffineMap;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
AffineMapAttr(AffineMap value)
|
|
||||||
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
|
|
||||||
value(value) {}
|
|
||||||
~AffineMapAttr() = delete;
|
|
||||||
AffineMap value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TypeAttr : public Attribute {
|
class TypeAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static TypeAttr *get(Type *type, MLIRContext *context);
|
typedef detail::TypeAttributeStorage ImplType;
|
||||||
|
TypeAttr() = default;
|
||||||
|
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
Type *getValue() const { return value; }
|
static TypeAttr get(Type *type, MLIRContext *context);
|
||||||
|
|
||||||
|
Type *getValue() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Type; }
|
||||||
return attr->getKind() == Kind::Type;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
TypeAttr(Type *value)
|
|
||||||
: Attribute(Kind::Type, /*isOrContainsFunction=*/false), value(value) {}
|
|
||||||
~TypeAttr() = delete;
|
|
||||||
Type *value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A function attribute represents a reference to a function object.
|
/// A function attribute represents a reference to a function object.
|
||||||
|
@ -237,63 +233,53 @@ private:
|
||||||
/// remain in MLIRContext.
|
/// remain in MLIRContext.
|
||||||
class FunctionAttr : public Attribute {
|
class FunctionAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
static FunctionAttr *get(const Function *value, MLIRContext *context);
|
typedef detail::FunctionAttributeStorage ImplType;
|
||||||
|
FunctionAttr() = default;
|
||||||
|
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
Function *getValue() const { return value; }
|
static FunctionAttr get(const Function *value, MLIRContext *context);
|
||||||
|
|
||||||
|
Function *getValue() const;
|
||||||
|
|
||||||
FunctionType *getType() const;
|
FunctionType *getType() const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||||
return attr->getKind() == Kind::Function;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This function is used by the internals of the Function class to null out
|
/// This function is used by the internals of the Function class to null out
|
||||||
/// attributes refering to functions that are about to be deleted.
|
/// attributes refering to functions that are about to be deleted.
|
||||||
static void dropFunctionReference(Function *value);
|
static void dropFunctionReference(Function *value);
|
||||||
|
|
||||||
private:
|
|
||||||
FunctionAttr(Function *value)
|
|
||||||
: Attribute(Kind::Function, /*isOrContainsFunction=*/true), value(value) {
|
|
||||||
}
|
|
||||||
~FunctionAttr() = delete;
|
|
||||||
Function *value;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A base attribute represents a reference to a vector or tensor constant.
|
/// A base attribute represents a reference to a vector or tensor constant.
|
||||||
class ElementsAttr : public Attribute {
|
class ElementsAttr : public Attribute {
|
||||||
public:
|
public:
|
||||||
ElementsAttr(Kind kind, VectorOrTensorType *type)
|
typedef detail::ElementsAttributeStorage ImplType;
|
||||||
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {}
|
ElementsAttr() = default;
|
||||||
|
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
VectorOrTensorType *getType() const { return type; }
|
VectorOrTensorType *getType() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) {
|
||||||
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR &&
|
return kind >= Kind::FIRST_ELEMENTS_ATTR &&
|
||||||
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR;
|
kind <= Kind::LAST_ELEMENTS_ATTR;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
VectorOrTensorType *type;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a splat vecctor or tensor constant,
|
/// An attribute represents a reference to a splat vecctor or tensor constant,
|
||||||
/// meaning all of the elements have the same value.
|
/// meaning all of the elements have the same value.
|
||||||
class SplatElementsAttr : public ElementsAttr {
|
class SplatElementsAttr : public ElementsAttr {
|
||||||
public:
|
public:
|
||||||
static SplatElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
|
typedef detail::SplatElementsAttributeStorage ImplType;
|
||||||
Attribute *getValue() const { return elt; }
|
SplatElementsAttr() = default;
|
||||||
|
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
|
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
|
||||||
|
Attribute getValue() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::SplatElements; }
|
||||||
return attr->getKind() == Kind::SplatElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
|
|
||||||
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
|
|
||||||
Attribute *elt;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a dense vector or tensor object.
|
/// An attribute represents a reference to a dense vector or tensor object.
|
||||||
|
@ -302,42 +288,42 @@ private:
|
||||||
/// than 64.
|
/// than 64.
|
||||||
class DenseElementsAttr : public ElementsAttr {
|
class DenseElementsAttr : public ElementsAttr {
|
||||||
public:
|
public:
|
||||||
|
typedef detail::DenseElementsAttributeStorage ImplType;
|
||||||
|
DenseElementsAttr() = default;
|
||||||
|
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
/// It assumes the elements in the input array have been truncated to the bits
|
/// It assumes the elements in the input array have been truncated to the bits
|
||||||
/// width specified by the element type (note all float type are 64 bits).
|
/// width specified by the element type (note all float type are 64 bits).
|
||||||
/// When the value is retrieved, the bits are read from the storage and extend
|
/// When the value is retrieved, the bits are read from the storage and extend
|
||||||
/// to 64 bits if necessary.
|
/// to 64 bits if necessary.
|
||||||
static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef<char> data);
|
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
|
||||||
|
|
||||||
// TODO: Read the data from the attribute list and compress them
|
// TODO: Read the data from the attribute list and compress them
|
||||||
// to a character array. Then call the above method to construct the
|
// to a character array. Then call the above method to construct the
|
||||||
// attribute.
|
// attribute.
|
||||||
static DenseElementsAttr *get(VectorOrTensorType *type,
|
static DenseElementsAttr get(VectorOrTensorType *type,
|
||||||
ArrayRef<Attribute *> values);
|
ArrayRef<Attribute> values);
|
||||||
|
|
||||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
|
||||||
ArrayRef<char> getRawData() const { return data; }
|
ArrayRef<char> getRawData() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) {
|
||||||
return attr->getKind() == Kind::DenseIntElements ||
|
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
|
||||||
attr->getKind() == Kind::DenseFPElements;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
|
||||||
DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef<char> data)
|
|
||||||
: ElementsAttr(kind, type), data(data) {}
|
|
||||||
|
|
||||||
private:
|
|
||||||
ArrayRef<char> data;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a dense integer vector or tensor
|
/// An attribute represents a reference to a dense integer vector or tensor
|
||||||
/// object.
|
/// object.
|
||||||
class DenseIntElementsAttr : public DenseElementsAttr {
|
class DenseIntElementsAttr : public DenseElementsAttr {
|
||||||
public:
|
public:
|
||||||
|
typedef detail::DenseIntElementsAttributeStorage ImplType;
|
||||||
|
DenseIntElementsAttr() = default;
|
||||||
|
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
// TODO: returns APInts instead of IntegerAttr.
|
// TODO: returns APInts instead of IntegerAttr.
|
||||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
|
||||||
APInt getValue(ArrayRef<unsigned> indices) const;
|
APInt getValue(ArrayRef<unsigned> indices) const;
|
||||||
|
|
||||||
|
@ -352,41 +338,24 @@ public:
|
||||||
size_t bitsWidth);
|
size_t bitsWidth);
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
|
||||||
return attr->getKind() == Kind::DenseIntElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class DenseElementsAttr;
|
|
||||||
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
|
|
||||||
size_t bitsWidth)
|
|
||||||
: DenseElementsAttr(Kind::DenseIntElements, type, data),
|
|
||||||
bitsWidth(bitsWidth) {}
|
|
||||||
|
|
||||||
~DenseIntElementsAttr() = delete;
|
|
||||||
|
|
||||||
size_t bitsWidth;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a dense float vector or tensor
|
/// An attribute represents a reference to a dense float vector or tensor
|
||||||
/// object. Each element is stored as a double.
|
/// object. Each element is stored as a double.
|
||||||
class DenseFPElementsAttr : public DenseElementsAttr {
|
class DenseFPElementsAttr : public DenseElementsAttr {
|
||||||
public:
|
public:
|
||||||
|
typedef detail::DenseFPElementsAttributeStorage ImplType;
|
||||||
|
DenseFPElementsAttr() = default;
|
||||||
|
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
// TODO: returns APFPs instead of FloatAttr.
|
// TODO: returns APFPs instead of FloatAttr.
|
||||||
void getValues(SmallVectorImpl<Attribute *> &values) const;
|
void getValues(SmallVectorImpl<Attribute> &values) const;
|
||||||
|
|
||||||
APFloat getValue(ArrayRef<unsigned> indices) const;
|
APFloat getValue(ArrayRef<unsigned> indices) const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
|
||||||
return attr->getKind() == Kind::DenseFPElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class DenseElementsAttr;
|
|
||||||
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
|
|
||||||
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
|
|
||||||
~DenseFPElementsAttr() = delete;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a tensor constant with opaque
|
/// An attribute represents a reference to a tensor constant with opaque
|
||||||
|
@ -394,20 +363,16 @@ private:
|
||||||
/// doesn't need to interpret.
|
/// doesn't need to interpret.
|
||||||
class OpaqueElementsAttr : public ElementsAttr {
|
class OpaqueElementsAttr : public ElementsAttr {
|
||||||
public:
|
public:
|
||||||
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes);
|
typedef detail::OpaqueElementsAttributeStorage ImplType;
|
||||||
|
OpaqueElementsAttr() = default;
|
||||||
|
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
StringRef getValue() const { return bytes; }
|
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
|
||||||
|
|
||||||
|
StringRef getValue() const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; }
|
||||||
return attr->getKind() == Kind::OpaqueElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
|
|
||||||
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
|
|
||||||
~OpaqueElementsAttr() = delete;
|
|
||||||
StringRef bytes;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// An attribute represents a reference to a sparse vector or tensor object.
|
/// An attribute represents a reference to a sparse vector or tensor object.
|
||||||
|
@ -427,32 +392,67 @@ private:
|
||||||
/// [0, 0, 0, 0]].
|
/// [0, 0, 0, 0]].
|
||||||
class SparseElementsAttr : public ElementsAttr {
|
class SparseElementsAttr : public ElementsAttr {
|
||||||
public:
|
public:
|
||||||
static SparseElementsAttr *get(VectorOrTensorType *type,
|
typedef detail::SparseElementsAttributeStorage ImplType;
|
||||||
DenseIntElementsAttr *indices,
|
SparseElementsAttr() = default;
|
||||||
DenseElementsAttr *values);
|
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
|
||||||
|
|
||||||
DenseIntElementsAttr *getIndices() const { return indices; }
|
static SparseElementsAttr get(VectorOrTensorType *type,
|
||||||
|
DenseIntElementsAttr indices,
|
||||||
|
DenseElementsAttr values);
|
||||||
|
|
||||||
DenseElementsAttr *getValues() const { return values; }
|
DenseIntElementsAttr getIndices() const;
|
||||||
|
|
||||||
|
DenseElementsAttr getValues() const;
|
||||||
|
|
||||||
/// Return the value at the given index.
|
/// Return the value at the given index.
|
||||||
Attribute *getValue(ArrayRef<unsigned> index) const;
|
Attribute getValue(ArrayRef<unsigned> index) const;
|
||||||
|
|
||||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||||
static bool classof(const Attribute *attr) {
|
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
|
||||||
return attr->getKind() == Kind::SparseElements;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
SparseElementsAttr(VectorOrTensorType *type, DenseIntElementsAttr *indices,
|
|
||||||
DenseElementsAttr *values)
|
|
||||||
: ElementsAttr(Kind::SparseElements, type), indices(indices),
|
|
||||||
values(values) {}
|
|
||||||
~SparseElementsAttr() = delete;
|
|
||||||
|
|
||||||
DenseIntElementsAttr *const indices;
|
|
||||||
DenseElementsAttr *const values;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename U> bool Attribute::isa() const {
|
||||||
|
assert(attr && "isa<> used on a null attribute.");
|
||||||
|
return U::kindof(getKind());
|
||||||
|
}
|
||||||
|
template <typename U> U Attribute::dyn_cast() const {
|
||||||
|
return isa<U>() ? U(attr) : U(nullptr);
|
||||||
|
}
|
||||||
|
template <typename U> U Attribute::dyn_cast_or_null() const {
|
||||||
|
return (attr && isa<U>()) ? U(attr) : U(nullptr);
|
||||||
|
}
|
||||||
|
template <typename U> U Attribute::cast() const {
|
||||||
|
assert(isa<U>());
|
||||||
|
return U(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make Attribute hashable.
|
||||||
|
inline ::llvm::hash_code hash_value(Attribute arg) {
|
||||||
|
return ::llvm::hash_value(arg.attr);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace mlir.
|
} // end namespace mlir.
|
||||||
|
|
||||||
|
namespace llvm {
|
||||||
|
|
||||||
|
// Attribute hash just like pointers
|
||||||
|
template <> struct DenseMapInfo<mlir::Attribute> {
|
||||||
|
static mlir::Attribute getEmptyKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||||
|
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static mlir::Attribute getTombstoneKey() {
|
||||||
|
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||||
|
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
|
||||||
|
}
|
||||||
|
static unsigned getHashValue(mlir::Attribute val) {
|
||||||
|
return mlir::hash_value(val);
|
||||||
|
}
|
||||||
|
static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
|
||||||
|
return LHS == RHS;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace llvm
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -93,23 +93,23 @@ public:
|
||||||
UnrankedTensorType *getTensorType(Type *elementType);
|
UnrankedTensorType *getTensorType(Type *elementType);
|
||||||
|
|
||||||
// Attributes.
|
// Attributes.
|
||||||
BoolAttr *getBoolAttr(bool value);
|
|
||||||
IntegerAttr *getIntegerAttr(int64_t value);
|
BoolAttr getBoolAttr(bool value);
|
||||||
FloatAttr *getFloatAttr(double value);
|
IntegerAttr getIntegerAttr(int64_t value);
|
||||||
FloatAttr *getFloatAttr(const APFloat &value);
|
FloatAttr getFloatAttr(double value);
|
||||||
StringAttr *getStringAttr(StringRef bytes);
|
FloatAttr getFloatAttr(const APFloat &value);
|
||||||
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
|
StringAttr getStringAttr(StringRef bytes);
|
||||||
AffineMapAttr *getAffineMapAttr(AffineMap map);
|
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
|
||||||
TypeAttr *getTypeAttr(Type *type);
|
AffineMapAttr getAffineMapAttr(AffineMap map);
|
||||||
FunctionAttr *getFunctionAttr(const Function *value);
|
TypeAttr getTypeAttr(Type *type);
|
||||||
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
|
FunctionAttr getFunctionAttr(const Function *value);
|
||||||
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
|
||||||
ArrayRef<char> data);
|
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
|
||||||
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
|
ArrayRef<char> data);
|
||||||
DenseIntElementsAttr *indices,
|
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
|
||||||
DenseElementsAttr *values);
|
DenseIntElementsAttr indices,
|
||||||
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type,
|
DenseElementsAttr values);
|
||||||
StringRef bytes);
|
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
|
||||||
|
|
||||||
// Affine expressions and affine maps.
|
// Affine expressions and affine maps.
|
||||||
AffineExpr getAffineDimExpr(unsigned position);
|
AffineExpr getAffineDimExpr(unsigned position);
|
||||||
|
|
|
@ -60,7 +60,7 @@ public:
|
||||||
|
|
||||||
/// Returns the affine map to be applied by this operation.
|
/// Returns the affine map to be applied by this operation.
|
||||||
AffineMap getAffineMap() const {
|
AffineMap getAffineMap() const {
|
||||||
return getAttrOfType<AffineMapAttr>("map")->getValue();
|
return getAttrOfType<AffineMapAttr>("map").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns true if the result of this operation can be used as dimension id.
|
/// Returns true if the result of this operation can be used as dimension id.
|
||||||
|
@ -75,8 +75,8 @@ public:
|
||||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||||
void print(OpAsmPrinter *p) const;
|
void print(OpAsmPrinter *p) const;
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
bool constantFold(ArrayRef<Attribute *> operands,
|
bool constantFold(ArrayRef<Attribute> operandConstants,
|
||||||
SmallVectorImpl<Attribute *> &results,
|
SmallVectorImpl<Attribute> &results,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -94,10 +94,10 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
|
||||||
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
|
||||||
public:
|
public:
|
||||||
/// Builds a constant op with the specified attribute value and result type.
|
/// Builds a constant op with the specified attribute value and result type.
|
||||||
static void build(Builder *builder, OperationState *result, Attribute *value,
|
static void build(Builder *builder, OperationState *result, Attribute value,
|
||||||
Type *type);
|
Type *type);
|
||||||
|
|
||||||
Attribute *getValue() const { return getAttr("value"); }
|
Attribute getValue() const { return getAttr("value"); }
|
||||||
|
|
||||||
static StringRef getOperationName() { return "constant"; }
|
static StringRef getOperationName() { return "constant"; }
|
||||||
|
|
||||||
|
@ -105,8 +105,8 @@ public:
|
||||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||||
void print(OpAsmPrinter *p) const;
|
void print(OpAsmPrinter *p) const;
|
||||||
bool verify() const;
|
bool verify() const;
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -125,7 +125,7 @@ public:
|
||||||
const APFloat &value, FloatType *type);
|
const APFloat &value, FloatType *type);
|
||||||
|
|
||||||
APFloat getValue() const {
|
APFloat getValue() const {
|
||||||
return getAttrOfType<FloatAttr>("value")->getValue();
|
return getAttrOfType<FloatAttr>("value").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isClassFor(const Operation *op);
|
static bool isClassFor(const Operation *op);
|
||||||
|
@ -152,7 +152,7 @@ public:
|
||||||
Type *type);
|
Type *type);
|
||||||
|
|
||||||
int64_t getValue() const {
|
int64_t getValue() const {
|
||||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isClassFor(const Operation *op);
|
static bool isClassFor(const Operation *op);
|
||||||
|
@ -173,7 +173,7 @@ public:
|
||||||
static void build(Builder *builder, OperationState *result, int64_t value);
|
static void build(Builder *builder, OperationState *result, int64_t value);
|
||||||
|
|
||||||
int64_t getValue() const {
|
int64_t getValue() const {
|
||||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isClassFor(const Operation *op);
|
static bool isClassFor(const Operation *op);
|
||||||
|
|
|
@ -24,12 +24,12 @@
|
||||||
#ifndef MLIR_IR_FUNCTION_H
|
#ifndef MLIR_IR_FUNCTION_H
|
||||||
#define MLIR_IR_FUNCTION_H
|
#define MLIR_IR_FUNCTION_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
#include "llvm/ADT/ilist.h"
|
#include "llvm/ADT/ilist.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Attribute;
|
|
||||||
class AttributeListStorage;
|
class AttributeListStorage;
|
||||||
class FunctionType;
|
class FunctionType;
|
||||||
class Location;
|
class Location;
|
||||||
|
@ -39,7 +39,7 @@ class Module;
|
||||||
/// NamedAttribute is used for function attribute lists, it holds an
|
/// NamedAttribute is used for function attribute lists, it holds an
|
||||||
/// identifier for the name and a value for the attribute. The attribute
|
/// identifier for the name and a value for the attribute. The attribute
|
||||||
/// pointer should always be non-null.
|
/// pointer should always be non-null.
|
||||||
using NamedAttribute = std::pair<Identifier, Attribute *>;
|
using NamedAttribute = std::pair<Identifier, Attribute>;
|
||||||
|
|
||||||
/// This is the base class for all of the MLIR function types.
|
/// This is the base class for all of the MLIR function types.
|
||||||
class Function : public llvm::ilist_node_with_parent<Function, Module> {
|
class Function : public llvm::ilist_node_with_parent<Function, Module> {
|
||||||
|
|
|
@ -138,17 +138,16 @@ public:
|
||||||
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
|
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
|
||||||
|
|
||||||
/// Return an attribute with the specified name.
|
/// Return an attribute with the specified name.
|
||||||
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
|
Attribute getAttr(StringRef name) const { return state->getAttr(name); }
|
||||||
|
|
||||||
/// If the operation has an attribute of the specified type, return it.
|
/// If the operation has an attribute of the specified type, return it.
|
||||||
template <typename AttrClass>
|
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
|
||||||
AttrClass *getAttrOfType(StringRef name) const {
|
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the an attribute exists with the specified name, change it to the new
|
/// If the an attribute exists with the specified name, change it to the new
|
||||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||||
void setAttr(Identifier name, Attribute *value) {
|
void setAttr(Identifier name, Attribute value) {
|
||||||
state->setAttr(name, value);
|
state->setAttr(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -211,8 +210,8 @@ public:
|
||||||
/// true if folding failed, or returns false and fills in `results` on
|
/// true if folding failed, or returns false and fills in `results` on
|
||||||
/// success.
|
/// success.
|
||||||
static bool constantFoldHook(const Operation *op,
|
static bool constantFoldHook(const Operation *op,
|
||||||
ArrayRef<Attribute *> operands,
|
ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results) {
|
SmallVectorImpl<Attribute> &results) {
|
||||||
return op->cast<ConcreteType>()->constantFold(operands, results,
|
return op->cast<ConcreteType>()->constantFold(operands, results,
|
||||||
op->getContext());
|
op->getContext());
|
||||||
}
|
}
|
||||||
|
@ -226,8 +225,8 @@ public:
|
||||||
///
|
///
|
||||||
/// If not overridden, this fallback implementation always fails to fold.
|
/// If not overridden, this fallback implementation always fails to fold.
|
||||||
///
|
///
|
||||||
bool constantFold(ArrayRef<Attribute *> operands,
|
bool constantFold(ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results,
|
SmallVectorImpl<Attribute> &results,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -244,9 +243,9 @@ public:
|
||||||
/// true if folding failed, or returns false and fills in `results` on
|
/// true if folding failed, or returns false and fills in `results` on
|
||||||
/// success.
|
/// success.
|
||||||
static bool constantFoldHook(const Operation *op,
|
static bool constantFoldHook(const Operation *op,
|
||||||
ArrayRef<Attribute *> operands,
|
ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results) {
|
SmallVectorImpl<Attribute> &results) {
|
||||||
auto *result =
|
auto result =
|
||||||
op->cast<ConcreteType>()->constantFold(operands, op->getContext());
|
op->cast<ConcreteType>()->constantFold(operands, op->getContext());
|
||||||
if (!result)
|
if (!result)
|
||||||
return true;
|
return true;
|
||||||
|
@ -511,8 +510,8 @@ public:
|
||||||
///
|
///
|
||||||
/// If not overridden, this fallback implementation always fails to fold.
|
/// If not overridden, this fallback implementation always fails to fold.
|
||||||
///
|
///
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -69,7 +69,7 @@ public:
|
||||||
}
|
}
|
||||||
virtual void printType(const Type *type) = 0;
|
virtual void printType(const Type *type) = 0;
|
||||||
virtual void printFunctionReference(const Function *func) = 0;
|
virtual void printFunctionReference(const Function *func) = 0;
|
||||||
virtual void printAttribute(const Attribute *attr) = 0;
|
virtual void printAttribute(Attribute attr) = 0;
|
||||||
virtual void printAffineMap(AffineMap map) = 0;
|
virtual void printAffineMap(AffineMap map) = 0;
|
||||||
virtual void printAffineExpr(AffineExpr expr) = 0;
|
virtual void printAffineExpr(AffineExpr expr) = 0;
|
||||||
|
|
||||||
|
@ -100,8 +100,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
|
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
|
||||||
p.printAttribute(&attr);
|
p.printAttribute(attr);
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,24 +210,24 @@ public:
|
||||||
/// Parse an arbitrary attribute and return it in result. This also adds the
|
/// Parse an arbitrary attribute and return it in result. This also adds the
|
||||||
/// attribute to the specified attribute list with the specified name. this
|
/// attribute to the specified attribute list with the specified name. this
|
||||||
/// captures the location of the attribute in 'loc' if it is non-null.
|
/// captures the location of the attribute in 'loc' if it is non-null.
|
||||||
virtual bool parseAttribute(Attribute *&result, const char *attrName,
|
virtual bool parseAttribute(Attribute &result, const char *attrName,
|
||||||
SmallVectorImpl<NamedAttribute> &attrs) = 0;
|
SmallVectorImpl<NamedAttribute> &attrs) = 0;
|
||||||
|
|
||||||
/// Parse an attribute of a specific kind, capturing the location into `loc`
|
/// Parse an attribute of a specific kind, capturing the location into `loc`
|
||||||
/// if specified.
|
/// if specified.
|
||||||
template <typename AttrType>
|
template <typename AttrType>
|
||||||
bool parseAttribute(AttrType *&result, const char *attrName,
|
bool parseAttribute(AttrType &result, const char *attrName,
|
||||||
SmallVectorImpl<NamedAttribute> &attrs) {
|
SmallVectorImpl<NamedAttribute> &attrs) {
|
||||||
llvm::SMLoc loc;
|
llvm::SMLoc loc;
|
||||||
getCurrentLocation(&loc);
|
getCurrentLocation(&loc);
|
||||||
|
|
||||||
// Parse any kind of attribute.
|
// Parse any kind of attribute.
|
||||||
Attribute *attr;
|
Attribute attr;
|
||||||
if (parseAttribute(attr, attrName, attrs))
|
if (parseAttribute(attr, attrName, attrs))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// Check for the right kind of attribute.
|
// Check for the right kind of attribute.
|
||||||
result = dyn_cast<AttrType>(attr);
|
result = attr.dyn_cast<AttrType>();
|
||||||
if (!result) {
|
if (!result) {
|
||||||
emitError(loc, "invalid kind of constant specified");
|
emitError(loc, "invalid kind of constant specified");
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -113,33 +113,31 @@ public:
|
||||||
ArrayRef<NamedAttribute> getAttrs() const;
|
ArrayRef<NamedAttribute> getAttrs() const;
|
||||||
|
|
||||||
/// Return the specified attribute if present, null otherwise.
|
/// Return the specified attribute if present, null otherwise.
|
||||||
Attribute *getAttr(Identifier name) const {
|
Attribute getAttr(Identifier name) const {
|
||||||
for (auto elt : getAttrs())
|
for (auto elt : getAttrs())
|
||||||
if (elt.first == name)
|
if (elt.first == name)
|
||||||
return elt.second;
|
return elt.second;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute *getAttr(StringRef name) const {
|
Attribute getAttr(StringRef name) const {
|
||||||
for (auto elt : getAttrs())
|
for (auto elt : getAttrs())
|
||||||
if (elt.first.is(name))
|
if (elt.first.is(name))
|
||||||
return elt.second;
|
return elt.second;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AttrClass>
|
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
|
||||||
AttrClass *getAttrOfType(Identifier name) const {
|
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AttrClass>
|
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
|
||||||
AttrClass *getAttrOfType(StringRef name) const {
|
return getAttr(name).dyn_cast_or_null<AttrClass>();
|
||||||
return dyn_cast_or_null<AttrClass>(getAttr(name));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// If the an attribute exists with the specified name, change it to the new
|
/// If the an attribute exists with the specified name, change it to the new
|
||||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||||
void setAttr(Identifier name, Attribute *value);
|
void setAttr(Identifier name, Attribute value);
|
||||||
|
|
||||||
enum class RemoveResult {
|
enum class RemoveResult {
|
||||||
Removed, NotFound
|
Removed, NotFound
|
||||||
|
@ -250,8 +248,8 @@ public:
|
||||||
/// the operands of the operation, but may be null if non-constant. If
|
/// the operands of the operation, but may be null if non-constant. If
|
||||||
/// constant folding is successful, this returns false and fills in the
|
/// constant folding is successful, this returns false and fills in the
|
||||||
/// `results` vector. If not, this returns true and `results` is unspecified.
|
/// `results` vector. If not, this returns true and `results` is unspecified.
|
||||||
bool constantFold(ArrayRef<Attribute *> operands,
|
bool constantFold(ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results) const;
|
SmallVectorImpl<Attribute> &results) const;
|
||||||
|
|
||||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||||
static bool classof(const Instruction *inst);
|
static bool classof(const Instruction *inst);
|
||||||
|
|
|
@ -23,11 +23,11 @@
|
||||||
#ifndef MLIR_IR_OPERATION_SUPPORT_H
|
#ifndef MLIR_IR_OPERATION_SUPPORT_H
|
||||||
#define MLIR_IR_OPERATION_SUPPORT_H
|
#define MLIR_IR_OPERATION_SUPPORT_H
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
#include "mlir/IR/Identifier.h"
|
#include "mlir/IR/Identifier.h"
|
||||||
#include "llvm/ADT/PointerUnion.h"
|
#include "llvm/ADT/PointerUnion.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
class Attribute;
|
|
||||||
class Dialect;
|
class Dialect;
|
||||||
class Location;
|
class Location;
|
||||||
class Operation;
|
class Operation;
|
||||||
|
@ -80,8 +80,8 @@ public:
|
||||||
/// This hook implements a constant folder for this operation. It returns
|
/// This hook implements a constant folder for this operation. It returns
|
||||||
/// true if folding failed, or returns false and fills in `results` on
|
/// true if folding failed, or returns false and fills in `results` on
|
||||||
/// success.
|
/// success.
|
||||||
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute *> operands,
|
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results);
|
SmallVectorImpl<Attribute> &results);
|
||||||
|
|
||||||
// Returns whether the operation has a particular property.
|
// Returns whether the operation has a particular property.
|
||||||
bool hasProperty(OperationProperty property) const {
|
bool hasProperty(OperationProperty property) const {
|
||||||
|
@ -110,8 +110,8 @@ private:
|
||||||
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
|
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
|
||||||
bool (&verifyInvariants)(const Operation *op),
|
bool (&verifyInvariants)(const Operation *op),
|
||||||
bool (&constantFoldHook)(const Operation *op,
|
bool (&constantFoldHook)(const Operation *op,
|
||||||
ArrayRef<Attribute *> operands,
|
ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results))
|
SmallVectorImpl<Attribute> &results))
|
||||||
: name(name), dialect(dialect), isClassFor(isClassFor),
|
: name(name), dialect(dialect), isClassFor(isClassFor),
|
||||||
parseAssembly(parseAssembly), printAssembly(printAssembly),
|
parseAssembly(parseAssembly), printAssembly(printAssembly),
|
||||||
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
|
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
|
||||||
|
@ -124,7 +124,7 @@ private:
|
||||||
/// NamedAttribute is a used for operation attribute lists, it holds an
|
/// NamedAttribute is a used for operation attribute lists, it holds an
|
||||||
/// identifier for the name and a value for the attribute. The attribute
|
/// identifier for the name and a value for the attribute. The attribute
|
||||||
/// pointer should always be non-null.
|
/// pointer should always be non-null.
|
||||||
using NamedAttribute = std::pair<Identifier, Attribute *>;
|
using NamedAttribute = std::pair<Identifier, Attribute>;
|
||||||
|
|
||||||
class OperationName {
|
class OperationName {
|
||||||
public:
|
public:
|
||||||
|
@ -204,7 +204,7 @@ public:
|
||||||
types.append(newTypes.begin(), newTypes.end());
|
types.append(newTypes.begin(), newTypes.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
void addAttribute(StringRef name, Attribute *attr) {
|
void addAttribute(StringRef name, Attribute attr) {
|
||||||
attributes.push_back({Identifier::get(name, context), attr});
|
attributes.push_back({Identifier::get(name, context), attr});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -49,8 +49,8 @@ class AddFOp
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "addf"; }
|
static StringRef getOperationName() { return "addf"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -70,8 +70,8 @@ class AddIOp
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "addi"; }
|
static StringRef getOperationName() { return "addi"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -134,7 +134,7 @@ public:
|
||||||
ArrayRef<SSAValue *> operands);
|
ArrayRef<SSAValue *> operands);
|
||||||
|
|
||||||
Function *getCallee() const {
|
Function *getCallee() const {
|
||||||
return getAttrOfType<FunctionAttr>("callee")->getValue();
|
return getAttrOfType<FunctionAttr>("callee").getValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hooks to customize behavior of this op.
|
// Hooks to customize behavior of this op.
|
||||||
|
@ -218,13 +218,13 @@ public:
|
||||||
static void build(Builder *builder, OperationState *result,
|
static void build(Builder *builder, OperationState *result,
|
||||||
SSAValue *memrefOrTensor, unsigned index);
|
SSAValue *memrefOrTensor, unsigned index);
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
/// This returns the dimension number that the 'dim' is inspecting.
|
/// This returns the dimension number that the 'dim' is inspecting.
|
||||||
unsigned getIndex() const {
|
unsigned getIndex() const {
|
||||||
return static_cast<unsigned>(
|
return static_cast<unsigned>(
|
||||||
getAttrOfType<IntegerAttr>("index")->getValue());
|
getAttrOfType<IntegerAttr>("index").getValue());
|
||||||
}
|
}
|
||||||
|
|
||||||
static StringRef getOperationName() { return "dim"; }
|
static StringRef getOperationName() { return "dim"; }
|
||||||
|
@ -513,8 +513,8 @@ class MulFOp
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "mulf"; }
|
static StringRef getOperationName() { return "mulf"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -534,8 +534,8 @@ class MulIOp
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "muli"; }
|
static StringRef getOperationName() { return "muli"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -597,8 +597,8 @@ class SubFOp : public BinaryOp<SubFOp, OpTrait::ResultsAreFloatLike,
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "subf"; }
|
static StringRef getOperationName() { return "subf"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
@ -617,8 +617,8 @@ class SubIOp : public BinaryOp<SubIOp, OpTrait::ResultsAreIntegerLike,
|
||||||
public:
|
public:
|
||||||
static StringRef getOperationName() { return "subi"; }
|
static StringRef getOperationName() { return "subi"; }
|
||||||
|
|
||||||
Attribute *constantFold(ArrayRef<Attribute *> operands,
|
Attribute constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const;
|
MLIRContext *context) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class Operation;
|
friend class Operation;
|
||||||
|
|
|
@ -82,7 +82,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
bool verifyOperation(const Operation &op);
|
bool verifyOperation(const Operation &op);
|
||||||
bool verifyAttribute(Attribute *attr, const Operation &op);
|
bool verifyAttribute(Attribute attr, const Operation &op);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit Verifier(const Function &fn) : fn(fn) {}
|
explicit Verifier(const Function &fn) : fn(fn) {}
|
||||||
|
@ -94,26 +94,26 @@ private:
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
|
||||||
// Check that function attributes are all well formed.
|
// Check that function attributes are all well formed.
|
||||||
bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) {
|
bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
|
||||||
if (!attr->isOrContainsFunction())
|
if (!attr.isOrContainsFunction())
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// If we have a function attribute, check that it is non-null and in the
|
// If we have a function attribute, check that it is non-null and in the
|
||||||
// same module as the operation that refers to it.
|
// same module as the operation that refers to it.
|
||||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) {
|
if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
|
||||||
if (!fnAttr->getValue())
|
if (!fnAttr.getValue())
|
||||||
return failure("attribute refers to deallocated function!", op);
|
return failure("attribute refers to deallocated function!", op);
|
||||||
|
|
||||||
if (fnAttr->getValue()->getModule() != fn.getModule())
|
if (fnAttr.getValue()->getModule() != fn.getModule())
|
||||||
return failure("attribute refers to function '" +
|
return failure("attribute refers to function '" +
|
||||||
Twine(fnAttr->getValue()->getName()) +
|
Twine(fnAttr.getValue()->getName()) +
|
||||||
"' defined in another module!",
|
"' defined in another module!",
|
||||||
op);
|
op);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, we must have an array attribute, remap the elements.
|
// Otherwise, we must have an array attribute, remap the elements.
|
||||||
for (auto *elt : cast<ArrayAttr>(attr)->getValue()) {
|
for (auto elt : attr.cast<ArrayAttr>().getValue()) {
|
||||||
if (verifyAttribute(elt, op))
|
if (verifyAttribute(elt, op))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,13 +32,12 @@ namespace {
|
||||||
// evaluated on constant 'operandConsts'.
|
// evaluated on constant 'operandConsts'.
|
||||||
class AffineExprConstantFolder {
|
class AffineExprConstantFolder {
|
||||||
public:
|
public:
|
||||||
AffineExprConstantFolder(unsigned numDims,
|
AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
|
||||||
ArrayRef<Attribute *> operandConsts)
|
|
||||||
: numDims(numDims), operandConsts(operandConsts) {}
|
: numDims(numDims), operandConsts(operandConsts) {}
|
||||||
|
|
||||||
/// Attempt to constant fold the specified affine expr, or return null on
|
/// Attempt to constant fold the specified affine expr, or return null on
|
||||||
/// failure.
|
/// failure.
|
||||||
IntegerAttr *constantFold(AffineExpr expr) {
|
IntegerAttr constantFold(AffineExpr expr) {
|
||||||
switch (expr.getKind()) {
|
switch (expr.getKind()) {
|
||||||
case AffineExprKind::Add:
|
case AffineExprKind::Add:
|
||||||
return constantFoldBinExpr(
|
return constantFoldBinExpr(
|
||||||
|
@ -59,31 +58,32 @@ public:
|
||||||
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
|
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
|
||||||
expr.getContext());
|
expr.getContext());
|
||||||
case AffineExprKind::DimId:
|
case AffineExprKind::DimId:
|
||||||
return dyn_cast_or_null<IntegerAttr>(
|
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
|
||||||
operandConsts[expr.cast<AffineDimExpr>().getPosition()]);
|
.dyn_cast_or_null<IntegerAttr>();
|
||||||
case AffineExprKind::SymbolId:
|
case AffineExprKind::SymbolId:
|
||||||
return dyn_cast_or_null<IntegerAttr>(
|
return operandConsts[numDims +
|
||||||
operandConsts[numDims + expr.cast<AffineSymbolExpr>().getPosition()]);
|
expr.cast<AffineSymbolExpr>().getPosition()]
|
||||||
|
.dyn_cast_or_null<IntegerAttr>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
IntegerAttr *
|
IntegerAttr
|
||||||
constantFoldBinExpr(AffineExpr expr,
|
constantFoldBinExpr(AffineExpr expr,
|
||||||
std::function<uint64_t(int64_t, uint64_t)> op) {
|
std::function<uint64_t(int64_t, uint64_t)> op) {
|
||||||
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
|
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
|
||||||
auto *lhs = constantFold(binOpExpr.getLHS());
|
auto lhs = constantFold(binOpExpr.getLHS());
|
||||||
auto *rhs = constantFold(binOpExpr.getRHS());
|
auto rhs = constantFold(binOpExpr.getRHS());
|
||||||
if (!lhs || !rhs)
|
if (!lhs || !rhs)
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return IntegerAttr::get(op(lhs->getValue(), rhs->getValue()),
|
return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
|
||||||
expr.getContext());
|
expr.getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
// The number of dimension operands in AffineMap containing this expression.
|
// The number of dimension operands in AffineMap containing this expression.
|
||||||
unsigned numDims;
|
unsigned numDims;
|
||||||
// The constant valued operands used to evaluate this AffineExpr.
|
// The constant valued operands used to evaluate this AffineExpr.
|
||||||
ArrayRef<Attribute *> operandConsts;
|
ArrayRef<Attribute> operandConsts;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end anonymous namespace
|
} // end anonymous namespace
|
||||||
|
@ -137,15 +137,15 @@ ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
|
||||||
/// Folds the results of the application of an affine map on the provided
|
/// Folds the results of the application of an affine map on the provided
|
||||||
/// operands to a constant if possible. Returns false if the folding happens,
|
/// operands to a constant if possible. Returns false if the folding happens,
|
||||||
/// true otherwise.
|
/// true otherwise.
|
||||||
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants,
|
bool AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
|
||||||
SmallVectorImpl<Attribute *> &results) const {
|
SmallVectorImpl<Attribute> &results) const {
|
||||||
assert(getNumInputs() == operandConstants.size());
|
assert(getNumInputs() == operandConstants.size());
|
||||||
|
|
||||||
// Fold each of the result expressions.
|
// Fold each of the result expressions.
|
||||||
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
|
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
|
||||||
// Constant fold each AffineExpr in AffineMap and add to 'results'.
|
// Constant fold each AffineExpr in AffineMap and add to 'results'.
|
||||||
for (auto expr : getResults()) {
|
for (auto expr : getResults()) {
|
||||||
auto *folded = exprFolder.constantFold(expr);
|
auto folded = exprFolder.constantFold(expr);
|
||||||
// If we didn't fold to a constant, then folding fails.
|
// If we didn't fold to a constant, then folding fails.
|
||||||
if (!folded)
|
if (!folded)
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -123,7 +123,7 @@ private:
|
||||||
void visitIfStmt(const IfStmt *ifStmt);
|
void visitIfStmt(const IfStmt *ifStmt);
|
||||||
void visitOperationStmt(const OperationStmt *opStmt);
|
void visitOperationStmt(const OperationStmt *opStmt);
|
||||||
void visitType(const Type *type);
|
void visitType(const Type *type);
|
||||||
void visitAttribute(const Attribute *attr);
|
void visitAttribute(Attribute attr);
|
||||||
void visitOperation(const Operation *op);
|
void visitOperation(const Operation *op);
|
||||||
|
|
||||||
DenseMap<AffineMap, int> affineMapIds;
|
DenseMap<AffineMap, int> affineMapIds;
|
||||||
|
@ -150,11 +150,11 @@ void ModuleState::visitType(const Type *type) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModuleState::visitAttribute(const Attribute *attr) {
|
void ModuleState::visitAttribute(Attribute attr) {
|
||||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
|
if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
|
||||||
recordAffineMapReference(mapAttr->getValue());
|
recordAffineMapReference(mapAttr.getValue());
|
||||||
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
|
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
|
||||||
for (auto elt : arrayAttr->getValue()) {
|
for (auto elt : arrayAttr.getValue()) {
|
||||||
visitAttribute(elt);
|
visitAttribute(elt);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -268,7 +268,7 @@ public:
|
||||||
|
|
||||||
void print(const Module *module);
|
void print(const Module *module);
|
||||||
void printFunctionReference(const Function *func);
|
void printFunctionReference(const Function *func);
|
||||||
void printAttribute(const Attribute *attr);
|
void printAttribute(Attribute attr);
|
||||||
void printType(const Type *type);
|
void printType(const Type *type);
|
||||||
void print(const Function *fn);
|
void print(const Function *fn);
|
||||||
void print(const ExtFunction *fn);
|
void print(const ExtFunction *fn);
|
||||||
|
@ -293,7 +293,7 @@ protected:
|
||||||
void printAffineMapReference(AffineMap affineMap);
|
void printAffineMapReference(AffineMap affineMap);
|
||||||
void printIntegerSetId(int integerSetId) const;
|
void printIntegerSetId(int integerSetId) const;
|
||||||
void printIntegerSetReference(IntegerSet integerSet);
|
void printIntegerSetReference(IntegerSet integerSet);
|
||||||
void printDenseElementsAttr(const DenseElementsAttr *attr);
|
void printDenseElementsAttr(DenseElementsAttr attr);
|
||||||
|
|
||||||
/// This enum is used to represent the binding stength of the enclosing
|
/// This enum is used to represent the binding stength of the enclosing
|
||||||
/// context that an AffineExprStorage is being printed in, so we can
|
/// context that an AffineExprStorage is being printed in, so we can
|
||||||
|
@ -404,36 +404,36 @@ void ModulePrinter::printFunctionReference(const Function *func) {
|
||||||
os << '@' << func->getName();
|
os << '@' << func->getName();
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printAttribute(const Attribute *attr) {
|
void ModulePrinter::printAttribute(Attribute attr) {
|
||||||
switch (attr->getKind()) {
|
switch (attr.getKind()) {
|
||||||
case Attribute::Kind::Bool:
|
case Attribute::Kind::Bool:
|
||||||
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
|
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::Integer:
|
case Attribute::Kind::Integer:
|
||||||
os << cast<IntegerAttr>(attr)->getValue();
|
os << attr.cast<IntegerAttr>().getValue();
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::Float:
|
case Attribute::Kind::Float:
|
||||||
printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
|
printFloatValue(attr.cast<FloatAttr>().getValue(), os);
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::String:
|
case Attribute::Kind::String:
|
||||||
os << '"';
|
os << '"';
|
||||||
printEscapedString(cast<StringAttr>(attr)->getValue(), os);
|
printEscapedString(attr.cast<StringAttr>().getValue(), os);
|
||||||
os << '"';
|
os << '"';
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::Array:
|
case Attribute::Kind::Array:
|
||||||
os << '[';
|
os << '[';
|
||||||
interleaveComma(cast<ArrayAttr>(attr)->getValue(),
|
interleaveComma(attr.cast<ArrayAttr>().getValue(),
|
||||||
[&](Attribute *attr) { printAttribute(attr); });
|
[&](Attribute attr) { printAttribute(attr); });
|
||||||
os << ']';
|
os << ']';
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::AffineMap:
|
case Attribute::Kind::AffineMap:
|
||||||
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
|
printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::Type:
|
case Attribute::Kind::Type:
|
||||||
printType(cast<TypeAttr>(attr)->getValue());
|
printType(attr.cast<TypeAttr>().getValue());
|
||||||
break;
|
break;
|
||||||
case Attribute::Kind::Function: {
|
case Attribute::Kind::Function: {
|
||||||
auto *function = cast<FunctionAttr>(attr)->getValue();
|
auto *function = attr.cast<FunctionAttr>().getValue();
|
||||||
if (!function) {
|
if (!function) {
|
||||||
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
|
||||||
} else {
|
} else {
|
||||||
|
@ -444,53 +444,52 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Attribute::Kind::OpaqueElements: {
|
case Attribute::Kind::OpaqueElements: {
|
||||||
auto *eltsAttr = cast<OpaqueElementsAttr>(attr);
|
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
|
||||||
os << "opaque<";
|
os << "opaque<";
|
||||||
printType(eltsAttr->getType());
|
printType(eltsAttr.getType());
|
||||||
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"'
|
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
|
||||||
<< '>';
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Attribute::Kind::DenseIntElements:
|
case Attribute::Kind::DenseIntElements:
|
||||||
case Attribute::Kind::DenseFPElements: {
|
case Attribute::Kind::DenseFPElements: {
|
||||||
auto *eltsAttr = cast<DenseElementsAttr>(attr);
|
auto eltsAttr = attr.cast<DenseElementsAttr>();
|
||||||
os << "dense<";
|
os << "dense<";
|
||||||
printType(eltsAttr->getType());
|
printType(eltsAttr.getType());
|
||||||
os << ", ";
|
os << ", ";
|
||||||
printDenseElementsAttr(eltsAttr);
|
printDenseElementsAttr(eltsAttr);
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Attribute::Kind::SplatElements: {
|
case Attribute::Kind::SplatElements: {
|
||||||
auto *elementsAttr = cast<SplatElementsAttr>(attr);
|
auto elementsAttr = attr.cast<SplatElementsAttr>();
|
||||||
os << "splat<";
|
os << "splat<";
|
||||||
printType(elementsAttr->getType());
|
printType(elementsAttr.getType());
|
||||||
os << ", ";
|
os << ", ";
|
||||||
printAttribute(elementsAttr->getValue());
|
printAttribute(elementsAttr.getValue());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Attribute::Kind::SparseElements: {
|
case Attribute::Kind::SparseElements: {
|
||||||
auto *elementsAttr = cast<SparseElementsAttr>(attr);
|
auto elementsAttr = attr.cast<SparseElementsAttr>();
|
||||||
os << "sparse<";
|
os << "sparse<";
|
||||||
printType(elementsAttr->getType());
|
printType(elementsAttr.getType());
|
||||||
os << ", ";
|
os << ", ";
|
||||||
printDenseElementsAttr(elementsAttr->getIndices());
|
printDenseElementsAttr(elementsAttr.getIndices());
|
||||||
os << ", ";
|
os << ", ";
|
||||||
printDenseElementsAttr(elementsAttr->getValues());
|
printDenseElementsAttr(elementsAttr.getValues());
|
||||||
os << '>';
|
os << '>';
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) {
|
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
||||||
auto *type = attr->getType();
|
auto *type = attr.getType();
|
||||||
auto shape = type->getShape();
|
auto shape = type->getShape();
|
||||||
auto rank = type->getRank();
|
auto rank = type->getRank();
|
||||||
|
|
||||||
SmallVector<Attribute *, 16> elements;
|
SmallVector<Attribute, 16> elements;
|
||||||
attr->getValues(elements);
|
attr.getValues(elements);
|
||||||
|
|
||||||
// Special case for degenerate tensors.
|
// Special case for degenerate tensors.
|
||||||
if (elements.empty()) {
|
if (elements.empty()) {
|
||||||
|
@ -934,9 +933,7 @@ public:
|
||||||
// Implement OpAsmPrinter.
|
// Implement OpAsmPrinter.
|
||||||
raw_ostream &getStream() const { return os; }
|
raw_ostream &getStream() const { return os; }
|
||||||
void printType(const Type *type) { ModulePrinter::printType(type); }
|
void printType(const Type *type) { ModulePrinter::printType(type); }
|
||||||
void printAttribute(const Attribute *attr) {
|
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
|
||||||
ModulePrinter::printAttribute(attr);
|
|
||||||
}
|
|
||||||
void printAffineMap(AffineMap map) {
|
void printAffineMap(AffineMap map) {
|
||||||
return ModulePrinter::printAffineMapReference(map);
|
return ModulePrinter::printAffineMapReference(map);
|
||||||
}
|
}
|
||||||
|
@ -980,7 +977,7 @@ protected:
|
||||||
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
|
||||||
specialName << 'c' << intOp->getValue();
|
specialName << 'c' << intOp->getValue();
|
||||||
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
|
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
|
||||||
if (isa<FunctionAttr>(constant->getValue()))
|
if (constant->getValue().isa<FunctionAttr>())
|
||||||
specialName << 'f';
|
specialName << 'f';
|
||||||
else
|
else
|
||||||
specialName << "cst";
|
specialName << "cst";
|
||||||
|
@ -1570,7 +1567,7 @@ void ModulePrinter::print(const MLFunction *fn) {
|
||||||
|
|
||||||
void Attribute::print(raw_ostream &os) const {
|
void Attribute::print(raw_ostream &os) const {
|
||||||
ModuleState state(/*no context is known*/ nullptr);
|
ModuleState state(/*no context is known*/ nullptr);
|
||||||
ModulePrinter(os, state).printAttribute(this);
|
ModulePrinter(os, state).printAttribute(*this);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Attribute::dump() const { print(llvm::errs()); }
|
void Attribute::dump() const { print(llvm::errs()); }
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
//
|
||||||
|
// This holds implementation details of Attribute.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef ATTRIBUTEDETAIL_H_
|
||||||
|
#define ATTRIBUTEDETAIL_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "mlir/IR/MLIRContext.h"
|
||||||
|
|
||||||
|
namespace mlir {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
/// Base storage class appearing in an attribute.
|
||||||
|
struct AttributeStorage {
|
||||||
|
Attribute::Kind kind : 8;
|
||||||
|
|
||||||
|
/// This field is true if this is, or contains, a function attribute.
|
||||||
|
bool isOrContainsFunctionCache : 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a boolean value.
|
||||||
|
struct BoolAttributeStorage : public AttributeStorage {
|
||||||
|
bool value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a integral value.
|
||||||
|
struct IntegerAttributeStorage : public AttributeStorage {
|
||||||
|
int64_t value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a floating point value.
|
||||||
|
struct FloatAttributeStorage final
|
||||||
|
: public AttributeStorage,
|
||||||
|
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
|
||||||
|
const llvm::fltSemantics &semantics;
|
||||||
|
size_t numObjects;
|
||||||
|
|
||||||
|
/// Returns an APFloat representing the stored value.
|
||||||
|
APFloat getValue() const {
|
||||||
|
auto val = APInt(APFloat::getSizeInBits(semantics),
|
||||||
|
{getTrailingObjects<uint64_t>(), numObjects});
|
||||||
|
return APFloat(semantics, val);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a string value.
|
||||||
|
struct StringAttributeStorage : public AttributeStorage {
|
||||||
|
StringRef value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing an array of other attributes.
|
||||||
|
struct ArrayAttributeStorage : public AttributeStorage {
|
||||||
|
ArrayRef<Attribute> value;
|
||||||
|
};
|
||||||
|
|
||||||
|
// An attribute representing a reference to an affine map.
|
||||||
|
struct AffineMapAttributeStorage : public AttributeStorage {
|
||||||
|
AffineMap value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a type.
|
||||||
|
struct TypeAttributeStorage : public AttributeStorage {
|
||||||
|
Type *value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a function.
|
||||||
|
struct FunctionAttributeStorage : public AttributeStorage {
|
||||||
|
Function *value;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A base attribute representing a reference to a vector or tensor constant.
|
||||||
|
struct ElementsAttributeStorage : public AttributeStorage {
|
||||||
|
VectorOrTensorType *type;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a vector or tensor constant,
|
||||||
|
/// inwhich all elements have the same value.
|
||||||
|
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
|
||||||
|
Attribute elt;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a dense vector or tensor object.
|
||||||
|
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||||
|
ArrayRef<char> data;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a dense integer vector or tensor
|
||||||
|
/// object.
|
||||||
|
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||||
|
size_t bitsWidth;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a dense float vector or tensor
|
||||||
|
/// object.
|
||||||
|
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a tensor constant with opaque
|
||||||
|
/// content.
|
||||||
|
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
|
||||||
|
StringRef bytes;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// An attribute representing a reference to a sparse vector or tensor object.
|
||||||
|
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||||
|
DenseIntElementsAttr indices;
|
||||||
|
DenseElementsAttr values;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace mlir
|
||||||
|
|
||||||
|
#endif // ATTRIBUTEDETAIL_H_
|
|
@ -0,0 +1,214 @@
|
||||||
|
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
|
||||||
|
//
|
||||||
|
// Copyright 2019 The MLIR Authors.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
// =============================================================================
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h"
|
||||||
|
#include "AttributeDetail.h"
|
||||||
|
#include "mlir/IR/Function.h"
|
||||||
|
#include "mlir/IR/Types.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::detail;
|
||||||
|
|
||||||
|
Attribute::Kind Attribute::getKind() const { return attr->kind; }
|
||||||
|
|
||||||
|
bool Attribute::isOrContainsFunction() const {
|
||||||
|
return attr->isOrContainsFunctionCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
BoolAttr::BoolAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||||
|
|
||||||
|
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
int64_t IntegerAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
APFloat FloatAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->getValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
double FloatAttr::getDouble() const { return getValue().convertToDouble(); }
|
||||||
|
|
||||||
|
StringAttr::StringAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
StringRef StringAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayAttr::ArrayAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
ArrayRef<Attribute> ArrayAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
AffineMapAttr::AffineMapAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
AffineMap AffineMapAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
Type *TypeAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
Function *FunctionAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->value;
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
||||||
|
|
||||||
|
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||||
|
|
||||||
|
VectorOrTensorType *ElementsAttr::getType() const {
|
||||||
|
return static_cast<ImplType *>(attr)->type;
|
||||||
|
}
|
||||||
|
|
||||||
|
SplatElementsAttr::SplatElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: ElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
Attribute SplatElementsAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->elt;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr::DenseElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: ElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
|
switch (getKind()) {
|
||||||
|
case Attribute::Kind::DenseIntElements:
|
||||||
|
cast<DenseIntElementsAttr>().getValues(values);
|
||||||
|
return;
|
||||||
|
case Attribute::Kind::DenseFPElements:
|
||||||
|
cast<DenseFPElementsAttr>().getValues(values);
|
||||||
|
return;
|
||||||
|
default:
|
||||||
|
llvm_unreachable("unexpected element type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ArrayRef<char> DenseElementsAttr::getRawData() const {
|
||||||
|
return static_cast<ImplType *>(attr)->data;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: DenseElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
||||||
|
/// starting from `rawData`.
|
||||||
|
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
||||||
|
uint64_t value) {
|
||||||
|
// Read the destination bytes which will be written to.
|
||||||
|
uint64_t dst = 0;
|
||||||
|
auto dstData = reinterpret_cast<char *>(&dst);
|
||||||
|
auto endPos = bitPos + bitWidth;
|
||||||
|
auto start = data + bitPos / 8;
|
||||||
|
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
||||||
|
std::copy(start, end, dstData);
|
||||||
|
|
||||||
|
// Clean up the invalid bits in the destination bytes.
|
||||||
|
dst &= ~(-1UL << (bitPos % 8));
|
||||||
|
|
||||||
|
// Get the valid bits of the source value, shift them to right position,
|
||||||
|
// then add them to the destination bytes.
|
||||||
|
value <<= bitPos % 8;
|
||||||
|
dst |= value;
|
||||||
|
|
||||||
|
// Write the destination bytes back.
|
||||||
|
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
||||||
|
std::copy(range.begin(), range.end(), start);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
||||||
|
/// and put them in the lowest bits.
|
||||||
|
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
||||||
|
size_t bitsWidth) {
|
||||||
|
uint64_t dst = 0;
|
||||||
|
auto dstData = reinterpret_cast<char *>(&dst);
|
||||||
|
auto endPos = bitPos + bitsWidth;
|
||||||
|
auto start = rawData + bitPos / 8;
|
||||||
|
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
||||||
|
std::copy(start, end, dstData);
|
||||||
|
|
||||||
|
dst >>= bitPos % 8;
|
||||||
|
dst &= ~(-1UL << bitsWidth);
|
||||||
|
return dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
|
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
|
||||||
|
auto elementNum = getType()->getNumElements();
|
||||||
|
auto context = getType()->getContext();
|
||||||
|
values.reserve(elementNum);
|
||||||
|
if (bitsWidth == 64) {
|
||||||
|
ArrayRef<int64_t> vs(
|
||||||
|
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
||||||
|
getRawData().size() / 8});
|
||||||
|
for (auto value : vs) {
|
||||||
|
auto attr = IntegerAttr::get(value, context);
|
||||||
|
values.push_back(attr);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const auto *rawData = getRawData().data();
|
||||||
|
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
||||||
|
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
||||||
|
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
||||||
|
auto attr = IntegerAttr::get(value.getSExtValue(), context);
|
||||||
|
values.push_back(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: DenseElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
|
||||||
|
auto elementNum = getType()->getNumElements();
|
||||||
|
auto context = getType()->getContext();
|
||||||
|
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
||||||
|
getRawData().size() / 8});
|
||||||
|
values.reserve(elementNum);
|
||||||
|
for (auto v : vs) {
|
||||||
|
auto attr = FloatAttr::get(v, context);
|
||||||
|
values.push_back(attr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
OpaqueElementsAttr::OpaqueElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: ElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
StringRef OpaqueElementsAttr::getValue() const {
|
||||||
|
return static_cast<ImplType *>(attr)->bytes;
|
||||||
|
}
|
||||||
|
|
||||||
|
SparseElementsAttr::SparseElementsAttr(Attribute::ImplType *ptr)
|
||||||
|
: ElementsAttr(ptr) {}
|
||||||
|
|
||||||
|
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
|
||||||
|
return static_cast<ImplType *>(attr)->indices;
|
||||||
|
}
|
||||||
|
|
||||||
|
DenseElementsAttr SparseElementsAttr::getValues() const {
|
||||||
|
return static_cast<ImplType *>(attr)->values;
|
||||||
|
}
|
|
@ -112,60 +112,60 @@ UnrankedTensorType *Builder::getTensorType(Type *elementType) {
|
||||||
// Attributes.
|
// Attributes.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
BoolAttr *Builder::getBoolAttr(bool value) {
|
BoolAttr Builder::getBoolAttr(bool value) {
|
||||||
return BoolAttr::get(value, context);
|
return BoolAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerAttr *Builder::getIntegerAttr(int64_t value) {
|
IntegerAttr Builder::getIntegerAttr(int64_t value) {
|
||||||
return IntegerAttr::get(value, context);
|
return IntegerAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *Builder::getFloatAttr(double value) {
|
FloatAttr Builder::getFloatAttr(double value) {
|
||||||
return FloatAttr::get(APFloat(value), context);
|
return FloatAttr::get(APFloat(value), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
|
FloatAttr Builder::getFloatAttr(const APFloat &value) {
|
||||||
return FloatAttr::get(value, context);
|
return FloatAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
StringAttr *Builder::getStringAttr(StringRef bytes) {
|
StringAttr Builder::getStringAttr(StringRef bytes) {
|
||||||
return StringAttr::get(bytes, context);
|
return StringAttr::get(bytes, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
|
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
|
||||||
return ArrayAttr::get(value, context);
|
return ArrayAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) {
|
AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
|
||||||
return AffineMapAttr::get(map);
|
return AffineMapAttr::get(map);
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeAttr *Builder::getTypeAttr(Type *type) {
|
TypeAttr Builder::getTypeAttr(Type *type) {
|
||||||
return TypeAttr::get(type, context);
|
return TypeAttr::get(type, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionAttr *Builder::getFunctionAttr(const Function *value) {
|
FunctionAttr Builder::getFunctionAttr(const Function *value) {
|
||||||
return FunctionAttr::get(value, context);
|
return FunctionAttr::get(value, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
|
||||||
Attribute *elt) {
|
Attribute elt) {
|
||||||
return SplatElementsAttr::get(type, elt);
|
return SplatElementsAttr::get(type, elt);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
|
||||||
ArrayRef<char> data) {
|
ArrayRef<char> data) {
|
||||||
return DenseElementsAttr::get(type, data);
|
return DenseElementsAttr::get(type, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
|
||||||
DenseIntElementsAttr *indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr *values) {
|
DenseElementsAttr values) {
|
||||||
return SparseElementsAttr::get(type, indices, values);
|
return SparseElementsAttr::get(type, indices, values);
|
||||||
}
|
}
|
||||||
|
|
||||||
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
|
||||||
StringRef bytes) {
|
StringRef bytes) {
|
||||||
return OpaqueElementsAttr::get(type, bytes);
|
return OpaqueElementsAttr::get(type, bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -86,13 +86,13 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
auto &builder = parser->getBuilder();
|
auto &builder = parser->getBuilder();
|
||||||
auto *affineIntTy = builder.getIndexType();
|
auto *affineIntTy = builder.getIndexType();
|
||||||
|
|
||||||
AffineMapAttr *mapAttr;
|
AffineMapAttr mapAttr;
|
||||||
unsigned numDims;
|
unsigned numDims;
|
||||||
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
|
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
|
||||||
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
parseDimAndSymbolList(parser, result->operands, numDims) ||
|
||||||
parser->parseOptionalAttributeDict(result->attributes))
|
parser->parseOptionalAttributeDict(result->attributes))
|
||||||
return true;
|
return true;
|
||||||
auto map = mapAttr->getValue();
|
auto map = mapAttr.getValue();
|
||||||
|
|
||||||
if (map.getNumDims() != numDims ||
|
if (map.getNumDims() != numDims ||
|
||||||
numDims + map.getNumSymbols() != result->operands.size()) {
|
numDims + map.getNumSymbols() != result->operands.size()) {
|
||||||
|
@ -113,12 +113,12 @@ void AffineApplyOp::print(OpAsmPrinter *p) const {
|
||||||
|
|
||||||
bool AffineApplyOp::verify() const {
|
bool AffineApplyOp::verify() const {
|
||||||
// Check that affine map attribute was specified.
|
// Check that affine map attribute was specified.
|
||||||
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
|
||||||
if (!affineMapAttr)
|
if (!affineMapAttr)
|
||||||
return emitOpError("requires an affine map");
|
return emitOpError("requires an affine map");
|
||||||
|
|
||||||
// Check input and output dimensions match.
|
// Check input and output dimensions match.
|
||||||
auto map = affineMapAttr->getValue();
|
auto map = affineMapAttr.getValue();
|
||||||
|
|
||||||
// Verify that operand count matches affine map dimension and symbol count.
|
// Verify that operand count matches affine map dimension and symbol count.
|
||||||
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
|
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
|
||||||
|
@ -155,8 +155,8 @@ bool AffineApplyOp::isValidSymbol() const {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
|
bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
|
||||||
SmallVectorImpl<Attribute *> &results,
|
SmallVectorImpl<Attribute> &results,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
auto map = getAffineMap();
|
auto map = getAffineMap();
|
||||||
if (map.constantFold(operandConstants, results))
|
if (map.constantFold(operandConstants, results))
|
||||||
|
@ -171,21 +171,21 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
|
||||||
|
|
||||||
/// Builds a constant op with the specified attribute value and result type.
|
/// Builds a constant op with the specified attribute value and result type.
|
||||||
void ConstantOp::build(Builder *builder, OperationState *result,
|
void ConstantOp::build(Builder *builder, OperationState *result,
|
||||||
Attribute *value, Type *type) {
|
Attribute value, Type *type) {
|
||||||
result->addAttribute("value", value);
|
result->addAttribute("value", value);
|
||||||
result->types.push_back(type);
|
result->types.push_back(type);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ConstantOp::print(OpAsmPrinter *p) const {
|
void ConstantOp::print(OpAsmPrinter *p) const {
|
||||||
*p << "constant " << *getValue();
|
*p << "constant " << getValue();
|
||||||
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
|
||||||
|
|
||||||
if (!isa<FunctionAttr>(getValue()))
|
if (!getValue().isa<FunctionAttr>())
|
||||||
*p << " : " << *getType();
|
*p << " : " << *getType();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
Attribute *valueAttr;
|
Attribute valueAttr;
|
||||||
Type *type;
|
Type *type;
|
||||||
|
|
||||||
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
|
||||||
|
@ -194,8 +194,8 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
|
|
||||||
// 'constant' taking a function reference doesn't get a redundant type
|
// 'constant' taking a function reference doesn't get a redundant type
|
||||||
// specifier. The attribute itself carries it.
|
// specifier. The attribute itself carries it.
|
||||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
|
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
|
||||||
return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
|
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
|
||||||
|
|
||||||
return parser->parseColonType(type) ||
|
return parser->parseColonType(type) ||
|
||||||
parser->addTypeToList(type, result->types);
|
parser->addTypeToList(type, result->types);
|
||||||
|
@ -204,32 +204,32 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
/// The constant op requires an attribute, and furthermore requires that it
|
/// The constant op requires an attribute, and furthermore requires that it
|
||||||
/// matches the return type.
|
/// matches the return type.
|
||||||
bool ConstantOp::verify() const {
|
bool ConstantOp::verify() const {
|
||||||
auto *value = getValue();
|
auto value = getValue();
|
||||||
if (!value)
|
if (!value)
|
||||||
return emitOpError("requires a 'value' attribute");
|
return emitOpError("requires a 'value' attribute");
|
||||||
|
|
||||||
auto *type = this->getType();
|
auto *type = this->getType();
|
||||||
if (isa<IntegerType>(type) || type->isIndex()) {
|
if (isa<IntegerType>(type) || type->isIndex()) {
|
||||||
if (!isa<IntegerAttr>(value))
|
if (!value.isa<IntegerAttr>())
|
||||||
return emitOpError(
|
return emitOpError(
|
||||||
"requires 'value' to be an integer for an integer result type");
|
"requires 'value' to be an integer for an integer result type");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<FloatType>(type)) {
|
if (isa<FloatType>(type)) {
|
||||||
if (!isa<FloatAttr>(value))
|
if (!value.isa<FloatAttr>())
|
||||||
return emitOpError("requires 'value' to be a floating point constant");
|
return emitOpError("requires 'value' to be a floating point constant");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type->isTFString()) {
|
if (type->isTFString()) {
|
||||||
if (!isa<StringAttr>(value))
|
if (!value.isa<StringAttr>())
|
||||||
return emitOpError("requires 'value' to be a string constant");
|
return emitOpError("requires 'value' to be a string constant");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isa<FunctionType>(type)) {
|
if (isa<FunctionType>(type)) {
|
||||||
if (!isa<FunctionAttr>(value))
|
if (!value.isa<FunctionAttr>())
|
||||||
return emitOpError("requires 'value' to be a function reference");
|
return emitOpError("requires 'value' to be a function reference");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -238,8 +238,8 @@ bool ConstantOp::verify() const {
|
||||||
"requires a result type that aligns with the 'value' attribute");
|
"requires a result type that aligns with the 'value' attribute");
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.empty() && "constant has no operands");
|
assert(operands.empty() && "constant has no operands");
|
||||||
return getValue();
|
return getValue();
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include "mlir/IR/MLIRContext.h"
|
#include "mlir/IR/MLIRContext.h"
|
||||||
#include "AffineExprDetail.h"
|
#include "AffineExprDetail.h"
|
||||||
#include "AffineMapDetail.h"
|
#include "AffineMapDetail.h"
|
||||||
|
#include "AttributeDetail.h"
|
||||||
#include "AttributeListStorage.h"
|
#include "AttributeListStorage.h"
|
||||||
#include "IntegerSetDetail.h"
|
#include "IntegerSetDetail.h"
|
||||||
#include "mlir/IR/AffineExpr.h"
|
#include "mlir/IR/AffineExpr.h"
|
||||||
|
@ -169,35 +170,35 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
|
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
|
||||||
// Float attributes are uniqued based on wrapped APFloat.
|
// Float attributes are uniqued based on wrapped APFloat.
|
||||||
using KeyTy = APFloat;
|
using KeyTy = APFloat;
|
||||||
using DenseMapInfo<FloatAttr *>::getHashValue;
|
using DenseMapInfo<FloatAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<FloatAttr *>::isEqual;
|
using DenseMapInfo<FloatAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
|
static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs.bitwiseIsEqual(rhs->getValue());
|
return lhs.bitwiseIsEqual(rhs->getValue());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
|
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
|
||||||
// Array attributes are uniqued based on their elements.
|
// Array attributes are uniqued based on their elements.
|
||||||
using KeyTy = ArrayRef<Attribute *>;
|
using KeyTy = ArrayRef<Attribute>;
|
||||||
using DenseMapInfo<ArrayAttr *>::getHashValue;
|
using DenseMapInfo<ArrayAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<ArrayAttr *>::isEqual;
|
using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine_range(key.begin(), key.end());
|
return hash_combine_range(key.begin(), key.end());
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) {
|
static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == rhs->getValue();
|
return lhs == rhs->value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -218,37 +219,39 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
|
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
|
||||||
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
|
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
|
||||||
using DenseMapInfo<DenseElementsAttr *>::getHashValue;
|
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<DenseElementsAttr *>::isEqual;
|
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) {
|
static bool isEqual(const KeyTy &lhs,
|
||||||
|
const DenseElementsAttributeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
|
return lhs == std::make_pair(rhs->type, rhs->data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> {
|
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
|
||||||
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
|
||||||
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue;
|
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
|
||||||
using DenseMapInfo<OpaqueElementsAttr *>::isEqual;
|
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
|
||||||
|
|
||||||
static unsigned getHashValue(KeyTy key) {
|
static unsigned getHashValue(KeyTy key) {
|
||||||
return hash_combine(
|
return hash_combine(
|
||||||
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
key.first, hash_combine_range(key.second.begin(), key.second.end()));
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) {
|
static bool isEqual(const KeyTy &lhs,
|
||||||
|
const OpaqueElementsAttributeStorage *rhs) {
|
||||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||||
return false;
|
return false;
|
||||||
return lhs == std::make_pair(rhs->getType(), rhs->getValue());
|
return lhs == std::make_pair(rhs->type, rhs->bytes);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // end anonymous namespace.
|
} // end anonymous namespace.
|
||||||
|
@ -343,28 +346,29 @@ public:
|
||||||
MemRefTypeSet memrefs;
|
MemRefTypeSet memrefs;
|
||||||
|
|
||||||
// Attribute uniquing.
|
// Attribute uniquing.
|
||||||
BoolAttr *boolAttrs[2] = {nullptr};
|
BoolAttributeStorage *boolAttrs[2] = {nullptr};
|
||||||
DenseMap<int64_t, IntegerAttr *> integerAttrs;
|
DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
|
||||||
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
|
DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
|
||||||
StringMap<StringAttr *> stringAttrs;
|
StringMap<StringAttributeStorage *> stringAttrs;
|
||||||
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
|
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
|
||||||
ArrayAttrSet arrayAttrs;
|
ArrayAttrSet arrayAttrs;
|
||||||
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs;
|
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
|
||||||
DenseMap<Type *, TypeAttr *> typeAttrs;
|
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
|
||||||
using AttributeListSet =
|
using AttributeListSet =
|
||||||
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
|
||||||
AttributeListSet attributeLists;
|
AttributeListSet attributeLists;
|
||||||
DenseMap<const Function *, FunctionAttr *> functionAttrs;
|
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
|
||||||
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
|
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
|
||||||
|
SplatElementsAttributeStorage *>
|
||||||
splatElementsAttrs;
|
splatElementsAttrs;
|
||||||
using DenseElementsAttrSet =
|
using DenseElementsAttrSet =
|
||||||
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
|
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
|
||||||
DenseElementsAttrSet denseElementsAttrs;
|
DenseElementsAttrSet denseElementsAttrs;
|
||||||
using OpaqueElementsAttrSet =
|
using OpaqueElementsAttrSet =
|
||||||
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>;
|
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
|
||||||
OpaqueElementsAttrSet opaqueElementsAttrs;
|
OpaqueElementsAttrSet opaqueElementsAttrs;
|
||||||
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
|
DenseMap<std::tuple<Type *, Attribute, Attribute>,
|
||||||
SparseElementsAttr *>
|
SparseElementsAttributeStorage *>
|
||||||
sparseElementsAttrs;
|
sparseElementsAttrs;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -716,31 +720,36 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
|
||||||
// Attribute uniquing
|
// Attribute uniquing
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
BoolAttr *BoolAttr::get(bool value, MLIRContext *context) {
|
BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
|
||||||
auto *&result = context->getImpl().boolAttrs[value];
|
auto *&result = context->getImpl().boolAttrs[value];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<BoolAttr>();
|
result = context->getImpl().allocator.Allocate<BoolAttributeStorage>();
|
||||||
new (result) BoolAttr(value);
|
new (result) BoolAttributeStorage{{Attribute::Kind::Bool,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
value};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
|
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||||
auto *&result = context->getImpl().integerAttrs[value];
|
auto *&result = context->getImpl().integerAttrs[value];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<IntegerAttr>();
|
result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
|
||||||
new (result) IntegerAttr(value);
|
new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
value};
|
||||||
|
result->value = value;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
|
FloatAttr FloatAttr::get(double value, MLIRContext *context) {
|
||||||
return get(APFloat(value), context);
|
return get(APFloat(value), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if the float attribute has been created already.
|
// Look to see if the float attribute has been created already.
|
||||||
|
@ -755,33 +764,35 @@ FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
|
||||||
// Here one word's bitwidth equals to that of uint64_t.
|
// Here one word's bitwidth equals to that of uint64_t.
|
||||||
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
|
||||||
|
|
||||||
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
|
auto byteSize =
|
||||||
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
|
FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
|
||||||
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
|
auto rawMem =
|
||||||
|
impl.allocator.Allocate(byteSize, alignof(FloatAttributeStorage));
|
||||||
|
auto result = ::new (rawMem) FloatAttributeStorage{
|
||||||
|
{Attribute::Kind::Float, /*isOrContainsFunction=*/false},
|
||||||
|
{},
|
||||||
|
value.getSemantics(),
|
||||||
|
elements.size()};
|
||||||
std::uninitialized_copy(elements.begin(), elements.end(),
|
std::uninitialized_copy(elements.begin(), elements.end(),
|
||||||
result->getTrailingObjects<uint64_t>());
|
result->getTrailingObjects<uint64_t>());
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
APFloat FloatAttr::getValue() const {
|
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
|
||||||
auto val = APInt(APFloat::getSizeInBits(semantics),
|
|
||||||
{getTrailingObjects<uint64_t>(), numObjects});
|
|
||||||
return APFloat(semantics, val);
|
|
||||||
}
|
|
||||||
|
|
||||||
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
|
|
||||||
auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
|
auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
|
||||||
|
|
||||||
if (it->second)
|
if (it->second)
|
||||||
return it->second;
|
return it->second;
|
||||||
|
|
||||||
auto result = context->getImpl().allocator.Allocate<StringAttr>();
|
auto result = context->getImpl().allocator.Allocate<StringAttributeStorage>();
|
||||||
new (result) StringAttr(it->first());
|
new (result) StringAttributeStorage{{Attribute::Kind::String,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
it->first()};
|
||||||
it->second = result;
|
it->second = result;
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
|
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
|
||||||
auto &impl = context->getImpl();
|
auto &impl = context->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if we already have this.
|
||||||
|
@ -792,61 +803,66 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
|
||||||
return *existing.first;
|
return *existing.first;
|
||||||
|
|
||||||
// On the first use, we allocate them into the bump pointer.
|
// On the first use, we allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<ArrayAttr>();
|
auto *result = impl.allocator.Allocate<ArrayAttributeStorage>();
|
||||||
|
|
||||||
// Copy the elements into the bump pointer.
|
// Copy the elements into the bump pointer.
|
||||||
value = impl.copyInto(value);
|
value = impl.copyInto(value);
|
||||||
|
|
||||||
// Check to see if any of the elements have a function attr.
|
// Check to see if any of the elements have a function attr.
|
||||||
bool hasFunctionAttr = false;
|
bool hasFunctionAttr = false;
|
||||||
for (auto *elt : value)
|
for (auto elt : value)
|
||||||
if (elt->isOrContainsFunction()) {
|
if (elt.isOrContainsFunction()) {
|
||||||
hasFunctionAttr = true;
|
hasFunctionAttr = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize the memory using placement new.
|
// Initialize the memory using placement new.
|
||||||
new (result) ArrayAttr(value, hasFunctionAttr);
|
new (result)
|
||||||
|
ArrayAttributeStorage{{Attribute::Kind::Array, hasFunctionAttr}, value};
|
||||||
|
|
||||||
// Cache and return it.
|
// Cache and return it.
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
AffineMapAttr *AffineMapAttr::get(AffineMap value) {
|
AffineMapAttr AffineMapAttr::get(AffineMap value) {
|
||||||
auto *context = value.getResult(0).getContext();
|
auto *context = value.getResult(0).getContext();
|
||||||
auto &result = context->getImpl().affineMapAttrs[value];
|
auto &result = context->getImpl().affineMapAttrs[value];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<AffineMapAttr>();
|
result = context->getImpl().allocator.Allocate<AffineMapAttributeStorage>();
|
||||||
new (result) AffineMapAttr(value);
|
new (result) AffineMapAttributeStorage{{Attribute::Kind::AffineMap,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
value};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
|
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
|
||||||
auto *&result = context->getImpl().typeAttrs[type];
|
auto *&result = context->getImpl().typeAttrs[type];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<TypeAttr>();
|
result = context->getImpl().allocator.Allocate<TypeAttributeStorage>();
|
||||||
new (result) TypeAttr(type);
|
new (result) TypeAttributeStorage{{Attribute::Kind::Type,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
type};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
|
FunctionAttr FunctionAttr::get(const Function *value, MLIRContext *context) {
|
||||||
assert(value && "Cannot get FunctionAttr for a null function");
|
assert(value && "Cannot get FunctionAttr for a null function");
|
||||||
|
|
||||||
auto *&result = context->getImpl().functionAttrs[value];
|
auto *&result = context->getImpl().functionAttrs[value];
|
||||||
if (result)
|
if (result)
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
result = context->getImpl().allocator.Allocate<FunctionAttr>();
|
result = context->getImpl().allocator.Allocate<FunctionAttributeStorage>();
|
||||||
new (result) FunctionAttr(const_cast<Function *>(value));
|
new (result) FunctionAttributeStorage{{Attribute::Kind::Function,
|
||||||
|
/*isOrContainsFunction=*/true},
|
||||||
|
const_cast<Function *>(value)};
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
|
|
||||||
|
|
||||||
/// This function is used by the internals of the Function class to null out
|
/// This function is used by the internals of the Function class to null out
|
||||||
/// attributes refering to functions that are about to be deleted.
|
/// attributes refering to functions that are about to be deleted.
|
||||||
void FunctionAttr::dropFunctionReference(Function *value) {
|
void FunctionAttr::dropFunctionReference(Function *value) {
|
||||||
|
@ -935,30 +951,29 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type,
|
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
|
||||||
StringRef bytes) {
|
Attribute elt) {
|
||||||
assert(isValidTensorElementType(type->getElementType()) &&
|
|
||||||
"Input element type should be a valid tensor element type");
|
|
||||||
|
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type->getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if this constant is already defined.
|
// Look to see if we already have this.
|
||||||
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
||||||
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
|
|
||||||
|
|
||||||
// If we already have it, return that value.
|
// If we already have it, return that value.
|
||||||
if (!existing.second)
|
if (result)
|
||||||
return *existing.first;
|
return result;
|
||||||
|
|
||||||
// Otherwise, allocate a new one, unique it and return it.
|
// Otherwise, allocate them into the bump pointer.
|
||||||
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>();
|
result = impl.allocator.Allocate<SplatElementsAttributeStorage>();
|
||||||
bytes = bytes.copy(impl.allocator);
|
new (result) SplatElementsAttributeStorage{{{Attribute::Kind::SplatElements,
|
||||||
new (result) OpaqueElementsAttr(type, bytes);
|
/*isOrContainsFunction=*/false},
|
||||||
return *existing.first = result;
|
type},
|
||||||
|
elt};
|
||||||
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
ArrayRef<char> data) {
|
ArrayRef<char> data) {
|
||||||
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
|
||||||
(void)(bitsRequired);
|
(void)(bitsRequired);
|
||||||
assert((bitsRequired <= data.size() * 8L) &&
|
assert((bitsRequired <= data.size() * 8L) &&
|
||||||
|
@ -981,18 +996,25 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
case Type::Kind::F16:
|
case Type::Kind::F16:
|
||||||
case Type::Kind::F32:
|
case Type::Kind::F32:
|
||||||
case Type::Kind::F64: {
|
case Type::Kind::F64: {
|
||||||
auto *result = impl.allocator.Allocate<DenseFPElementsAttr>();
|
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
|
||||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||||
new (result) DenseFPElementsAttr(type, {copy, data.size()});
|
new (result) DenseFPElementsAttributeStorage{
|
||||||
|
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
|
||||||
|
type},
|
||||||
|
{copy, data.size()}}};
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
case Type::Kind::Integer: {
|
case Type::Kind::Integer: {
|
||||||
auto width = cast<IntegerType>(eltType)->getWidth();
|
auto width = ::cast<IntegerType>(eltType)->getWidth();
|
||||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttr>();
|
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||||
new (result) DenseIntElementsAttr(type, {copy, data.size()}, width);
|
new (result) DenseIntElementsAttributeStorage{
|
||||||
|
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
|
||||||
|
type},
|
||||||
|
{copy, data.size()}},
|
||||||
|
width};
|
||||||
return *existing.first = result;
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
|
@ -1000,118 +1022,33 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
|
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
|
||||||
/// starting from `rawData`.
|
StringRef bytes) {
|
||||||
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
|
assert(isValidTensorElementType(type->getElementType()) &&
|
||||||
uint64_t value) {
|
"Input element type should be a valid tensor element type");
|
||||||
// Read the destination bytes which will be written to.
|
|
||||||
uint64_t dst = 0;
|
|
||||||
auto dstData = reinterpret_cast<char *>(&dst);
|
|
||||||
auto endPos = bitPos + bitWidth;
|
|
||||||
auto start = data + bitPos / 8;
|
|
||||||
auto end = data + endPos / 8 + (endPos % 8 != 0);
|
|
||||||
std::copy(start, end, dstData);
|
|
||||||
|
|
||||||
// Clean up the invalid bits in the destination bytes.
|
|
||||||
dst &= ~(-1UL << (bitPos % 8));
|
|
||||||
|
|
||||||
// Get the valid bits of the source value, shift them to right position,
|
|
||||||
// then add them to the destination bytes.
|
|
||||||
value <<= bitPos % 8;
|
|
||||||
dst |= value;
|
|
||||||
|
|
||||||
// Write the destination bytes back.
|
|
||||||
ArrayRef<char> range({dstData, (size_t)(end - start)});
|
|
||||||
std::copy(range.begin(), range.end(), start);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
|
|
||||||
/// and put them in the lowest bits.
|
|
||||||
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
|
|
||||||
size_t bitsWidth) {
|
|
||||||
uint64_t dst = 0;
|
|
||||||
auto dstData = reinterpret_cast<char *>(&dst);
|
|
||||||
auto endPos = bitPos + bitsWidth;
|
|
||||||
auto start = rawData + bitPos / 8;
|
|
||||||
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
|
|
||||||
std::copy(start, end, dstData);
|
|
||||||
|
|
||||||
dst >>= bitPos % 8;
|
|
||||||
dst &= ~(-1UL << bitsWidth);
|
|
||||||
return dst;
|
|
||||||
}
|
|
||||||
|
|
||||||
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute *> &values) const {
|
|
||||||
switch (getKind()) {
|
|
||||||
case Attribute::Kind::DenseIntElements:
|
|
||||||
cast<DenseIntElementsAttr>(this)->getValues(values);
|
|
||||||
return;
|
|
||||||
case Attribute::Kind::DenseFPElements:
|
|
||||||
cast<DenseFPElementsAttr>(this)->getValues(values);
|
|
||||||
return;
|
|
||||||
default:
|
|
||||||
llvm_unreachable("unexpected element type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void DenseIntElementsAttr::getValues(
|
|
||||||
SmallVectorImpl<Attribute *> &values) const {
|
|
||||||
auto elementNum = getType()->getNumElements();
|
|
||||||
auto context = getType()->getContext();
|
|
||||||
values.reserve(elementNum);
|
|
||||||
if (bitsWidth == 64) {
|
|
||||||
ArrayRef<int64_t> vs(
|
|
||||||
{reinterpret_cast<const int64_t *>(getRawData().data()),
|
|
||||||
getRawData().size() / 8});
|
|
||||||
for (auto value : vs) {
|
|
||||||
auto *attr = IntegerAttr::get(value, context);
|
|
||||||
values.push_back(attr);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const auto *rawData = getRawData().data();
|
|
||||||
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
|
|
||||||
uint64_t bits = readBits(rawData, pos, bitsWidth);
|
|
||||||
APInt value(bitsWidth, bits, /*isSigned=*/true);
|
|
||||||
auto *attr = IntegerAttr::get(value.getSExtValue(), context);
|
|
||||||
values.push_back(attr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void DenseFPElementsAttr::getValues(
|
|
||||||
SmallVectorImpl<Attribute *> &values) const {
|
|
||||||
auto elementNum = getType()->getNumElements();
|
|
||||||
auto context = getType()->getContext();
|
|
||||||
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
|
|
||||||
getRawData().size() / 8});
|
|
||||||
values.reserve(elementNum);
|
|
||||||
for (auto v : vs) {
|
|
||||||
auto *attr = FloatAttr::get(v, context);
|
|
||||||
values.push_back(attr);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
SplatElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type,
|
|
||||||
Attribute *elt) {
|
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type->getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if this constant is already defined.
|
||||||
auto *&result = impl.splatElementsAttrs[{type, elt}];
|
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
|
||||||
|
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
|
||||||
|
|
||||||
// If we already have it, return that value.
|
// If we already have it, return that value.
|
||||||
if (result)
|
if (!existing.second)
|
||||||
return result;
|
return *existing.first;
|
||||||
|
|
||||||
// Otherwise, allocate them into the bump pointer.
|
// Otherwise, allocate a new one, unique it and return it.
|
||||||
result = impl.allocator.Allocate<SplatElementsAttr>();
|
auto *result = impl.allocator.Allocate<OpaqueElementsAttributeStorage>();
|
||||||
new (result) SplatElementsAttr(type, elt);
|
bytes = bytes.copy(impl.allocator);
|
||||||
|
new (result) OpaqueElementsAttributeStorage{
|
||||||
return result;
|
{{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type},
|
||||||
|
bytes};
|
||||||
|
return *existing.first = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
|
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
|
||||||
DenseIntElementsAttr *indices,
|
DenseIntElementsAttr indices,
|
||||||
DenseElementsAttr *values) {
|
DenseElementsAttr values) {
|
||||||
auto &impl = type->getContext()->getImpl();
|
auto &impl = type->getContext()->getImpl();
|
||||||
|
|
||||||
// Look to see if we already have this.
|
// Look to see if we already have this.
|
||||||
|
@ -1123,8 +1060,12 @@ SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
// Otherwise, allocate them into the bump pointer.
|
// Otherwise, allocate them into the bump pointer.
|
||||||
result = impl.allocator.Allocate<SparseElementsAttr>();
|
result = impl.allocator.Allocate<SparseElementsAttributeStorage>();
|
||||||
new (result) SparseElementsAttr(type, indices, values);
|
new (result) SparseElementsAttributeStorage{{{Attribute::Kind::SparseElements,
|
||||||
|
/*isOrContainsFunction=*/false},
|
||||||
|
type},
|
||||||
|
indices,
|
||||||
|
values};
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,7 +148,7 @@ ArrayRef<NamedAttribute> Operation::getAttrs() const {
|
||||||
|
|
||||||
/// If an attribute exists with the specified name, change it to the new
|
/// If an attribute exists with the specified name, change it to the new
|
||||||
/// value. Otherwise, add a new attribute with the specified name/value.
|
/// value. Otherwise, add a new attribute with the specified name/value.
|
||||||
void Operation::setAttr(Identifier name, Attribute *value) {
|
void Operation::setAttr(Identifier name, Attribute value) {
|
||||||
assert(value && "attributes may never be null");
|
assert(value && "attributes may never be null");
|
||||||
auto origAttrs = getAttrs();
|
auto origAttrs = getAttrs();
|
||||||
|
|
||||||
|
@ -225,8 +225,8 @@ void Operation::erase() {
|
||||||
/// Attempt to constant fold this operation with the specified constant
|
/// Attempt to constant fold this operation with the specified constant
|
||||||
/// operand values. If successful, this returns false and fills in the
|
/// operand values. If successful, this returns false and fills in the
|
||||||
/// results vector. If not, this returns true and results is unspecified.
|
/// results vector. If not, this returns true and results is unspecified.
|
||||||
bool Operation::constantFold(ArrayRef<Attribute *> operands,
|
bool Operation::constantFold(ArrayRef<Attribute> operands,
|
||||||
SmallVectorImpl<Attribute *> &results) const {
|
SmallVectorImpl<Attribute> &results) const {
|
||||||
// If we have a registered operation definition matching this one, use it to
|
// If we have a registered operation definition matching this one, use it to
|
||||||
// try to constant fold the operation.
|
// try to constant fold the operation.
|
||||||
if (auto *abstractOp = getAbstractOperation())
|
if (auto *abstractOp = getAbstractOperation())
|
||||||
|
|
|
@ -195,7 +195,7 @@ public:
|
||||||
// Attribute parsing.
|
// Attribute parsing.
|
||||||
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||||
FunctionType *type);
|
FunctionType *type);
|
||||||
Attribute *parseAttribute();
|
Attribute parseAttribute();
|
||||||
|
|
||||||
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
|
||||||
|
|
||||||
|
@ -204,8 +204,8 @@ public:
|
||||||
AffineMap parseAffineMapReference();
|
AffineMap parseAffineMapReference();
|
||||||
IntegerSet parseIntegerSetInline();
|
IntegerSet parseIntegerSetInline();
|
||||||
IntegerSet parseIntegerSetReference();
|
IntegerSet parseIntegerSetReference();
|
||||||
DenseElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
|
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
|
||||||
DenseElementsAttr *parseDenseElementsAttr(Type *eltType, bool isVector);
|
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
|
||||||
VectorOrTensorType *parseVectorOrTensorType();
|
VectorOrTensorType *parseVectorOrTensorType();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -684,7 +684,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
||||||
case Token::floatliteral:
|
case Token::floatliteral:
|
||||||
case Token::integer:
|
case Token::integer:
|
||||||
case Token::minus: {
|
case Token::minus: {
|
||||||
auto *result = p.parseAttribute();
|
auto result = p.parseAttribute();
|
||||||
if (!result)
|
if (!result)
|
||||||
return p.emitError("expected tensor element");
|
return p.emitError("expected tensor element");
|
||||||
// check result matches the element type.
|
// check result matches the element type.
|
||||||
|
@ -693,16 +693,16 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
||||||
case Type::Kind::F16:
|
case Type::Kind::F16:
|
||||||
case Type::Kind::F32:
|
case Type::Kind::F32:
|
||||||
case Type::Kind::F64: {
|
case Type::Kind::F64: {
|
||||||
if (!isa<FloatAttr>(result))
|
if (!result.isa<FloatAttr>())
|
||||||
return p.emitError("expected tensor literal element has float type");
|
return p.emitError("expected tensor literal element has float type");
|
||||||
double value = cast<FloatAttr>(result)->getDouble();
|
double value = result.cast<FloatAttr>().getDouble();
|
||||||
addToStorage(*(uint64_t *)(&value));
|
addToStorage(*(uint64_t *)(&value));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case Type::Kind::Integer: {
|
case Type::Kind::Integer: {
|
||||||
if (!isa<IntegerAttr>(result))
|
if (!result.isa<IntegerAttr>())
|
||||||
return p.emitError("expected tensor literal element has integer type");
|
return p.emitError("expected tensor literal element has integer type");
|
||||||
auto value = cast<IntegerAttr>(result)->getValue();
|
auto value = result.cast<IntegerAttr>().getValue();
|
||||||
// If we couldn't successfully round trip the value, it means some bits
|
// If we couldn't successfully round trip the value, it means some bits
|
||||||
// are truncated and we should give up here.
|
// are truncated and we should give up here.
|
||||||
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
|
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
|
||||||
|
@ -804,7 +804,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
|
||||||
/// | `sparse<` (tensor-type | vector-type)`,`
|
/// | `sparse<` (tensor-type | vector-type)`,`
|
||||||
/// attribute-value`, ` attribute-value `>`
|
/// attribute-value`, ` attribute-value `>`
|
||||||
///
|
///
|
||||||
Attribute *Parser::parseAttribute() {
|
Attribute Parser::parseAttribute() {
|
||||||
switch (getToken().getKind()) {
|
switch (getToken().getKind()) {
|
||||||
case Token::kw_true:
|
case Token::kw_true:
|
||||||
consumeToken(Token::kw_true);
|
consumeToken(Token::kw_true);
|
||||||
|
@ -859,7 +859,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
|
|
||||||
case Token::l_square: {
|
case Token::l_square: {
|
||||||
consumeToken(Token::l_square);
|
consumeToken(Token::l_square);
|
||||||
SmallVector<Attribute *, 4> elements;
|
SmallVector<Attribute, 4> elements;
|
||||||
|
|
||||||
auto parseElt = [&]() -> ParseResult {
|
auto parseElt = [&]() -> ParseResult {
|
||||||
elements.push_back(parseAttribute());
|
elements.push_back(parseAttribute());
|
||||||
|
@ -928,7 +928,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
case Token::floatliteral:
|
case Token::floatliteral:
|
||||||
case Token::integer:
|
case Token::integer:
|
||||||
case Token::minus: {
|
case Token::minus: {
|
||||||
auto *scalar = parseAttribute();
|
auto scalar = parseAttribute();
|
||||||
if (parseToken(Token::greater, "expected '>'"))
|
if (parseToken(Token::greater, "expected '>'"))
|
||||||
return nullptr;
|
return nullptr;
|
||||||
return builder.getSplatElementsAttr(type, scalar);
|
return builder.getSplatElementsAttr(type, scalar);
|
||||||
|
@ -973,7 +973,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
case Token::l_square: {
|
case Token::l_square: {
|
||||||
/// Parse indices
|
/// Parse indices
|
||||||
auto *indicesEltType = builder.getIntegerType(32);
|
auto *indicesEltType = builder.getIntegerType(32);
|
||||||
auto *indices =
|
auto indices =
|
||||||
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
|
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
|
||||||
|
|
||||||
if (parseToken(Token::comma, "expected ','"))
|
if (parseToken(Token::comma, "expected ','"))
|
||||||
|
@ -981,12 +981,12 @@ Attribute *Parser::parseAttribute() {
|
||||||
|
|
||||||
/// Parse values.
|
/// Parse values.
|
||||||
auto *valuesEltType = type->getElementType();
|
auto *valuesEltType = type->getElementType();
|
||||||
auto *values =
|
auto values =
|
||||||
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
|
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
|
||||||
|
|
||||||
/// Sanity check.
|
/// Sanity check.
|
||||||
auto *indicesType = indices->getType();
|
auto *indicesType = indices.getType();
|
||||||
auto *valuesType = values->getType();
|
auto *valuesType = values.getType();
|
||||||
auto sameShape = (indicesType->getRank() == 1) ||
|
auto sameShape = (indicesType->getRank() == 1) ||
|
||||||
(type->getRank() == indicesType->getDimSize(1));
|
(type->getRank() == indicesType->getDimSize(1));
|
||||||
auto sameElementNum =
|
auto sameElementNum =
|
||||||
|
@ -1009,7 +1009,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
|
|
||||||
// Build the sparse elements attribute by the indices and values.
|
// Build the sparse elements attribute by the indices and values.
|
||||||
return builder.getSparseElementsAttr(
|
return builder.getSparseElementsAttr(
|
||||||
type, cast<DenseIntElementsAttr>(indices), values);
|
type, indices.cast<DenseIntElementsAttr>(), values);
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return (emitError("expected '[' to start sparse tensor literal"),
|
return (emitError("expected '[' to start sparse tensor literal"),
|
||||||
|
@ -1035,8 +1035,7 @@ Attribute *Parser::parseAttribute() {
|
||||||
///
|
///
|
||||||
/// This method returns a constructed dense elements attribute with the shape
|
/// This method returns a constructed dense elements attribute with the shape
|
||||||
/// from the parsing result.
|
/// from the parsing result.
|
||||||
DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
|
||||||
bool isVector) {
|
|
||||||
TensorLiteralParser literalParser(*this, eltType);
|
TensorLiteralParser literalParser(*this, eltType);
|
||||||
if (literalParser.parse())
|
if (literalParser.parse())
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -1047,8 +1046,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
||||||
} else {
|
} else {
|
||||||
type = builder.getTensorType(literalParser.getShape(), eltType);
|
type = builder.getTensorType(literalParser.getShape(), eltType);
|
||||||
}
|
}
|
||||||
return (DenseElementsAttr *)builder.getDenseElementsAttr(
|
return builder.getDenseElementsAttr(type, literalParser.getValues())
|
||||||
type, literalParser.getValues());
|
.cast<DenseElementsAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dense elements attribute.
|
/// Dense elements attribute.
|
||||||
|
@ -1061,7 +1060,7 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
|
||||||
/// This method compares the shapes from the parsing result and that from the
|
/// This method compares the shapes from the parsing result and that from the
|
||||||
/// input argument. It returns a constructed dense elements attribute if both
|
/// input argument. It returns a constructed dense elements attribute if both
|
||||||
/// match.
|
/// match.
|
||||||
DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||||
auto *eltTy = type->getElementType();
|
auto *eltTy = type->getElementType();
|
||||||
TensorLiteralParser literalParser(*this, eltTy);
|
TensorLiteralParser literalParser(*this, eltTy);
|
||||||
if (literalParser.parse())
|
if (literalParser.parse())
|
||||||
|
@ -1076,8 +1075,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
|
||||||
s << "])";
|
s << "])";
|
||||||
return (emitError(s.str()), nullptr);
|
return (emitError(s.str()), nullptr);
|
||||||
}
|
}
|
||||||
return (DenseElementsAttr *)builder.getDenseElementsAttr(
|
return builder.getDenseElementsAttr(type, literalParser.getValues())
|
||||||
type, literalParser.getValues());
|
.cast<DenseElementsAttr>();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Vector or tensor type for elements attribute.
|
/// Vector or tensor type for elements attribute.
|
||||||
|
@ -2133,7 +2132,7 @@ public:
|
||||||
/// Parse an arbitrary attribute and return it in result. This also adds
|
/// Parse an arbitrary attribute and return it in result. This also adds
|
||||||
/// the attribute to the specified attribute list with the specified name.
|
/// the attribute to the specified attribute list with the specified name.
|
||||||
/// this captures the location of the attribute in 'loc' if it is non-null.
|
/// this captures the location of the attribute in 'loc' if it is non-null.
|
||||||
bool parseAttribute(Attribute *&result, const char *attrName,
|
bool parseAttribute(Attribute &result, const char *attrName,
|
||||||
SmallVectorImpl<NamedAttribute> &attrs) override {
|
SmallVectorImpl<NamedAttribute> &attrs) override {
|
||||||
result = parser.parseAttribute();
|
result = parser.parseAttribute();
|
||||||
if (!result)
|
if (!result)
|
||||||
|
@ -3336,27 +3335,27 @@ ParseResult ModuleParser::parseMLFunc() {
|
||||||
/// Given an attribute that could refer to a function attribute in the
|
/// Given an attribute that could refer to a function attribute in the
|
||||||
/// remapping table, walk it and rewrite it to use the mapped function. If it
|
/// remapping table, walk it and rewrite it to use the mapped function. If it
|
||||||
/// doesn't refer to anything in the table, then it is returned unmodified.
|
/// doesn't refer to anything in the table, then it is returned unmodified.
|
||||||
static Attribute *
|
static Attribute
|
||||||
remapFunctionAttrs(Attribute *input,
|
remapFunctionAttrs(Attribute input,
|
||||||
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
|
DenseMap<Attribute, FunctionAttr> &remappingTable,
|
||||||
MLIRContext *context) {
|
MLIRContext *context) {
|
||||||
// Most attributes are trivially unrelated to function attributes, skip them
|
// Most attributes are trivially unrelated to function attributes, skip them
|
||||||
// rapidly.
|
// rapidly.
|
||||||
if (!input->isOrContainsFunction())
|
if (!input.isOrContainsFunction())
|
||||||
return input;
|
return input;
|
||||||
|
|
||||||
// If we have a function attribute, remap it.
|
// If we have a function attribute, remap it.
|
||||||
if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) {
|
if (auto fnAttr = input.dyn_cast<FunctionAttr>()) {
|
||||||
auto it = remappingTable.find(fnAttr);
|
auto it = remappingTable.find(fnAttr);
|
||||||
return it != remappingTable.end() ? it->second : input;
|
return it != remappingTable.end() ? it->second : input;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, we must have an array attribute, remap the elements.
|
// Otherwise, we must have an array attribute, remap the elements.
|
||||||
auto *arrayAttr = cast<ArrayAttr>(input);
|
auto arrayAttr = input.cast<ArrayAttr>();
|
||||||
SmallVector<Attribute *, 8> remappedElts;
|
SmallVector<Attribute, 8> remappedElts;
|
||||||
bool anyChange = false;
|
bool anyChange = false;
|
||||||
for (auto *elt : arrayAttr->getValue()) {
|
for (auto elt : arrayAttr.getValue()) {
|
||||||
auto *newElt = remapFunctionAttrs(elt, remappingTable, context);
|
auto newElt = remapFunctionAttrs(elt, remappingTable, context);
|
||||||
remappedElts.push_back(newElt);
|
remappedElts.push_back(newElt);
|
||||||
anyChange |= (elt != newElt);
|
anyChange |= (elt != newElt);
|
||||||
}
|
}
|
||||||
|
@ -3370,11 +3369,11 @@ remapFunctionAttrs(Attribute *input,
|
||||||
/// Remap function attributes to resolve forward references to their actual
|
/// Remap function attributes to resolve forward references to their actual
|
||||||
/// definition.
|
/// definition.
|
||||||
static void remapFunctionAttrsInOperation(
|
static void remapFunctionAttrsInOperation(
|
||||||
Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) {
|
Operation *op, DenseMap<Attribute, FunctionAttr> &remappingTable) {
|
||||||
for (auto attr : op->getAttrs()) {
|
for (auto attr : op->getAttrs()) {
|
||||||
// Do the remapping, if we got the same thing back, then it must contain
|
// Do the remapping, if we got the same thing back, then it must contain
|
||||||
// functions that aren't getting remapped.
|
// functions that aren't getting remapped.
|
||||||
auto *newVal =
|
auto newVal =
|
||||||
remapFunctionAttrs(attr.second, remappingTable, op->getContext());
|
remapFunctionAttrs(attr.second, remappingTable, op->getContext());
|
||||||
if (newVal == attr.second)
|
if (newVal == attr.second)
|
||||||
continue;
|
continue;
|
||||||
|
@ -3391,7 +3390,7 @@ static void remapFunctionAttrsInOperation(
|
||||||
ParseResult ModuleParser::finalizeModule() {
|
ParseResult ModuleParser::finalizeModule() {
|
||||||
|
|
||||||
// Resolve all forward references, building a remapping table of attributes.
|
// Resolve all forward references, building a remapping table of attributes.
|
||||||
DenseMap<FunctionAttr *, FunctionAttr *> remappingTable;
|
DenseMap<Attribute, FunctionAttr> remappingTable;
|
||||||
for (auto forwardRef : getState().functionForwardRefs) {
|
for (auto forwardRef : getState().functionForwardRefs) {
|
||||||
auto name = forwardRef.first;
|
auto name = forwardRef.first;
|
||||||
|
|
||||||
|
@ -3428,13 +3427,13 @@ ParseResult ModuleParser::finalizeModule() {
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
struct MLFnWalker : public StmtWalker<MLFnWalker> {
|
||||||
MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable)
|
MLFnWalker(DenseMap<Attribute, FunctionAttr> &remappingTable)
|
||||||
: remappingTable(remappingTable) {}
|
: remappingTable(remappingTable) {}
|
||||||
void visitOperationStmt(OperationStmt *opStmt) {
|
void visitOperationStmt(OperationStmt *opStmt) {
|
||||||
remapFunctionAttrsInOperation(opStmt, remappingTable);
|
remapFunctionAttrsInOperation(opStmt, remappingTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable;
|
DenseMap<Attribute, FunctionAttr> &remappingTable;
|
||||||
};
|
};
|
||||||
|
|
||||||
MLFnWalker(remappingTable).walk(mlFn);
|
MLFnWalker(remappingTable).walk(mlFn);
|
||||||
|
|
|
@ -44,13 +44,13 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
|
||||||
// AddFOp
|
// AddFOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "addf takes two operands");
|
assert(operands.size() == 2 && "addf takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||||
return FloatAttr::get(lhs->getValue() + rhs->getValue(), context);
|
return FloatAttr::get(lhs.getValue() + rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -60,13 +60,13 @@ Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||||
// AddIOp
|
// AddIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "addi takes two operands");
|
assert(operands.size() == 2 && "addi takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||||
return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context);
|
return IntegerAttr::get(lhs.getValue() + rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -192,12 +192,12 @@ void CallOp::print(OpAsmPrinter *p) const {
|
||||||
|
|
||||||
bool CallOp::verify() const {
|
bool CallOp::verify() const {
|
||||||
// Check that the callee attribute was specified.
|
// Check that the callee attribute was specified.
|
||||||
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
|
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
|
||||||
if (!fnAttr)
|
if (!fnAttr)
|
||||||
return emitOpError("requires a 'callee' function attribute");
|
return emitOpError("requires a 'callee' function attribute");
|
||||||
|
|
||||||
// Verify that the operand and result types match the callee.
|
// Verify that the operand and result types match the callee.
|
||||||
auto *fnType = fnAttr->getValue()->getType();
|
auto *fnType = fnAttr.getValue()->getType();
|
||||||
if (fnType->getNumInputs() != getNumOperands())
|
if (fnType->getNumInputs() != getNumOperands())
|
||||||
return emitOpError("incorrect number of operands for callee");
|
return emitOpError("incorrect number of operands for callee");
|
||||||
|
|
||||||
|
@ -329,7 +329,7 @@ void DimOp::print(OpAsmPrinter *p) const {
|
||||||
|
|
||||||
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||||
OpAsmParser::OperandType operandInfo;
|
OpAsmParser::OperandType operandInfo;
|
||||||
IntegerAttr *indexAttr;
|
IntegerAttr indexAttr;
|
||||||
Type *type;
|
Type *type;
|
||||||
|
|
||||||
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
return parser->parseOperand(operandInfo) || parser->parseComma() ||
|
||||||
|
@ -346,7 +346,7 @@ bool DimOp::verify() const {
|
||||||
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
||||||
if (!indexAttr)
|
if (!indexAttr)
|
||||||
return emitOpError("requires an integer attribute named 'index'");
|
return emitOpError("requires an integer attribute named 'index'");
|
||||||
uint64_t index = (uint64_t)indexAttr->getValue();
|
uint64_t index = (uint64_t)indexAttr.getValue();
|
||||||
|
|
||||||
auto *type = getOperand()->getType();
|
auto *type = getOperand()->getType();
|
||||||
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
|
||||||
|
@ -365,8 +365,8 @@ bool DimOp::verify() const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
// Constant fold dim when the size along the index referred to is a constant.
|
// Constant fold dim when the size along the index referred to is a constant.
|
||||||
auto *opType = getOperand()->getType();
|
auto *opType = getOperand()->getType();
|
||||||
int indexSize = -1;
|
int indexSize = -1;
|
||||||
|
@ -671,13 +671,13 @@ bool MemRefCastOp::verify() const {
|
||||||
// MulFOp
|
// MulFOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "mulf takes two operands");
|
assert(operands.size() == 2 && "mulf takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||||
return FloatAttr::get(lhs->getValue() * rhs->getValue(), context);
|
return FloatAttr::get(lhs.getValue() * rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -687,23 +687,23 @@ Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||||
// MulIOp
|
// MulIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "muli takes two operands");
|
assert(operands.size() == 2 && "muli takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||||
// 0*x == 0
|
// 0*x == 0
|
||||||
if (lhs->getValue() == 0)
|
if (lhs.getValue() == 0)
|
||||||
return lhs;
|
return lhs;
|
||||||
|
|
||||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||||
// TODO: Handle the overflow case.
|
// TODO: Handle the overflow case.
|
||||||
return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context);
|
return IntegerAttr::get(lhs.getValue() * rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// x*0 == 0
|
// x*0 == 0
|
||||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||||
if (rhs->getValue() == 0)
|
if (rhs.getValue() == 0)
|
||||||
return rhs;
|
return rhs;
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -817,13 +817,13 @@ bool StoreOp::verify() const {
|
||||||
// SubFOp
|
// SubFOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "subf takes two operands");
|
assert(operands.size() == 2 && "subf takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
|
||||||
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
|
||||||
return FloatAttr::get(lhs->getValue() - rhs->getValue(), context);
|
return FloatAttr::get(lhs.getValue() - rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -833,13 +833,13 @@ Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
|
||||||
// SubIOp
|
// SubIOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
|
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
|
||||||
MLIRContext *context) const {
|
MLIRContext *context) const {
|
||||||
assert(operands.size() == 2 && "subi takes two operands");
|
assert(operands.size() == 2 && "subi takes two operands");
|
||||||
|
|
||||||
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
|
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
|
||||||
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
|
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
|
||||||
return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context);
|
return IntegerAttr::get(lhs.getValue() - rhs.getValue(), context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
|
||||||
SmallVector<SSAValue *, 8> existingConstants;
|
SmallVector<SSAValue *, 8> existingConstants;
|
||||||
// Operation statements that were folded and that need to be erased.
|
// Operation statements that were folded and that need to be erased.
|
||||||
std::vector<OperationStmt *> opStmtsToErase;
|
std::vector<OperationStmt *> opStmtsToErase;
|
||||||
using ConstantFactoryType = std::function<SSAValue *(Attribute *, Type *)>;
|
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
|
||||||
|
|
||||||
bool foldOperation(Operation *op,
|
bool foldOperation(Operation *op,
|
||||||
SmallVectorImpl<SSAValue *> &existingConstants,
|
SmallVectorImpl<SSAValue *> &existingConstants,
|
||||||
|
@ -60,9 +60,9 @@ bool ConstantFold::foldOperation(Operation *op,
|
||||||
|
|
||||||
// Check to see if each of the operands is a trivial constant. If so, get
|
// Check to see if each of the operands is a trivial constant. If so, get
|
||||||
// the value. If not, ignore the instruction.
|
// the value. If not, ignore the instruction.
|
||||||
SmallVector<Attribute *, 8> operandConstants;
|
SmallVector<Attribute, 8> operandConstants;
|
||||||
for (auto *operand : op->getOperands()) {
|
for (auto *operand : op->getOperands()) {
|
||||||
Attribute *operandCst = nullptr;
|
Attribute operandCst = nullptr;
|
||||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||||
operandCst = operandConstantOp->getValue();
|
operandCst = operandConstantOp->getValue();
|
||||||
|
@ -71,7 +71,7 @@ bool ConstantFold::foldOperation(Operation *op,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to constant fold the operation.
|
// Attempt to constant fold the operation.
|
||||||
SmallVector<Attribute *, 8> resultConstants;
|
SmallVector<Attribute, 8> resultConstants;
|
||||||
if (op->constantFold(operandConstants, resultConstants))
|
if (op->constantFold(operandConstants, resultConstants))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||||
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
|
||||||
auto &inst = *instIt++;
|
auto &inst = *instIt++;
|
||||||
|
|
||||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||||
builder.setInsertionPoint(&inst);
|
builder.setInsertionPoint(&inst);
|
||||||
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
return builder.create<ConstantOp>(inst.getLoc(), value, type);
|
||||||
};
|
};
|
||||||
|
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
|
||||||
|
|
||||||
// Override the walker's operation statement visit for constant folding.
|
// Override the walker's operation statement visit for constant folding.
|
||||||
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
|
||||||
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
|
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
|
||||||
MLFuncBuilder builder(stmt);
|
MLFuncBuilder builder(stmt);
|
||||||
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
|
||||||
};
|
};
|
||||||
|
|
|
@ -71,8 +71,8 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
|
||||||
|
|
||||||
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
|
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
|
||||||
for (auto attr : opStmt->getAttrs()) {
|
for (auto attr : opStmt->getAttrs()) {
|
||||||
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
|
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
|
||||||
MutableAffineMap mMap(mapAttr->getValue());
|
MutableAffineMap mMap(mapAttr.getValue());
|
||||||
mMap.simplify();
|
mMap.simplify();
|
||||||
auto map = mMap.getAffineMap();
|
auto map = mMap.getAffineMap();
|
||||||
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
opStmt->setAttr(attr.first, AffineMapAttr::get(map));
|
||||||
|
|
|
@ -79,7 +79,7 @@ private:
|
||||||
/// As part of canonicalization, we move constants to the top of the entry
|
/// As part of canonicalization, we move constants to the top of the entry
|
||||||
/// block of the current function and de-duplicate them. This keeps track of
|
/// block of the current function and de-duplicate them. This keeps track of
|
||||||
/// constants we have done this for.
|
/// constants we have done this for.
|
||||||
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
|
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
|
||||||
};
|
};
|
||||||
}; // end anonymous namespace
|
}; // end anonymous namespace
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ public:
|
||||||
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
||||||
WorklistRewriter &rewriter) {
|
WorklistRewriter &rewriter) {
|
||||||
// These are scratch vectors used in the constant folding loop below.
|
// These are scratch vectors used in the constant folding loop below.
|
||||||
SmallVector<Attribute *, 8> operandConstants, resultConstants;
|
SmallVector<Attribute, 8> operandConstants, resultConstants;
|
||||||
|
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
auto *op = popFromWorklist();
|
auto *op = popFromWorklist();
|
||||||
|
@ -175,7 +175,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
|
||||||
// the operation knows how to constant fold itself.
|
// the operation knows how to constant fold itself.
|
||||||
operandConstants.clear();
|
operandConstants.clear();
|
||||||
for (auto *operand : op->getOperands()) {
|
for (auto *operand : op->getOperands()) {
|
||||||
Attribute *operandCst = nullptr;
|
Attribute operandCst;
|
||||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||||
operandCst = operandConstantOp->getValue();
|
operandCst = operandConstantOp->getValue();
|
||||||
|
|
|
@ -353,11 +353,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
||||||
|
|
||||||
// Check to see if each of the operands is the result of a constant. If so,
|
// Check to see if each of the operands is the result of a constant. If so,
|
||||||
// get the value. If not, ignore it.
|
// get the value. If not, ignore it.
|
||||||
SmallVector<Attribute *, 8> operandConstants;
|
SmallVector<Attribute, 8> operandConstants;
|
||||||
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
|
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
|
||||||
: forStmt->getUpperBoundOperands();
|
: forStmt->getUpperBoundOperands();
|
||||||
for (const auto *operand : boundOperands) {
|
for (const auto *operand : boundOperands) {
|
||||||
Attribute *operandCst = nullptr;
|
Attribute operandCst;
|
||||||
if (auto *operandOp = operand->getDefiningOperation()) {
|
if (auto *operandOp = operand->getDefiningOperation()) {
|
||||||
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
|
||||||
operandCst = operandConstantOp->getValue();
|
operandCst = operandConstantOp->getValue();
|
||||||
|
@ -369,15 +369,15 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
||||||
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
|
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
|
||||||
assert(boundMap.getNumResults() >= 1 &&
|
assert(boundMap.getNumResults() >= 1 &&
|
||||||
"bound maps should have at least one result");
|
"bound maps should have at least one result");
|
||||||
SmallVector<Attribute *, 4> foldedResults;
|
SmallVector<Attribute, 4> foldedResults;
|
||||||
if (boundMap.constantFold(operandConstants, foldedResults))
|
if (boundMap.constantFold(operandConstants, foldedResults))
|
||||||
return true;
|
return true;
|
||||||
|
|
||||||
// Compute the max or min as applicable over the results.
|
// Compute the max or min as applicable over the results.
|
||||||
assert(!foldedResults.empty() && "bounds should have at least one result");
|
assert(!foldedResults.empty() && "bounds should have at least one result");
|
||||||
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
|
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
|
||||||
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
||||||
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
|
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
|
||||||
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
||||||
: std::min(maxOrMin, foldedResult);
|
: std::min(maxOrMin, foldedResult);
|
||||||
}
|
}
|
||||||
|
|
|
@ -154,7 +154,7 @@ void OpEmitter::emitAttrGetters() {
|
||||||
<< val.getName() << "() const {\n";
|
<< val.getName() << "() const {\n";
|
||||||
os << " return this->getAttrOfType<"
|
os << " return this->getAttrOfType<"
|
||||||
<< attr.getValueAsString("AttrType").trim() << ">(\"" << name
|
<< attr.getValueAsString("AttrType").trim() << ">(\"" << name
|
||||||
<< "\")->getValue();\n }\n";
|
<< "\").getValue();\n }\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -207,9 +207,9 @@ void OpEmitter::emitVerifier() {
|
||||||
// Verify the attributes have the correct type.
|
// Verify the attributes have the correct type.
|
||||||
for (const auto attr : attrs) {
|
for (const auto attr : attrs) {
|
||||||
auto name = attr.first->getName();
|
auto name = attr.first->getName();
|
||||||
os << " if (!dyn_cast_or_null<"
|
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
|
||||||
<< attr.second->getValueAsString("AttrType") << ">(this->getAttr(\""
|
<< attr.second->getValueAsString("AttrType") << ">("
|
||||||
<< name << "\"))) return emitOpError(\"requires "
|
<< ")) return emitOpError(\"requires "
|
||||||
<< attr.second->getValueAsString("PrimitiveType").trim()
|
<< attr.second->getValueAsString("PrimitiveType").trim()
|
||||||
<< " attribute '" << name << "'\");\n";
|
<< " attribute '" << name << "'\");\n";
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue