Implement value type abstraction for attributes.

This is done by changing Attribute to be a POD interface around an underlying pointer storage and adding in-class support for isa/dyn_cast/cast.

PiperOrigin-RevId: 218764173
This commit is contained in:
River Riddle 2018-10-25 15:46:10 -07:00 committed by jpienaar
parent 64d52014bd
commit 792d1c25e4
26 changed files with 971 additions and 692 deletions

View File

@ -95,8 +95,8 @@ public:
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
bool constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results) const;
bool constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const;
friend ::llvm::hash_code hash_value(AffineMap arg);

View File

@ -30,10 +30,32 @@ class MLIRContext;
class Type;
class VectorOrTensorType;
namespace detail {
struct AttributeStorage;
struct BoolAttributeStorage;
struct IntegerAttributeStorage;
struct FloatAttributeStorage;
struct StringAttributeStorage;
struct ArrayAttributeStorage;
struct AffineMapAttributeStorage;
struct TypeAttributeStorage;
struct FunctionAttributeStorage;
struct ElementsAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
struct DenseFPElementsAttributeStorage;
struct OpaqueElementsAttributeStorage;
struct SparseElementsAttributeStorage;
} // namespace detail
/// Attributes are known-constant values of operations and functions.
///
/// Instances of the Attribute class are immutable, uniqued, immortal, and owned
/// by MLIRContext. As such, they are passed around by raw non-const pointer.
/// by MLIRContext. As such, an Attribute is a POD interface to an underlying
/// storage pointer.
class Attribute {
public:
enum class Kind {
@ -55,177 +77,151 @@ public:
LAST_ELEMENTS_ATTR = SparseElements,
};
typedef detail::AttributeStorage ImplType;
Attribute() : attr(nullptr) {}
/* implicit */ Attribute(const ImplType *attr)
: attr(const_cast<ImplType *>(attr)) {}
Attribute(const Attribute &other) : attr(other.attr) {}
Attribute &operator=(Attribute other) {
attr = other.attr;
return *this;
}
bool operator==(Attribute other) const { return attr == other.attr; }
bool operator!=(Attribute other) const { return !(*this == other); }
explicit operator bool() const { return attr; }
bool operator!() const { return attr == nullptr; }
template <typename U> bool isa() const;
template <typename U> U dyn_cast() const;
template <typename U> U dyn_cast_or_null() const;
template <typename U> U cast() const;
/// Return the classification for this attribute.
Kind getKind() const { return kind; }
Kind getKind() const;
/// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const { return isOrContainsFunctionCache; }
bool isOrContainsFunction() const;
/// Print the attribute.
void print(raw_ostream &os) const;
void dump() const;
friend ::llvm::hash_code hash_value(Attribute arg);
protected:
explicit Attribute(Kind kind, bool isOrContainsFunction)
: kind(kind), isOrContainsFunctionCache(isOrContainsFunction) {}
~Attribute() {}
private:
/// Classification of the subclass, used for type checking.
Kind kind : 8;
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
Attribute(const Attribute &) = delete;
void operator=(const Attribute &) = delete;
ImplType *attr;
};
inline raw_ostream &operator<<(raw_ostream &os, const Attribute &attr) {
inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
attr.print(os);
return os;
}
class BoolAttr : public Attribute {
public:
static BoolAttr *get(bool value, MLIRContext *context);
typedef detail::BoolAttributeStorage ImplType;
BoolAttr() = default;
/* implicit */ BoolAttr(Attribute::ImplType *ptr);
bool getValue() const { return value; }
static BoolAttr get(bool value, MLIRContext *context);
bool getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Bool;
}
private:
BoolAttr(bool value)
: Attribute(Kind::Bool, /*isOrContainsFunction=*/false), value(value) {}
~BoolAttr() = delete;
bool value;
static bool kindof(Kind kind) { return kind == Kind::Bool; }
};
class IntegerAttr : public Attribute {
public:
static IntegerAttr *get(int64_t value, MLIRContext *context);
typedef detail::IntegerAttributeStorage ImplType;
IntegerAttr() = default;
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
int64_t getValue() const { return value; }
static IntegerAttr get(int64_t value, MLIRContext *context);
int64_t getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Integer;
}
private:
IntegerAttr(int64_t value)
: Attribute(Kind::Integer, /*isOrContainsFunction=*/false), value(value) {
}
~IntegerAttr() = delete;
int64_t value;
static bool kindof(Kind kind) { return kind == Kind::Integer; }
};
class FloatAttr final : public Attribute,
public llvm::TrailingObjects<FloatAttr, uint64_t> {
class FloatAttr final : public Attribute {
public:
static FloatAttr *get(double value, MLIRContext *context);
static FloatAttr *get(const APFloat &value, MLIRContext *context);
typedef detail::FloatAttributeStorage ImplType;
FloatAttr() = default;
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
static FloatAttr get(double value, MLIRContext *context);
static FloatAttr get(const APFloat &value, MLIRContext *context);
APFloat getValue() const;
double getDouble() const { return getValue().convertToDouble(); }
double getDouble() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Float;
}
private:
FloatAttr(const llvm::fltSemantics &semantics, size_t numObjects)
: Attribute(Kind::Float, /*isOrContainsFunction=*/false),
semantics(semantics), numObjects(numObjects) {}
FloatAttr(const FloatAttr &value) = delete;
~FloatAttr() = delete;
size_t numTrailingObjects(OverloadToken<uint64_t>) const {
return numObjects;
}
const llvm::fltSemantics &semantics;
size_t numObjects;
static bool kindof(Kind kind) { return kind == Kind::Float; }
};
class StringAttr : public Attribute {
public:
static StringAttr *get(StringRef bytes, MLIRContext *context);
typedef detail::StringAttributeStorage ImplType;
StringAttr() = default;
/* implicit */ StringAttr(Attribute::ImplType *ptr);
StringRef getValue() const { return value; }
static StringAttr get(StringRef bytes, MLIRContext *context);
StringRef getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::String;
}
private:
StringAttr(StringRef value)
: Attribute(Kind::String, /*isOrContainsFunction=*/false), value(value) {}
~StringAttr() = delete;
StringRef value;
static bool kindof(Kind kind) { return kind == Kind::String; }
};
/// Array attributes are lists of other attributes. They are not necessarily
/// type homogenous given that attributes don't, in general, carry types.
class ArrayAttr : public Attribute {
public:
static ArrayAttr *get(ArrayRef<Attribute *> value, MLIRContext *context);
typedef detail::ArrayAttributeStorage ImplType;
ArrayAttr() = default;
/* implicit */ ArrayAttr(Attribute::ImplType *ptr);
ArrayRef<Attribute *> getValue() const { return value; }
static ArrayAttr get(ArrayRef<Attribute> value, MLIRContext *context);
ArrayRef<Attribute> getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Array;
}
private:
ArrayAttr(ArrayRef<Attribute *> value, bool isOrContainsFunction)
: Attribute(Kind::Array, isOrContainsFunction), value(value) {}
~ArrayAttr() = delete;
ArrayRef<Attribute *> value;
static bool kindof(Kind kind) { return kind == Kind::Array; }
};
class AffineMapAttr : public Attribute {
public:
static AffineMapAttr *get(AffineMap value);
typedef detail::AffineMapAttributeStorage ImplType;
AffineMapAttr() = default;
/* implicit */ AffineMapAttr(Attribute::ImplType *ptr);
AffineMap getValue() const { return value; }
static AffineMapAttr get(AffineMap value);
AffineMap getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::AffineMap;
}
private:
AffineMapAttr(AffineMap value)
: Attribute(Kind::AffineMap, /*isOrContainsFunction=*/false),
value(value) {}
~AffineMapAttr() = delete;
AffineMap value;
static bool kindof(Kind kind) { return kind == Kind::AffineMap; }
};
class TypeAttr : public Attribute {
public:
static TypeAttr *get(Type *type, MLIRContext *context);
typedef detail::TypeAttributeStorage ImplType;
TypeAttr() = default;
/* implicit */ TypeAttr(Attribute::ImplType *ptr);
Type *getValue() const { return value; }
static TypeAttr get(Type *type, MLIRContext *context);
Type *getValue() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Type;
}
private:
TypeAttr(Type *value)
: Attribute(Kind::Type, /*isOrContainsFunction=*/false), value(value) {}
~TypeAttr() = delete;
Type *value;
static bool kindof(Kind kind) { return kind == Kind::Type; }
};
/// A function attribute represents a reference to a function object.
@ -237,63 +233,53 @@ private:
/// remain in MLIRContext.
class FunctionAttr : public Attribute {
public:
static FunctionAttr *get(const Function *value, MLIRContext *context);
typedef detail::FunctionAttributeStorage ImplType;
FunctionAttr() = default;
/* implicit */ FunctionAttr(Attribute::ImplType *ptr);
Function *getValue() const { return value; }
static FunctionAttr get(const Function *value, MLIRContext *context);
Function *getValue() const;
FunctionType *getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::Function;
}
static bool kindof(Kind kind) { return kind == Kind::Function; }
/// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted.
static void dropFunctionReference(Function *value);
private:
FunctionAttr(Function *value)
: Attribute(Kind::Function, /*isOrContainsFunction=*/true), value(value) {
}
~FunctionAttr() = delete;
Function *value;
};
/// A base attribute represents a reference to a vector or tensor constant.
class ElementsAttr : public Attribute {
public:
ElementsAttr(Kind kind, VectorOrTensorType *type)
: Attribute(kind, /*isOrContainsFunction=*/false), type(type) {}
typedef detail::ElementsAttributeStorage ImplType;
ElementsAttr() = default;
/* implicit */ ElementsAttr(Attribute::ImplType *ptr);
VectorOrTensorType *getType() const { return type; }
VectorOrTensorType *getType() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() >= Kind::FIRST_ELEMENTS_ATTR &&
attr->getKind() <= Kind::LAST_ELEMENTS_ATTR;
static bool kindof(Kind kind) {
return kind >= Kind::FIRST_ELEMENTS_ATTR &&
kind <= Kind::LAST_ELEMENTS_ATTR;
}
private:
VectorOrTensorType *type;
};
/// An attribute represents a reference to a splat vecctor or tensor constant,
/// meaning all of the elements have the same value.
class SplatElementsAttr : public ElementsAttr {
public:
static SplatElementsAttr *get(VectorOrTensorType *type, Attribute *elt);
Attribute *getValue() const { return elt; }
typedef detail::SplatElementsAttributeStorage ImplType;
SplatElementsAttr() = default;
/* implicit */ SplatElementsAttr(Attribute::ImplType *ptr);
static SplatElementsAttr get(VectorOrTensorType *type, Attribute elt);
Attribute getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::SplatElements;
}
private:
SplatElementsAttr(VectorOrTensorType *type, Attribute *elt)
: ElementsAttr(Kind::SplatElements, type), elt(elt) {}
Attribute *elt;
static bool kindof(Kind kind) { return kind == Kind::SplatElements; }
};
/// An attribute represents a reference to a dense vector or tensor object.
@ -302,42 +288,42 @@ private:
/// than 64.
class DenseElementsAttr : public ElementsAttr {
public:
typedef detail::DenseElementsAttributeStorage ImplType;
DenseElementsAttr() = default;
/* implicit */ DenseElementsAttr(Attribute::ImplType *ptr);
/// It assumes the elements in the input array have been truncated to the bits
/// width specified by the element type (note all float type are 64 bits).
/// When the value is retrieved, the bits are read from the storage and extend
/// to 64 bits if necessary.
static DenseElementsAttr *get(VectorOrTensorType *type, ArrayRef<char> data);
static DenseElementsAttr get(VectorOrTensorType *type, ArrayRef<char> data);
// TODO: Read the data from the attribute list and compress them
// to a character array. Then call the above method to construct the
// attribute.
static DenseElementsAttr *get(VectorOrTensorType *type,
ArrayRef<Attribute *> values);
static DenseElementsAttr get(VectorOrTensorType *type,
ArrayRef<Attribute> values);
void getValues(SmallVectorImpl<Attribute *> &values) const;
void getValues(SmallVectorImpl<Attribute> &values) const;
ArrayRef<char> getRawData() const { return data; }
ArrayRef<char> getRawData() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseIntElements ||
attr->getKind() == Kind::DenseFPElements;
static bool kindof(Kind kind) {
return kind == Kind::DenseIntElements || kind == Kind::DenseFPElements;
}
protected:
DenseElementsAttr(Kind kind, VectorOrTensorType *type, ArrayRef<char> data)
: ElementsAttr(kind, type), data(data) {}
private:
ArrayRef<char> data;
};
/// An attribute represents a reference to a dense integer vector or tensor
/// object.
class DenseIntElementsAttr : public DenseElementsAttr {
public:
typedef detail::DenseIntElementsAttributeStorage ImplType;
DenseIntElementsAttr() = default;
/* implicit */ DenseIntElementsAttr(Attribute::ImplType *ptr);
// TODO: returns APInts instead of IntegerAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
void getValues(SmallVectorImpl<Attribute> &values) const;
APInt getValue(ArrayRef<unsigned> indices) const;
@ -352,41 +338,24 @@ public:
size_t bitsWidth);
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseIntElements;
}
private:
friend class DenseElementsAttr;
DenseIntElementsAttr(VectorOrTensorType *type, ArrayRef<char> data,
size_t bitsWidth)
: DenseElementsAttr(Kind::DenseIntElements, type, data),
bitsWidth(bitsWidth) {}
~DenseIntElementsAttr() = delete;
size_t bitsWidth;
static bool kindof(Kind kind) { return kind == Kind::DenseIntElements; }
};
/// An attribute represents a reference to a dense float vector or tensor
/// object. Each element is stored as a double.
class DenseFPElementsAttr : public DenseElementsAttr {
public:
typedef detail::DenseFPElementsAttributeStorage ImplType;
DenseFPElementsAttr() = default;
/* implicit */ DenseFPElementsAttr(Attribute::ImplType *ptr);
// TODO: returns APFPs instead of FloatAttr.
void getValues(SmallVectorImpl<Attribute *> &values) const;
void getValues(SmallVectorImpl<Attribute> &values) const;
APFloat getValue(ArrayRef<unsigned> indices) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::DenseFPElements;
}
private:
friend class DenseElementsAttr;
DenseFPElementsAttr(VectorOrTensorType *type, ArrayRef<char> data)
: DenseElementsAttr(Kind::DenseFPElements, type, data) {}
~DenseFPElementsAttr() = delete;
static bool kindof(Kind kind) { return kind == Kind::DenseFPElements; }
};
/// An attribute represents a reference to a tensor constant with opaque
@ -394,20 +363,16 @@ private:
/// doesn't need to interpret.
class OpaqueElementsAttr : public ElementsAttr {
public:
static OpaqueElementsAttr *get(VectorOrTensorType *type, StringRef bytes);
typedef detail::OpaqueElementsAttributeStorage ImplType;
OpaqueElementsAttr() = default;
/* implicit */ OpaqueElementsAttr(Attribute::ImplType *ptr);
StringRef getValue() const { return bytes; }
static OpaqueElementsAttr get(VectorOrTensorType *type, StringRef bytes);
StringRef getValue() const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::OpaqueElements;
}
private:
OpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes)
: ElementsAttr(Kind::OpaqueElements, type), bytes(bytes) {}
~OpaqueElementsAttr() = delete;
StringRef bytes;
static bool kindof(Kind kind) { return kind == Kind::OpaqueElements; }
};
/// An attribute represents a reference to a sparse vector or tensor object.
@ -427,32 +392,67 @@ private:
/// [0, 0, 0, 0]].
class SparseElementsAttr : public ElementsAttr {
public:
static SparseElementsAttr *get(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values);
typedef detail::SparseElementsAttributeStorage ImplType;
SparseElementsAttr() = default;
/* implicit */ SparseElementsAttr(Attribute::ImplType *ptr);
DenseIntElementsAttr *getIndices() const { return indices; }
static SparseElementsAttr get(VectorOrTensorType *type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
DenseElementsAttr *getValues() const { return values; }
DenseIntElementsAttr getIndices() const;
DenseElementsAttr getValues() const;
/// Return the value at the given index.
Attribute *getValue(ArrayRef<unsigned> index) const;
Attribute getValue(ArrayRef<unsigned> index) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(const Attribute *attr) {
return attr->getKind() == Kind::SparseElements;
}
private:
SparseElementsAttr(VectorOrTensorType *type, DenseIntElementsAttr *indices,
DenseElementsAttr *values)
: ElementsAttr(Kind::SparseElements, type), indices(indices),
values(values) {}
~SparseElementsAttr() = delete;
DenseIntElementsAttr *const indices;
DenseElementsAttr *const values;
static bool kindof(Kind kind) { return kind == Kind::SparseElements; }
};
template <typename U> bool Attribute::isa() const {
assert(attr && "isa<> used on a null attribute.");
return U::kindof(getKind());
}
template <typename U> U Attribute::dyn_cast() const {
return isa<U>() ? U(attr) : U(nullptr);
}
template <typename U> U Attribute::dyn_cast_or_null() const {
return (attr && isa<U>()) ? U(attr) : U(nullptr);
}
template <typename U> U Attribute::cast() const {
assert(isa<U>());
return U(attr);
}
// Make Attribute hashable.
inline ::llvm::hash_code hash_value(Attribute arg) {
return ::llvm::hash_value(arg.attr);
}
} // end namespace mlir.
namespace llvm {
// Attribute hash just like pointers
template <> struct DenseMapInfo<mlir::Attribute> {
static mlir::Attribute getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static mlir::Attribute getTombstoneKey() {
auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
return mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer));
}
static unsigned getHashValue(mlir::Attribute val) {
return mlir::hash_value(val);
}
static bool isEqual(mlir::Attribute LHS, mlir::Attribute RHS) {
return LHS == RHS;
}
};
} // namespace llvm
#endif

View File

@ -93,23 +93,23 @@ public:
UnrankedTensorType *getTensorType(Type *elementType);
// Attributes.
BoolAttr *getBoolAttr(bool value);
IntegerAttr *getIntegerAttr(int64_t value);
FloatAttr *getFloatAttr(double value);
FloatAttr *getFloatAttr(const APFloat &value);
StringAttr *getStringAttr(StringRef bytes);
ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
AffineMapAttr *getAffineMapAttr(AffineMap map);
TypeAttr *getTypeAttr(Type *type);
FunctionAttr *getFunctionAttr(const Function *value);
ElementsAttr *getSplatElementsAttr(VectorOrTensorType *type, Attribute *elt);
ElementsAttr *getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data);
ElementsAttr *getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values);
ElementsAttr *getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes);
BoolAttr getBoolAttr(bool value);
IntegerAttr getIntegerAttr(int64_t value);
FloatAttr getFloatAttr(double value);
FloatAttr getFloatAttr(const APFloat &value);
StringAttr getStringAttr(StringRef bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map);
TypeAttr getTypeAttr(Type *type);
FunctionAttr getFunctionAttr(const Function *value);
ElementsAttr getSplatElementsAttr(VectorOrTensorType *type, Attribute elt);
ElementsAttr getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data);
ElementsAttr getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr indices,
DenseElementsAttr values);
ElementsAttr getOpaqueElementsAttr(VectorOrTensorType *type, StringRef bytes);
// Affine expressions and affine maps.
AffineExpr getAffineDimExpr(unsigned position);

View File

@ -60,7 +60,7 @@ public:
/// Returns the affine map to be applied by this operation.
AffineMap getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
return getAttrOfType<AffineMapAttr>("map").getValue();
}
/// Returns true if the result of this operation can be used as dimension id.
@ -75,8 +75,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
bool constantFold(ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results,
bool constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results,
MLIRContext *context) const;
private:
@ -94,10 +94,10 @@ class ConstantOp : public Op<ConstantOp, OpTrait::ZeroOperands,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
/// Builds a constant op with the specified attribute value and result type.
static void build(Builder *builder, OperationState *result, Attribute *value,
static void build(Builder *builder, OperationState *result, Attribute value,
Type *type);
Attribute *getValue() const { return getAttr("value"); }
Attribute getValue() const { return getAttr("value"); }
static StringRef getOperationName() { return "constant"; }
@ -105,8 +105,8 @@ public:
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
bool verify() const;
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
protected:
friend class Operation;
@ -125,7 +125,7 @@ public:
const APFloat &value, FloatType *type);
APFloat getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
return getAttrOfType<FloatAttr>("value").getValue();
}
static bool isClassFor(const Operation *op);
@ -152,7 +152,7 @@ public:
Type *type);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
return getAttrOfType<IntegerAttr>("value").getValue();
}
static bool isClassFor(const Operation *op);
@ -173,7 +173,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
return getAttrOfType<IntegerAttr>("value").getValue();
}
static bool isClassFor(const Operation *op);

View File

@ -24,12 +24,12 @@
#ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
namespace mlir {
class Attribute;
class AttributeListStorage;
class FunctionType;
class Location;
@ -39,7 +39,7 @@ class Module;
/// NamedAttribute is used for function attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null.
using NamedAttribute = std::pair<Identifier, Attribute *>;
using NamedAttribute = std::pair<Identifier, Attribute>;
/// This is the base class for all of the MLIR function types.
class Function : public llvm::ilist_node_with_parent<Function, Module> {

View File

@ -138,17 +138,16 @@ public:
ArrayRef<NamedAttribute> getAttrs() const { return state->getAttrs(); }
/// Return an attribute with the specified name.
Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
Attribute getAttr(StringRef name) const { return state->getAttr(name); }
/// If the operation has an attribute of the specified type, return it.
template <typename AttrClass>
AttrClass *getAttrOfType(StringRef name) const {
return dyn_cast_or_null<AttrClass>(getAttr(name));
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute *value) {
void setAttr(Identifier name, Attribute value) {
state->setAttr(name, value);
}
@ -211,8 +210,8 @@ public:
/// true if folding failed, or returns false and fills in `results` on
/// success.
static bool constantFoldHook(const Operation *op,
ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results) {
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
return op->cast<ConcreteType>()->constantFold(operands, results,
op->getContext());
}
@ -226,8 +225,8 @@ public:
///
/// If not overridden, this fallback implementation always fails to fold.
///
bool constantFold(ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results,
bool constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results,
MLIRContext *context) const {
return true;
}
@ -244,9 +243,9 @@ public:
/// true if folding failed, or returns false and fills in `results` on
/// success.
static bool constantFoldHook(const Operation *op,
ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results) {
auto *result =
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) {
auto result =
op->cast<ConcreteType>()->constantFold(operands, op->getContext());
if (!result)
return true;
@ -511,8 +510,8 @@ public:
///
/// If not overridden, this fallback implementation always fails to fold.
///
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
return nullptr;
}
};

View File

@ -69,7 +69,7 @@ public:
}
virtual void printType(const Type *type) = 0;
virtual void printFunctionReference(const Function *func) = 0;
virtual void printAttribute(const Attribute *attr) = 0;
virtual void printAttribute(Attribute attr) = 0;
virtual void printAffineMap(AffineMap map) = 0;
virtual void printAffineExpr(AffineExpr expr) = 0;
@ -100,8 +100,8 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Type &type) {
return p;
}
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const Attribute &attr) {
p.printAttribute(&attr);
inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Attribute attr) {
p.printAttribute(attr);
return p;
}
@ -210,24 +210,24 @@ public:
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
virtual bool parseAttribute(Attribute *&result, const char *attrName,
virtual bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind, capturing the location into `loc`
/// if specified.
template <typename AttrType>
bool parseAttribute(AttrType *&result, const char *attrName,
bool parseAttribute(AttrType &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc;
getCurrentLocation(&loc);
// Parse any kind of attribute.
Attribute *attr;
Attribute attr;
if (parseAttribute(attr, attrName, attrs))
return true;
// Check for the right kind of attribute.
result = dyn_cast<AttrType>(attr);
result = attr.dyn_cast<AttrType>();
if (!result) {
emitError(loc, "invalid kind of constant specified");
return true;

View File

@ -113,33 +113,31 @@ public:
ArrayRef<NamedAttribute> getAttrs() const;
/// Return the specified attribute if present, null otherwise.
Attribute *getAttr(Identifier name) const {
Attribute getAttr(Identifier name) const {
for (auto elt : getAttrs())
if (elt.first == name)
return elt.second;
return nullptr;
}
Attribute *getAttr(StringRef name) const {
Attribute getAttr(StringRef name) const {
for (auto elt : getAttrs())
if (elt.first.is(name))
return elt.second;
return nullptr;
}
template <typename AttrClass>
AttrClass *getAttrOfType(Identifier name) const {
return dyn_cast_or_null<AttrClass>(getAttr(name));
template <typename AttrClass> AttrClass getAttrOfType(Identifier name) const {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
template <typename AttrClass>
AttrClass *getAttrOfType(StringRef name) const {
return dyn_cast_or_null<AttrClass>(getAttr(name));
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) const {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute *value);
void setAttr(Identifier name, Attribute value);
enum class RemoveResult {
Removed, NotFound
@ -250,8 +248,8 @@ public:
/// the operands of the operation, but may be null if non-constant. If
/// constant folding is successful, this returns false and fills in the
/// `results` vector. If not, this returns true and `results` is unspecified.
bool constantFold(ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results) const;
bool constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Instruction *inst);

View File

@ -23,11 +23,11 @@
#ifndef MLIR_IR_OPERATION_SUPPORT_H
#define MLIR_IR_OPERATION_SUPPORT_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Identifier.h"
#include "llvm/ADT/PointerUnion.h"
namespace mlir {
class Attribute;
class Dialect;
class Location;
class Operation;
@ -80,8 +80,8 @@ public:
/// This hook implements a constant folder for this operation. It returns
/// true if folding failed, or returns false and fills in `results` on
/// success.
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results);
bool (&constantFoldHook)(const Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results);
// Returns whether the operation has a particular property.
bool hasProperty(OperationProperty property) const {
@ -110,8 +110,8 @@ private:
void (&printAssembly)(const Operation *op, OpAsmPrinter *p),
bool (&verifyInvariants)(const Operation *op),
bool (&constantFoldHook)(const Operation *op,
ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results))
ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results))
: name(name), dialect(dialect), isClassFor(isClassFor),
parseAssembly(parseAssembly), printAssembly(printAssembly),
verifyInvariants(verifyInvariants), constantFoldHook(constantFoldHook),
@ -124,7 +124,7 @@ private:
/// NamedAttribute is a used for operation attribute lists, it holds an
/// identifier for the name and a value for the attribute. The attribute
/// pointer should always be non-null.
using NamedAttribute = std::pair<Identifier, Attribute *>;
using NamedAttribute = std::pair<Identifier, Attribute>;
class OperationName {
public:
@ -204,7 +204,7 @@ public:
types.append(newTypes.begin(), newTypes.end());
}
void addAttribute(StringRef name, Attribute *attr) {
void addAttribute(StringRef name, Attribute attr) {
attributes.push_back({Identifier::get(name, context), attr});
}
};

View File

@ -49,8 +49,8 @@ class AddFOp
public:
static StringRef getOperationName() { return "addf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
@ -70,8 +70,8 @@ class AddIOp
public:
static StringRef getOperationName() { return "addi"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
@ -134,7 +134,7 @@ public:
ArrayRef<SSAValue *> operands);
Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee")->getValue();
return getAttrOfType<FunctionAttr>("callee").getValue();
}
// Hooks to customize behavior of this op.
@ -218,13 +218,13 @@ public:
static void build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index);
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const {
return static_cast<unsigned>(
getAttrOfType<IntegerAttr>("index")->getValue());
getAttrOfType<IntegerAttr>("index").getValue());
}
static StringRef getOperationName() { return "dim"; }
@ -513,8 +513,8 @@ class MulFOp
public:
static StringRef getOperationName() { return "mulf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
@ -534,8 +534,8 @@ class MulIOp
public:
static StringRef getOperationName() { return "muli"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
@ -597,8 +597,8 @@ class SubFOp : public BinaryOp<SubFOp, OpTrait::ResultsAreFloatLike,
public:
static StringRef getOperationName() { return "subf"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;
@ -617,8 +617,8 @@ class SubIOp : public BinaryOp<SubIOp, OpTrait::ResultsAreIntegerLike,
public:
static StringRef getOperationName() { return "subi"; }
Attribute *constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const;
Attribute constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const;
private:
friend class Operation;

View File

@ -82,7 +82,7 @@ public:
}
bool verifyOperation(const Operation &op);
bool verifyAttribute(Attribute *attr, const Operation &op);
bool verifyAttribute(Attribute attr, const Operation &op);
protected:
explicit Verifier(const Function &fn) : fn(fn) {}
@ -94,26 +94,26 @@ private:
} // end anonymous namespace
// Check that function attributes are all well formed.
bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) {
if (!attr->isOrContainsFunction())
bool Verifier::verifyAttribute(Attribute attr, const Operation &op) {
if (!attr.isOrContainsFunction())
return false;
// If we have a function attribute, check that it is non-null and in the
// same module as the operation that refers to it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) {
if (!fnAttr->getValue())
if (auto fnAttr = attr.dyn_cast<FunctionAttr>()) {
if (!fnAttr.getValue())
return failure("attribute refers to deallocated function!", op);
if (fnAttr->getValue()->getModule() != fn.getModule())
if (fnAttr.getValue()->getModule() != fn.getModule())
return failure("attribute refers to function '" +
Twine(fnAttr->getValue()->getName()) +
Twine(fnAttr.getValue()->getName()) +
"' defined in another module!",
op);
return false;
}
// Otherwise, we must have an array attribute, remap the elements.
for (auto *elt : cast<ArrayAttr>(attr)->getValue()) {
for (auto elt : attr.cast<ArrayAttr>().getValue()) {
if (verifyAttribute(elt, op))
return true;
}

View File

@ -32,13 +32,12 @@ namespace {
// evaluated on constant 'operandConsts'.
class AffineExprConstantFolder {
public:
AffineExprConstantFolder(unsigned numDims,
ArrayRef<Attribute *> operandConsts)
AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
: numDims(numDims), operandConsts(operandConsts) {}
/// Attempt to constant fold the specified affine expr, or return null on
/// failure.
IntegerAttr *constantFold(AffineExpr expr) {
IntegerAttr constantFold(AffineExpr expr) {
switch (expr.getKind()) {
case AffineExprKind::Add:
return constantFoldBinExpr(
@ -59,31 +58,32 @@ public:
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
expr.getContext());
case AffineExprKind::DimId:
return dyn_cast_or_null<IntegerAttr>(
operandConsts[expr.cast<AffineDimExpr>().getPosition()]);
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
.dyn_cast_or_null<IntegerAttr>();
case AffineExprKind::SymbolId:
return dyn_cast_or_null<IntegerAttr>(
operandConsts[numDims + expr.cast<AffineSymbolExpr>().getPosition()]);
return operandConsts[numDims +
expr.cast<AffineSymbolExpr>().getPosition()]
.dyn_cast_or_null<IntegerAttr>();
}
}
private:
IntegerAttr *
IntegerAttr
constantFoldBinExpr(AffineExpr expr,
std::function<uint64_t(int64_t, uint64_t)> op) {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
auto *lhs = constantFold(binOpExpr.getLHS());
auto *rhs = constantFold(binOpExpr.getRHS());
auto lhs = constantFold(binOpExpr.getLHS());
auto rhs = constantFold(binOpExpr.getRHS());
if (!lhs || !rhs)
return nullptr;
return IntegerAttr::get(op(lhs->getValue(), rhs->getValue()),
return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
expr.getContext());
}
// The number of dimension operands in AffineMap containing this expression.
unsigned numDims;
// The constant valued operands used to evaluate this AffineExpr.
ArrayRef<Attribute *> operandConsts;
ArrayRef<Attribute> operandConsts;
};
} // end anonymous namespace
@ -137,15 +137,15 @@ ArrayRef<AffineExpr> AffineMap::getRangeSizes() const {
/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
bool AffineMap::constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results) const {
bool AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results) const {
assert(getNumInputs() == operandConstants.size());
// Fold each of the result expressions.
AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
// Constant fold each AffineExpr in AffineMap and add to 'results'.
for (auto expr : getResults()) {
auto *folded = exprFolder.constantFold(expr);
auto folded = exprFolder.constantFold(expr);
// If we didn't fold to a constant, then folding fails.
if (!folded)
return true;

View File

@ -123,7 +123,7 @@ private:
void visitIfStmt(const IfStmt *ifStmt);
void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type);
void visitAttribute(const Attribute *attr);
void visitAttribute(Attribute attr);
void visitOperation(const Operation *op);
DenseMap<AffineMap, int> affineMapIds;
@ -150,11 +150,11 @@ void ModuleState::visitType(const Type *type) {
}
}
void ModuleState::visitAttribute(const Attribute *attr) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
recordAffineMapReference(mapAttr->getValue());
} else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
for (auto elt : arrayAttr->getValue()) {
void ModuleState::visitAttribute(Attribute attr) {
if (auto mapAttr = attr.dyn_cast<AffineMapAttr>()) {
recordAffineMapReference(mapAttr.getValue());
} else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (auto elt : arrayAttr.getValue()) {
visitAttribute(elt);
}
}
@ -268,7 +268,7 @@ public:
void print(const Module *module);
void printFunctionReference(const Function *func);
void printAttribute(const Attribute *attr);
void printAttribute(Attribute attr);
void printType(const Type *type);
void print(const Function *fn);
void print(const ExtFunction *fn);
@ -293,7 +293,7 @@ protected:
void printAffineMapReference(AffineMap affineMap);
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet integerSet);
void printDenseElementsAttr(const DenseElementsAttr *attr);
void printDenseElementsAttr(DenseElementsAttr attr);
/// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExprStorage is being printed in, so we can
@ -404,36 +404,36 @@ void ModulePrinter::printFunctionReference(const Function *func) {
os << '@' << func->getName();
}
void ModulePrinter::printAttribute(const Attribute *attr) {
switch (attr->getKind()) {
void ModulePrinter::printAttribute(Attribute attr) {
switch (attr.getKind()) {
case Attribute::Kind::Bool:
os << (cast<BoolAttr>(attr)->getValue() ? "true" : "false");
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
break;
case Attribute::Kind::Integer:
os << cast<IntegerAttr>(attr)->getValue();
os << attr.cast<IntegerAttr>().getValue();
break;
case Attribute::Kind::Float:
printFloatValue(cast<FloatAttr>(attr)->getValue(), os);
printFloatValue(attr.cast<FloatAttr>().getValue(), os);
break;
case Attribute::Kind::String:
os << '"';
printEscapedString(cast<StringAttr>(attr)->getValue(), os);
printEscapedString(attr.cast<StringAttr>().getValue(), os);
os << '"';
break;
case Attribute::Kind::Array:
os << '[';
interleaveComma(cast<ArrayAttr>(attr)->getValue(),
[&](Attribute *attr) { printAttribute(attr); });
interleaveComma(attr.cast<ArrayAttr>().getValue(),
[&](Attribute attr) { printAttribute(attr); });
os << ']';
break;
case Attribute::Kind::AffineMap:
printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
printAffineMapReference(attr.cast<AffineMapAttr>().getValue());
break;
case Attribute::Kind::Type:
printType(cast<TypeAttr>(attr)->getValue());
printType(attr.cast<TypeAttr>().getValue());
break;
case Attribute::Kind::Function: {
auto *function = cast<FunctionAttr>(attr)->getValue();
auto *function = attr.cast<FunctionAttr>().getValue();
if (!function) {
os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
} else {
@ -444,53 +444,52 @@ void ModulePrinter::printAttribute(const Attribute *attr) {
break;
}
case Attribute::Kind::OpaqueElements: {
auto *eltsAttr = cast<OpaqueElementsAttr>(attr);
auto eltsAttr = attr.cast<OpaqueElementsAttr>();
os << "opaque<";
printType(eltsAttr->getType());
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr->getValue()) << '"'
<< '>';
printType(eltsAttr.getType());
os << ", " << '"' << "0x" << llvm::toHex(eltsAttr.getValue()) << '"' << '>';
break;
}
case Attribute::Kind::DenseIntElements:
case Attribute::Kind::DenseFPElements: {
auto *eltsAttr = cast<DenseElementsAttr>(attr);
auto eltsAttr = attr.cast<DenseElementsAttr>();
os << "dense<";
printType(eltsAttr->getType());
printType(eltsAttr.getType());
os << ", ";
printDenseElementsAttr(eltsAttr);
os << '>';
break;
}
case Attribute::Kind::SplatElements: {
auto *elementsAttr = cast<SplatElementsAttr>(attr);
auto elementsAttr = attr.cast<SplatElementsAttr>();
os << "splat<";
printType(elementsAttr->getType());
printType(elementsAttr.getType());
os << ", ";
printAttribute(elementsAttr->getValue());
printAttribute(elementsAttr.getValue());
os << '>';
break;
}
case Attribute::Kind::SparseElements: {
auto *elementsAttr = cast<SparseElementsAttr>(attr);
auto elementsAttr = attr.cast<SparseElementsAttr>();
os << "sparse<";
printType(elementsAttr->getType());
printType(elementsAttr.getType());
os << ", ";
printDenseElementsAttr(elementsAttr->getIndices());
printDenseElementsAttr(elementsAttr.getIndices());
os << ", ";
printDenseElementsAttr(elementsAttr->getValues());
printDenseElementsAttr(elementsAttr.getValues());
os << '>';
break;
}
}
}
void ModulePrinter::printDenseElementsAttr(const DenseElementsAttr *attr) {
auto *type = attr->getType();
void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
auto *type = attr.getType();
auto shape = type->getShape();
auto rank = type->getRank();
SmallVector<Attribute *, 16> elements;
attr->getValues(elements);
SmallVector<Attribute, 16> elements;
attr.getValues(elements);
// Special case for degenerate tensors.
if (elements.empty()) {
@ -934,9 +933,7 @@ public:
// Implement OpAsmPrinter.
raw_ostream &getStream() const { return os; }
void printType(const Type *type) { ModulePrinter::printType(type); }
void printAttribute(const Attribute *attr) {
ModulePrinter::printAttribute(attr);
}
void printAttribute(Attribute attr) { ModulePrinter::printAttribute(attr); }
void printAffineMap(AffineMap map) {
return ModulePrinter::printAffineMapReference(map);
}
@ -980,7 +977,7 @@ protected:
} else if (auto intOp = op->dyn_cast<ConstantIndexOp>()) {
specialName << 'c' << intOp->getValue();
} else if (auto constant = op->dyn_cast<ConstantOp>()) {
if (isa<FunctionAttr>(constant->getValue()))
if (constant->getValue().isa<FunctionAttr>())
specialName << 'f';
else
specialName << "cst";
@ -1570,7 +1567,7 @@ void ModulePrinter::print(const MLFunction *fn) {
void Attribute::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAttribute(this);
ModulePrinter(os, state).printAttribute(*this);
}
void Attribute::dump() const { print(llvm::errs()); }

View File

@ -0,0 +1,131 @@
//===- AttributeDetail.h - MLIR Affine Map details Class --------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This holds implementation details of Attribute.
//
//===----------------------------------------------------------------------===//
#ifndef ATTRIBUTEDETAIL_H_
#define ATTRIBUTEDETAIL_H_
#include "mlir/IR/Attributes.h"
#include "mlir/IR/MLIRContext.h"
namespace mlir {
namespace detail {
/// Base storage class appearing in an attribute.
struct AttributeStorage {
Attribute::Kind kind : 8;
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
};
/// An attribute representing a boolean value.
struct BoolAttributeStorage : public AttributeStorage {
bool value;
};
/// An attribute representing a integral value.
struct IntegerAttributeStorage : public AttributeStorage {
int64_t value;
};
/// An attribute representing a floating point value.
struct FloatAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
const llvm::fltSemantics &semantics;
size_t numObjects;
/// Returns an APFloat representing the stored value.
APFloat getValue() const {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
};
/// An attribute representing a string value.
struct StringAttributeStorage : public AttributeStorage {
StringRef value;
};
/// An attribute representing an array of other attributes.
struct ArrayAttributeStorage : public AttributeStorage {
ArrayRef<Attribute> value;
};
// An attribute representing a reference to an affine map.
struct AffineMapAttributeStorage : public AttributeStorage {
AffineMap value;
};
/// An attribute representing a reference to a type.
struct TypeAttributeStorage : public AttributeStorage {
Type *value;
};
/// An attribute representing a reference to a function.
struct FunctionAttributeStorage : public AttributeStorage {
Function *value;
};
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
VectorOrTensorType *type;
};
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
Attribute elt;
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
ArrayRef<char> data;
};
/// An attribute representing a reference to a dense integer vector or tensor
/// object.
struct DenseIntElementsAttributeStorage : public DenseElementsAttributeStorage {
size_t bitsWidth;
};
/// An attribute representing a reference to a dense float vector or tensor
/// object.
struct DenseFPElementsAttributeStorage : public DenseElementsAttributeStorage {
};
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
StringRef bytes;
};
/// An attribute representing a reference to a sparse vector or tensor object.
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
DenseIntElementsAttr indices;
DenseElementsAttr values;
};
} // namespace detail
} // namespace mlir
#endif // ATTRIBUTEDETAIL_H_

214
mlir/lib/IR/Attributes.cpp Normal file
View File

@ -0,0 +1,214 @@
//===- Attributes.cpp - MLIR Affine Expr Classes --------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/Attributes.h"
#include "AttributeDetail.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Types.h"
using namespace mlir;
using namespace mlir::detail;
Attribute::Kind Attribute::getKind() const { return attr->kind; }
bool Attribute::isOrContainsFunction() const {
return attr->isOrContainsFunctionCache;
}
BoolAttr::BoolAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
int64_t IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
double FloatAttr::getDouble() const { return getValue().convertToDouble(); }
StringAttr::StringAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
StringRef StringAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
ArrayAttr::ArrayAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
ArrayRef<Attribute> ArrayAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
AffineMapAttr::AffineMapAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
AffineMap AffineMapAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
TypeAttr::TypeAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Type *TypeAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionAttr::FunctionAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
ElementsAttr::ElementsAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
VectorOrTensorType *ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
SplatElementsAttr::SplatElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
Attribute SplatElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->elt;
}
DenseElementsAttr::DenseElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>().getValues(values);
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>().getValues(values);
return;
default:
llvm_unreachable("unexpected element type");
}
}
ArrayRef<char> DenseElementsAttr::getRawData() const {
return static_cast<ImplType *>(attr)->data;
}
DenseIntElementsAttr::DenseIntElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`.
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) {
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto attr = IntegerAttr::get(value, context);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto attr = IntegerAttr::get(value.getSExtValue(), context);
values.push_back(attr);
}
}
}
DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
: DenseElementsAttr(ptr) {}
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto attr = FloatAttr::get(v, context);
values.push_back(attr);
}
}
OpaqueElementsAttr::OpaqueElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
StringRef OpaqueElementsAttr::getValue() const {
return static_cast<ImplType *>(attr)->bytes;
}
SparseElementsAttr::SparseElementsAttr(Attribute::ImplType *ptr)
: ElementsAttr(ptr) {}
DenseIntElementsAttr SparseElementsAttr::getIndices() const {
return static_cast<ImplType *>(attr)->indices;
}
DenseElementsAttr SparseElementsAttr::getValues() const {
return static_cast<ImplType *>(attr)->values;
}

View File

@ -112,60 +112,60 @@ UnrankedTensorType *Builder::getTensorType(Type *elementType) {
// Attributes.
//===----------------------------------------------------------------------===//
BoolAttr *Builder::getBoolAttr(bool value) {
BoolAttr Builder::getBoolAttr(bool value) {
return BoolAttr::get(value, context);
}
IntegerAttr *Builder::getIntegerAttr(int64_t value) {
IntegerAttr Builder::getIntegerAttr(int64_t value) {
return IntegerAttr::get(value, context);
}
FloatAttr *Builder::getFloatAttr(double value) {
FloatAttr Builder::getFloatAttr(double value) {
return FloatAttr::get(APFloat(value), context);
}
FloatAttr *Builder::getFloatAttr(const APFloat &value) {
FloatAttr Builder::getFloatAttr(const APFloat &value) {
return FloatAttr::get(value, context);
}
StringAttr *Builder::getStringAttr(StringRef bytes) {
StringAttr Builder::getStringAttr(StringRef bytes) {
return StringAttr::get(bytes, context);
}
ArrayAttr *Builder::getArrayAttr(ArrayRef<Attribute *> value) {
ArrayAttr Builder::getArrayAttr(ArrayRef<Attribute> value) {
return ArrayAttr::get(value, context);
}
AffineMapAttr *Builder::getAffineMapAttr(AffineMap map) {
AffineMapAttr Builder::getAffineMapAttr(AffineMap map) {
return AffineMapAttr::get(map);
}
TypeAttr *Builder::getTypeAttr(Type *type) {
TypeAttr Builder::getTypeAttr(Type *type) {
return TypeAttr::get(type, context);
}
FunctionAttr *Builder::getFunctionAttr(const Function *value) {
FunctionAttr Builder::getFunctionAttr(const Function *value) {
return FunctionAttr::get(value, context);
}
ElementsAttr *Builder::getSplatElementsAttr(VectorOrTensorType *type,
Attribute *elt) {
ElementsAttr Builder::getSplatElementsAttr(VectorOrTensorType *type,
Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
ElementsAttr *Builder::getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data) {
ElementsAttr Builder::getDenseElementsAttr(VectorOrTensorType *type,
ArrayRef<char> data) {
return DenseElementsAttr::get(type, data);
}
ElementsAttr *Builder::getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values) {
ElementsAttr Builder::getSparseElementsAttr(VectorOrTensorType *type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
return SparseElementsAttr::get(type, indices, values);
}
ElementsAttr *Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes) {
ElementsAttr Builder::getOpaqueElementsAttr(VectorOrTensorType *type,
StringRef bytes) {
return OpaqueElementsAttr::get(type, bytes);
}

View File

@ -86,13 +86,13 @@ bool AffineApplyOp::parse(OpAsmParser *parser, OperationState *result) {
auto &builder = parser->getBuilder();
auto *affineIntTy = builder.getIndexType();
AffineMapAttr *mapAttr;
AffineMapAttr mapAttr;
unsigned numDims;
if (parser->parseAttribute(mapAttr, "map", result->attributes) ||
parseDimAndSymbolList(parser, result->operands, numDims) ||
parser->parseOptionalAttributeDict(result->attributes))
return true;
auto map = mapAttr->getValue();
auto map = mapAttr.getValue();
if (map.getNumDims() != numDims ||
numDims + map.getNumSymbols() != result->operands.size()) {
@ -113,12 +113,12 @@ void AffineApplyOp::print(OpAsmPrinter *p) const {
bool AffineApplyOp::verify() const {
// Check that affine map attribute was specified.
auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
if (!affineMapAttr)
return emitOpError("requires an affine map");
// Check input and output dimensions match.
auto map = affineMapAttr->getValue();
auto map = affineMapAttr.getValue();
// Verify that operand count matches affine map dimension and symbol count.
if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
@ -155,8 +155,8 @@ bool AffineApplyOp::isValidSymbol() const {
return true;
}
bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
SmallVectorImpl<Attribute *> &results,
bool AffineApplyOp::constantFold(ArrayRef<Attribute> operandConstants,
SmallVectorImpl<Attribute> &results,
MLIRContext *context) const {
auto map = getAffineMap();
if (map.constantFold(operandConstants, results))
@ -171,21 +171,21 @@ bool AffineApplyOp::constantFold(ArrayRef<Attribute *> operandConstants,
/// Builds a constant op with the specified attribute value and result type.
void ConstantOp::build(Builder *builder, OperationState *result,
Attribute *value, Type *type) {
Attribute value, Type *type) {
result->addAttribute("value", value);
result->types.push_back(type);
}
void ConstantOp::print(OpAsmPrinter *p) const {
*p << "constant " << *getValue();
*p << "constant " << getValue();
p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
if (!isa<FunctionAttr>(getValue()))
if (!getValue().isa<FunctionAttr>())
*p << " : " << *getType();
}
bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
Attribute *valueAttr;
Attribute valueAttr;
Type *type;
if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
@ -194,8 +194,8 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
// 'constant' taking a function reference doesn't get a redundant type
// specifier. The attribute itself carries it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
return parser->parseColonType(type) ||
parser->addTypeToList(type, result->types);
@ -204,32 +204,32 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
bool ConstantOp::verify() const {
auto *value = getValue();
auto value = getValue();
if (!value)
return emitOpError("requires a 'value' attribute");
auto *type = this->getType();
if (isa<IntegerType>(type) || type->isIndex()) {
if (!isa<IntegerAttr>(value))
if (!value.isa<IntegerAttr>())
return emitOpError(
"requires 'value' to be an integer for an integer result type");
return false;
}
if (isa<FloatType>(type)) {
if (!isa<FloatAttr>(value))
if (!value.isa<FloatAttr>())
return emitOpError("requires 'value' to be a floating point constant");
return false;
}
if (type->isTFString()) {
if (!isa<StringAttr>(value))
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");
return false;
}
if (isa<FunctionType>(type)) {
if (!isa<FunctionAttr>(value))
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");
return false;
}
@ -238,8 +238,8 @@ bool ConstantOp::verify() const {
"requires a result type that aligns with the 'value' attribute");
}
Attribute *ConstantOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.empty() && "constant has no operands");
return getValue();
}

View File

@ -18,6 +18,7 @@
#include "mlir/IR/MLIRContext.h"
#include "AffineExprDetail.h"
#include "AffineMapDetail.h"
#include "AttributeDetail.h"
#include "AttributeListStorage.h"
#include "IntegerSetDetail.h"
#include "mlir/IR/AffineExpr.h"
@ -169,35 +170,35 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefType *> {
}
};
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttr *> {
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
// Float attributes are uniqued based on wrapped APFloat.
using KeyTy = APFloat;
using DenseMapInfo<FloatAttr *>::getHashValue;
using DenseMapInfo<FloatAttr *>::isEqual;
using DenseMapInfo<FloatAttributeStorage *>::getHashValue;
using DenseMapInfo<FloatAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static bool isEqual(const KeyTy &lhs, const FloatAttr *rhs) {
static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.bitwiseIsEqual(rhs->getValue());
}
};
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttr *> {
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute *>;
using DenseMapInfo<ArrayAttr *>::getHashValue;
using DenseMapInfo<ArrayAttr *>::isEqual;
using KeyTy = ArrayRef<Attribute>;
using DenseMapInfo<ArrayAttributeStorage *>::getHashValue;
using DenseMapInfo<ArrayAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine_range(key.begin(), key.end());
}
static bool isEqual(const KeyTy &lhs, const ArrayAttr *rhs) {
static bool isEqual(const KeyTy &lhs, const ArrayAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == rhs->getValue();
return lhs == rhs->value;
}
};
@ -218,37 +219,39 @@ struct AttributeListKeyInfo : DenseMapInfo<AttributeListStorage *> {
}
};
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttr *> {
struct DenseElementsAttrInfo : DenseMapInfo<DenseElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, ArrayRef<char>>;
using DenseMapInfo<DenseElementsAttr *>::getHashValue;
using DenseMapInfo<DenseElementsAttr *>::isEqual;
using DenseMapInfo<DenseElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<DenseElementsAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const DenseElementsAttr *rhs) {
static bool isEqual(const KeyTy &lhs,
const DenseElementsAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_pair(rhs->getType(), rhs->getRawData());
return lhs == std::make_pair(rhs->type, rhs->data);
}
};
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttr *> {
struct OpaqueElementsAttrInfo : DenseMapInfo<OpaqueElementsAttributeStorage *> {
using KeyTy = std::pair<VectorOrTensorType *, StringRef>;
using DenseMapInfo<OpaqueElementsAttr *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttr *>::isEqual;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::getHashValue;
using DenseMapInfo<OpaqueElementsAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) {
return hash_combine(
key.first, hash_combine_range(key.second.begin(), key.second.end()));
}
static bool isEqual(const KeyTy &lhs, const OpaqueElementsAttr *rhs) {
static bool isEqual(const KeyTy &lhs,
const OpaqueElementsAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == std::make_pair(rhs->getType(), rhs->getValue());
return lhs == std::make_pair(rhs->type, rhs->bytes);
}
};
} // end anonymous namespace.
@ -343,28 +346,29 @@ public:
MemRefTypeSet memrefs;
// Attribute uniquing.
BoolAttr *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttr *> integerAttrs;
DenseSet<FloatAttr *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttr *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
BoolAttributeStorage *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttributeStorage *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
ArrayAttrSet arrayAttrs;
DenseMap<AffineMap, AffineMapAttr *> affineMapAttrs;
DenseMap<Type *, TypeAttr *> typeAttrs;
DenseMap<AffineMap, AffineMapAttributeStorage *> affineMapAttrs;
DenseMap<Type *, TypeAttributeStorage *> typeAttrs;
using AttributeListSet =
DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
AttributeListSet attributeLists;
DenseMap<const Function *, FunctionAttr *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute *>, SplatElementsAttr *>
DenseMap<const Function *, FunctionAttributeStorage *> functionAttrs;
DenseMap<std::pair<VectorOrTensorType *, Attribute>,
SplatElementsAttributeStorage *>
splatElementsAttrs;
using DenseElementsAttrSet =
DenseSet<DenseElementsAttr *, DenseElementsAttrInfo>;
DenseSet<DenseElementsAttributeStorage *, DenseElementsAttrInfo>;
DenseElementsAttrSet denseElementsAttrs;
using OpaqueElementsAttrSet =
DenseSet<OpaqueElementsAttr *, OpaqueElementsAttrInfo>;
DenseSet<OpaqueElementsAttributeStorage *, OpaqueElementsAttrInfo>;
OpaqueElementsAttrSet opaqueElementsAttrs;
DenseMap<std::tuple<Type *, DenseElementsAttr *, DenseElementsAttr *>,
SparseElementsAttr *>
DenseMap<std::tuple<Type *, Attribute, Attribute>,
SparseElementsAttributeStorage *>
sparseElementsAttrs;
public:
@ -716,31 +720,36 @@ MemRefType *MemRefType::get(ArrayRef<int> shape, Type *elementType,
// Attribute uniquing
//===----------------------------------------------------------------------===//
BoolAttr *BoolAttr::get(bool value, MLIRContext *context) {
BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
auto *&result = context->getImpl().boolAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<BoolAttr>();
new (result) BoolAttr(value);
result = context->getImpl().allocator.Allocate<BoolAttributeStorage>();
new (result) BoolAttributeStorage{{Attribute::Kind::Bool,
/*isOrContainsFunction=*/false},
value};
return result;
}
IntegerAttr *IntegerAttr::get(int64_t value, MLIRContext *context) {
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
auto *&result = context->getImpl().integerAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<IntegerAttr>();
new (result) IntegerAttr(value);
result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
/*isOrContainsFunction=*/false},
value};
result->value = value;
return result;
}
FloatAttr *FloatAttr::get(double value, MLIRContext *context) {
FloatAttr FloatAttr::get(double value, MLIRContext *context) {
return get(APFloat(value), context);
}
FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
auto &impl = context->getImpl();
// Look to see if the float attribute has been created already.
@ -755,33 +764,35 @@ FloatAttr *FloatAttr::get(const APFloat &value, MLIRContext *context) {
// Here one word's bitwidth equals to that of uint64_t.
auto elements = ArrayRef<uint64_t>(apint.getRawData(), apint.getNumWords());
auto byteSize = FloatAttr::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem = impl.allocator.Allocate(byteSize, alignof(FloatAttr));
auto result = ::new (rawMem) FloatAttr(value.getSemantics(), elements.size());
auto byteSize =
FloatAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem =
impl.allocator.Allocate(byteSize, alignof(FloatAttributeStorage));
auto result = ::new (rawMem) FloatAttributeStorage{
{Attribute::Kind::Float, /*isOrContainsFunction=*/false},
{},
value.getSemantics(),
elements.size()};
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return *existing.first = result;
}
APFloat FloatAttr::getValue() const {
auto val = APInt(APFloat::getSizeInBits(semantics),
{getTrailingObjects<uint64_t>(), numObjects});
return APFloat(semantics, val);
}
StringAttr *StringAttr::get(StringRef bytes, MLIRContext *context) {
StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) {
auto it = context->getImpl().stringAttrs.insert({bytes, nullptr}).first;
if (it->second)
return it->second;
auto result = context->getImpl().allocator.Allocate<StringAttr>();
new (result) StringAttr(it->first());
auto result = context->getImpl().allocator.Allocate<StringAttributeStorage>();
new (result) StringAttributeStorage{{Attribute::Kind::String,
/*isOrContainsFunction=*/false},
it->first()};
it->second = result;
return result;
}
ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
ArrayAttr ArrayAttr::get(ArrayRef<Attribute> value, MLIRContext *context) {
auto &impl = context->getImpl();
// Look to see if we already have this.
@ -792,61 +803,66 @@ ArrayAttr *ArrayAttr::get(ArrayRef<Attribute *> value, MLIRContext *context) {
return *existing.first;
// On the first use, we allocate them into the bump pointer.
auto *result = impl.allocator.Allocate<ArrayAttr>();
auto *result = impl.allocator.Allocate<ArrayAttributeStorage>();
// Copy the elements into the bump pointer.
value = impl.copyInto(value);
// Check to see if any of the elements have a function attr.
bool hasFunctionAttr = false;
for (auto *elt : value)
if (elt->isOrContainsFunction()) {
for (auto elt : value)
if (elt.isOrContainsFunction()) {
hasFunctionAttr = true;
break;
}
// Initialize the memory using placement new.
new (result) ArrayAttr(value, hasFunctionAttr);
new (result)
ArrayAttributeStorage{{Attribute::Kind::Array, hasFunctionAttr}, value};
// Cache and return it.
return *existing.first = result;
}
AffineMapAttr *AffineMapAttr::get(AffineMap value) {
AffineMapAttr AffineMapAttr::get(AffineMap value) {
auto *context = value.getResult(0).getContext();
auto &result = context->getImpl().affineMapAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<AffineMapAttr>();
new (result) AffineMapAttr(value);
result = context->getImpl().allocator.Allocate<AffineMapAttributeStorage>();
new (result) AffineMapAttributeStorage{{Attribute::Kind::AffineMap,
/*isOrContainsFunction=*/false},
value};
return result;
}
TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
TypeAttr TypeAttr::get(Type *type, MLIRContext *context) {
auto *&result = context->getImpl().typeAttrs[type];
if (result)
return result;
result = context->getImpl().allocator.Allocate<TypeAttr>();
new (result) TypeAttr(type);
result = context->getImpl().allocator.Allocate<TypeAttributeStorage>();
new (result) TypeAttributeStorage{{Attribute::Kind::Type,
/*isOrContainsFunction=*/false},
type};
return result;
}
FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
FunctionAttr FunctionAttr::get(const Function *value, MLIRContext *context) {
assert(value && "Cannot get FunctionAttr for a null function");
auto *&result = context->getImpl().functionAttrs[value];
if (result)
return result;
result = context->getImpl().allocator.Allocate<FunctionAttr>();
new (result) FunctionAttr(const_cast<Function *>(value));
result = context->getImpl().allocator.Allocate<FunctionAttributeStorage>();
new (result) FunctionAttributeStorage{{Attribute::Kind::Function,
/*isOrContainsFunction=*/true},
const_cast<Function *>(value)};
return result;
}
FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
/// This function is used by the internals of the Function class to null out
/// attributes refering to functions that are about to be deleted.
void FunctionAttr::dropFunctionReference(Function *value) {
@ -935,30 +951,29 @@ AttributeListStorage *AttributeListStorage::get(ArrayRef<NamedAttribute> attrs,
return *existing.first = result;
}
OpaqueElementsAttr *OpaqueElementsAttr::get(VectorOrTensorType *type,
StringRef bytes) {
assert(isValidTensorElementType(type->getElementType()) &&
"Input element type should be a valid tensor element type");
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType *type,
Attribute elt) {
auto &impl = type->getContext()->getImpl();
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
// If we already have it, return that value.
if (!existing.second)
return *existing.first;
if (result)
return result;
// Otherwise, allocate a new one, unique it and return it.
auto *result = impl.allocator.Allocate<OpaqueElementsAttr>();
bytes = bytes.copy(impl.allocator);
new (result) OpaqueElementsAttr(type, bytes);
return *existing.first = result;
// Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SplatElementsAttributeStorage>();
new (result) SplatElementsAttributeStorage{{{Attribute::Kind::SplatElements,
/*isOrContainsFunction=*/false},
type},
elt};
return result;
}
DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
ArrayRef<char> data) {
DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType *type,
ArrayRef<char> data) {
auto bitsRequired = (long)type->getBitWidth() * type->getNumElements();
(void)(bitsRequired);
assert((bitsRequired <= data.size() * 8L) &&
@ -981,18 +996,25 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
auto *result = impl.allocator.Allocate<DenseFPElementsAttr>();
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseFPElementsAttr(type, {copy, data.size()});
new (result) DenseFPElementsAttributeStorage{
{{{Attribute::Kind::DenseFPElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}}};
return *existing.first = result;
}
case Type::Kind::Integer: {
auto width = cast<IntegerType>(eltType)->getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttr>();
auto width = ::cast<IntegerType>(eltType)->getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
new (result) DenseIntElementsAttr(type, {copy, data.size()}, width);
new (result) DenseIntElementsAttributeStorage{
{{{Attribute::Kind::DenseIntElements, /*isOrContainsFunction=*/false},
type},
{copy, data.size()}},
width};
return *existing.first = result;
}
default:
@ -1000,118 +1022,33 @@ DenseElementsAttr *DenseElementsAttr::get(VectorOrTensorType *type,
}
}
/// Writes the lowest `bitWidth` bits of `value` to bit position `bitPos`
/// starting from `rawData`.
void DenseIntElementsAttr::writeBits(char *data, size_t bitPos, size_t bitWidth,
uint64_t value) {
// Read the destination bytes which will be written to.
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitWidth;
auto start = data + bitPos / 8;
auto end = data + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
OpaqueElementsAttr OpaqueElementsAttr::get(VectorOrTensorType *type,
StringRef bytes) {
assert(isValidTensorElementType(type->getElementType()) &&
"Input element type should be a valid tensor element type");
// Clean up the invalid bits in the destination bytes.
dst &= ~(-1UL << (bitPos % 8));
// Get the valid bits of the source value, shift them to right position,
// then add them to the destination bytes.
value <<= bitPos % 8;
dst |= value;
// Write the destination bytes back.
ArrayRef<char> range({dstData, (size_t)(end - start)});
std::copy(range.begin(), range.end(), start);
}
/// Reads the next `bitWidth` bits from the bit position `bitPos` of `rawData`
/// and put them in the lowest bits.
uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
size_t bitsWidth) {
uint64_t dst = 0;
auto dstData = reinterpret_cast<char *>(&dst);
auto endPos = bitPos + bitsWidth;
auto start = rawData + bitPos / 8;
auto end = rawData + endPos / 8 + (endPos % 8 != 0);
std::copy(start, end, dstData);
dst >>= bitPos % 8;
dst &= ~(-1UL << bitsWidth);
return dst;
}
void DenseElementsAttr::getValues(SmallVectorImpl<Attribute *> &values) const {
switch (getKind()) {
case Attribute::Kind::DenseIntElements:
cast<DenseIntElementsAttr>(this)->getValues(values);
return;
case Attribute::Kind::DenseFPElements:
cast<DenseFPElementsAttr>(this)->getValues(values);
return;
default:
llvm_unreachable("unexpected element type");
}
}
void DenseIntElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto *attr = IntegerAttr::get(value, context);
values.push_back(attr);
}
} else {
const auto *rawData = getRawData().data();
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto *attr = IntegerAttr::get(value.getSExtValue(), context);
values.push_back(attr);
}
}
}
void DenseFPElementsAttr::getValues(
SmallVectorImpl<Attribute *> &values) const {
auto elementNum = getType()->getNumElements();
auto context = getType()->getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto *attr = FloatAttr::get(v, context);
values.push_back(attr);
}
}
SplatElementsAttr *SplatElementsAttr::get(VectorOrTensorType *type,
Attribute *elt) {
auto &impl = type->getContext()->getImpl();
// Look to see if we already have this.
auto *&result = impl.splatElementsAttrs[{type, elt}];
// Look to see if this constant is already defined.
OpaqueElementsAttrInfo::KeyTy key({type, bytes});
auto existing = impl.opaqueElementsAttrs.insert_as(nullptr, key);
// If we already have it, return that value.
if (result)
return result;
if (!existing.second)
return *existing.first;
// Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SplatElementsAttr>();
new (result) SplatElementsAttr(type, elt);
return result;
// Otherwise, allocate a new one, unique it and return it.
auto *result = impl.allocator.Allocate<OpaqueElementsAttributeStorage>();
bytes = bytes.copy(impl.allocator);
new (result) OpaqueElementsAttributeStorage{
{{Attribute::Kind::OpaqueElements, /*isOrContainsFunction=*/false}, type},
bytes};
return *existing.first = result;
}
SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
DenseIntElementsAttr *indices,
DenseElementsAttr *values) {
SparseElementsAttr SparseElementsAttr::get(VectorOrTensorType *type,
DenseIntElementsAttr indices,
DenseElementsAttr values) {
auto &impl = type->getContext()->getImpl();
// Look to see if we already have this.
@ -1123,8 +1060,12 @@ SparseElementsAttr *SparseElementsAttr::get(VectorOrTensorType *type,
return result;
// Otherwise, allocate them into the bump pointer.
result = impl.allocator.Allocate<SparseElementsAttr>();
new (result) SparseElementsAttr(type, indices, values);
result = impl.allocator.Allocate<SparseElementsAttributeStorage>();
new (result) SparseElementsAttributeStorage{{{Attribute::Kind::SparseElements,
/*isOrContainsFunction=*/false},
type},
indices,
values};
return result;
}

View File

@ -148,7 +148,7 @@ ArrayRef<NamedAttribute> Operation::getAttrs() const {
/// If an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void Operation::setAttr(Identifier name, Attribute *value) {
void Operation::setAttr(Identifier name, Attribute value) {
assert(value && "attributes may never be null");
auto origAttrs = getAttrs();
@ -225,8 +225,8 @@ void Operation::erase() {
/// Attempt to constant fold this operation with the specified constant
/// operand values. If successful, this returns false and fills in the
/// results vector. If not, this returns true and results is unspecified.
bool Operation::constantFold(ArrayRef<Attribute *> operands,
SmallVectorImpl<Attribute *> &results) const {
bool Operation::constantFold(ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) const {
// If we have a registered operation definition matching this one, use it to
// try to constant fold the operation.
if (auto *abstractOp = getAbstractOperation())

View File

@ -195,7 +195,7 @@ public:
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType *type);
Attribute *parseAttribute();
Attribute parseAttribute();
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -204,8 +204,8 @@ public:
AffineMap parseAffineMapReference();
IntegerSet parseIntegerSetInline();
IntegerSet parseIntegerSetReference();
DenseElementsAttr *parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr *parseDenseElementsAttr(Type *eltType, bool isVector);
DenseElementsAttr parseDenseElementsAttr(VectorOrTensorType *type);
DenseElementsAttr parseDenseElementsAttr(Type *eltType, bool isVector);
VectorOrTensorType *parseVectorOrTensorType();
private:
@ -684,7 +684,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto *result = p.parseAttribute();
auto result = p.parseAttribute();
if (!result)
return p.emitError("expected tensor element");
// check result matches the element type.
@ -693,16 +693,16 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
if (!isa<FloatAttr>(result))
if (!result.isa<FloatAttr>())
return p.emitError("expected tensor literal element has float type");
double value = cast<FloatAttr>(result)->getDouble();
double value = result.cast<FloatAttr>().getDouble();
addToStorage(*(uint64_t *)(&value));
break;
}
case Type::Kind::Integer: {
if (!isa<IntegerAttr>(result))
if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type");
auto value = cast<IntegerAttr>(result)->getValue();
auto value = result.cast<IntegerAttr>().getValue();
// If we couldn't successfully round trip the value, it means some bits
// are truncated and we should give up here.
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
@ -804,7 +804,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | `sparse<` (tensor-type | vector-type)`,`
/// attribute-value`, ` attribute-value `>`
///
Attribute *Parser::parseAttribute() {
Attribute Parser::parseAttribute() {
switch (getToken().getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
@ -859,7 +859,7 @@ Attribute *Parser::parseAttribute() {
case Token::l_square: {
consumeToken(Token::l_square);
SmallVector<Attribute *, 4> elements;
SmallVector<Attribute, 4> elements;
auto parseElt = [&]() -> ParseResult {
elements.push_back(parseAttribute());
@ -928,7 +928,7 @@ Attribute *Parser::parseAttribute() {
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto *scalar = parseAttribute();
auto scalar = parseAttribute();
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getSplatElementsAttr(type, scalar);
@ -973,7 +973,7 @@ Attribute *Parser::parseAttribute() {
case Token::l_square: {
/// Parse indices
auto *indicesEltType = builder.getIntegerType(32);
auto *indices =
auto indices =
parseDenseElementsAttr(indicesEltType, isa<VectorType>(type));
if (parseToken(Token::comma, "expected ','"))
@ -981,12 +981,12 @@ Attribute *Parser::parseAttribute() {
/// Parse values.
auto *valuesEltType = type->getElementType();
auto *values =
auto values =
parseDenseElementsAttr(valuesEltType, isa<VectorType>(type));
/// Sanity check.
auto *indicesType = indices->getType();
auto *valuesType = values->getType();
auto *indicesType = indices.getType();
auto *valuesType = values.getType();
auto sameShape = (indicesType->getRank() == 1) ||
(type->getRank() == indicesType->getDimSize(1));
auto sameElementNum =
@ -1009,7 +1009,7 @@ Attribute *Parser::parseAttribute() {
// Build the sparse elements attribute by the indices and values.
return builder.getSparseElementsAttr(
type, cast<DenseIntElementsAttr>(indices), values);
type, indices.cast<DenseIntElementsAttr>(), values);
}
default:
return (emitError("expected '[' to start sparse tensor literal"),
@ -1035,8 +1035,7 @@ Attribute *Parser::parseAttribute() {
///
/// This method returns a constructed dense elements attribute with the shape
/// from the parsing result.
DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
bool isVector) {
DenseElementsAttr Parser::parseDenseElementsAttr(Type *eltType, bool isVector) {
TensorLiteralParser literalParser(*this, eltType);
if (literalParser.parse())
return nullptr;
@ -1047,8 +1046,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
} else {
type = builder.getTensorType(literalParser.getShape(), eltType);
}
return (DenseElementsAttr *)builder.getDenseElementsAttr(
type, literalParser.getValues());
return builder.getDenseElementsAttr(type, literalParser.getValues())
.cast<DenseElementsAttr>();
}
/// Dense elements attribute.
@ -1061,7 +1060,7 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(Type *eltType,
/// This method compares the shapes from the parsing result and that from the
/// input argument. It returns a constructed dense elements attribute if both
/// match.
DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
DenseElementsAttr Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
auto *eltTy = type->getElementType();
TensorLiteralParser literalParser(*this, eltTy);
if (literalParser.parse())
@ -1076,8 +1075,8 @@ DenseElementsAttr *Parser::parseDenseElementsAttr(VectorOrTensorType *type) {
s << "])";
return (emitError(s.str()), nullptr);
}
return (DenseElementsAttr *)builder.getDenseElementsAttr(
type, literalParser.getValues());
return builder.getDenseElementsAttr(type, literalParser.getValues())
.cast<DenseElementsAttr>();
}
/// Vector or tensor type for elements attribute.
@ -2133,7 +2132,7 @@ public:
/// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name.
/// this captures the location of the attribute in 'loc' if it is non-null.
bool parseAttribute(Attribute *&result, const char *attrName,
bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute();
if (!result)
@ -3336,27 +3335,27 @@ ParseResult ModuleParser::parseMLFunc() {
/// Given an attribute that could refer to a function attribute in the
/// remapping table, walk it and rewrite it to use the mapped function. If it
/// doesn't refer to anything in the table, then it is returned unmodified.
static Attribute *
remapFunctionAttrs(Attribute *input,
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
static Attribute
remapFunctionAttrs(Attribute input,
DenseMap<Attribute, FunctionAttr> &remappingTable,
MLIRContext *context) {
// Most attributes are trivially unrelated to function attributes, skip them
// rapidly.
if (!input->isOrContainsFunction())
if (!input.isOrContainsFunction())
return input;
// If we have a function attribute, remap it.
if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) {
if (auto fnAttr = input.dyn_cast<FunctionAttr>()) {
auto it = remappingTable.find(fnAttr);
return it != remappingTable.end() ? it->second : input;
}
// Otherwise, we must have an array attribute, remap the elements.
auto *arrayAttr = cast<ArrayAttr>(input);
SmallVector<Attribute *, 8> remappedElts;
auto arrayAttr = input.cast<ArrayAttr>();
SmallVector<Attribute, 8> remappedElts;
bool anyChange = false;
for (auto *elt : arrayAttr->getValue()) {
auto *newElt = remapFunctionAttrs(elt, remappingTable, context);
for (auto elt : arrayAttr.getValue()) {
auto newElt = remapFunctionAttrs(elt, remappingTable, context);
remappedElts.push_back(newElt);
anyChange |= (elt != newElt);
}
@ -3370,11 +3369,11 @@ remapFunctionAttrs(Attribute *input,
/// Remap function attributes to resolve forward references to their actual
/// definition.
static void remapFunctionAttrsInOperation(
Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) {
Operation *op, DenseMap<Attribute, FunctionAttr> &remappingTable) {
for (auto attr : op->getAttrs()) {
// Do the remapping, if we got the same thing back, then it must contain
// functions that aren't getting remapped.
auto *newVal =
auto newVal =
remapFunctionAttrs(attr.second, remappingTable, op->getContext());
if (newVal == attr.second)
continue;
@ -3391,7 +3390,7 @@ static void remapFunctionAttrsInOperation(
ParseResult ModuleParser::finalizeModule() {
// Resolve all forward references, building a remapping table of attributes.
DenseMap<FunctionAttr *, FunctionAttr *> remappingTable;
DenseMap<Attribute, FunctionAttr> remappingTable;
for (auto forwardRef : getState().functionForwardRefs) {
auto name = forwardRef.first;
@ -3428,13 +3427,13 @@ ParseResult ModuleParser::finalizeModule() {
continue;
struct MLFnWalker : public StmtWalker<MLFnWalker> {
MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable)
MLFnWalker(DenseMap<Attribute, FunctionAttr> &remappingTable)
: remappingTable(remappingTable) {}
void visitOperationStmt(OperationStmt *opStmt) {
remapFunctionAttrsInOperation(opStmt, remappingTable);
}
DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable;
DenseMap<Attribute, FunctionAttr> &remappingTable;
};
MLFnWalker(remappingTable).walk(mlFn);

View File

@ -44,13 +44,13 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
// AddFOp
//===----------------------------------------------------------------------===//
Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "addf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
return FloatAttr::get(lhs->getValue() + rhs->getValue(), context);
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() + rhs.getValue(), context);
}
return nullptr;
@ -60,13 +60,13 @@ Attribute *AddFOp::constantFold(ArrayRef<Attribute *> operands,
// AddIOp
//===----------------------------------------------------------------------===//
Attribute *AddIOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "addi takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
return IntegerAttr::get(lhs->getValue() + rhs->getValue(), context);
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs.getValue() + rhs.getValue(), context);
}
return nullptr;
@ -192,12 +192,12 @@ void CallOp::print(OpAsmPrinter *p) const {
bool CallOp::verify() const {
// Check that the callee attribute was specified.
auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
auto fnAttr = getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return emitOpError("requires a 'callee' function attribute");
// Verify that the operand and result types match the callee.
auto *fnType = fnAttr->getValue()->getType();
auto *fnType = fnAttr.getValue()->getType();
if (fnType->getNumInputs() != getNumOperands())
return emitOpError("incorrect number of operands for callee");
@ -329,7 +329,7 @@ void DimOp::print(OpAsmPrinter *p) const {
bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr *indexAttr;
IntegerAttr indexAttr;
Type *type;
return parser->parseOperand(operandInfo) || parser->parseComma() ||
@ -346,7 +346,7 @@ bool DimOp::verify() const {
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr->getValue();
uint64_t index = (uint64_t)indexAttr.getValue();
auto *type = getOperand()->getType();
if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
@ -365,8 +365,8 @@ bool DimOp::verify() const {
return false;
}
Attribute *DimOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
// Constant fold dim when the size along the index referred to is a constant.
auto *opType = getOperand()->getType();
int indexSize = -1;
@ -671,13 +671,13 @@ bool MemRefCastOp::verify() const {
// MulFOp
//===----------------------------------------------------------------------===//
Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "mulf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
return FloatAttr::get(lhs->getValue() * rhs->getValue(), context);
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() * rhs.getValue(), context);
}
return nullptr;
@ -687,23 +687,23 @@ Attribute *MulFOp::constantFold(ArrayRef<Attribute *> operands,
// MulIOp
//===----------------------------------------------------------------------===//
Attribute *MulIOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "muli takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
// 0*x == 0
if (lhs->getValue() == 0)
if (lhs.getValue() == 0)
return lhs;
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
// TODO: Handle the overflow case.
return IntegerAttr::get(lhs->getValue() * rhs->getValue(), context);
return IntegerAttr::get(lhs.getValue() * rhs.getValue(), context);
}
// x*0 == 0
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
if (rhs->getValue() == 0)
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
if (rhs.getValue() == 0)
return rhs;
return nullptr;
@ -817,13 +817,13 @@ bool StoreOp::verify() const {
// SubFOp
//===----------------------------------------------------------------------===//
Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "subf takes two operands");
if (auto *lhs = dyn_cast_or_null<FloatAttr>(operands[0])) {
if (auto *rhs = dyn_cast_or_null<FloatAttr>(operands[1]))
return FloatAttr::get(lhs->getValue() - rhs->getValue(), context);
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() - rhs.getValue(), context);
}
return nullptr;
@ -833,13 +833,13 @@ Attribute *SubFOp::constantFold(ArrayRef<Attribute *> operands,
// SubIOp
//===----------------------------------------------------------------------===//
Attribute *SubIOp::constantFold(ArrayRef<Attribute *> operands,
MLIRContext *context) const {
Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
MLIRContext *context) const {
assert(operands.size() == 2 && "subi takes two operands");
if (auto *lhs = dyn_cast_or_null<IntegerAttr>(operands[0])) {
if (auto *rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
return IntegerAttr::get(lhs->getValue() - rhs->getValue(), context);
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs.getValue() - rhs.getValue(), context);
}
return nullptr;

View File

@ -31,7 +31,7 @@ struct ConstantFold : public FunctionPass, StmtWalker<ConstantFold> {
SmallVector<SSAValue *, 8> existingConstants;
// Operation statements that were folded and that need to be erased.
std::vector<OperationStmt *> opStmtsToErase;
using ConstantFactoryType = std::function<SSAValue *(Attribute *, Type *)>;
using ConstantFactoryType = std::function<SSAValue *(Attribute, Type *)>;
bool foldOperation(Operation *op,
SmallVectorImpl<SSAValue *> &existingConstants,
@ -60,9 +60,9 @@ bool ConstantFold::foldOperation(Operation *op,
// Check to see if each of the operands is a trivial constant. If so, get
// the value. If not, ignore the instruction.
SmallVector<Attribute *, 8> operandConstants;
SmallVector<Attribute, 8> operandConstants;
for (auto *operand : op->getOperands()) {
Attribute *operandCst = nullptr;
Attribute operandCst = nullptr;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
@ -71,7 +71,7 @@ bool ConstantFold::foldOperation(Operation *op,
}
// Attempt to constant fold the operation.
SmallVector<Attribute *, 8> resultConstants;
SmallVector<Attribute, 8> resultConstants;
if (op->constantFold(operandConstants, resultConstants))
return true;
@ -106,7 +106,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
for (auto instIt = bb.begin(), e = bb.end(); instIt != e;) {
auto &inst = *instIt++;
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
builder.setInsertionPoint(&inst);
return builder.create<ConstantOp>(inst.getLoc(), value, type);
};
@ -134,7 +134,7 @@ PassResult ConstantFold::runOnCFGFunction(CFGFunction *f) {
// Override the walker's operation statement visit for constant folding.
void ConstantFold::visitOperationStmt(OperationStmt *stmt) {
auto constantFactory = [&](Attribute *value, Type *type) -> SSAValue * {
auto constantFactory = [&](Attribute value, Type *type) -> SSAValue * {
MLFuncBuilder builder(stmt);
return builder.create<ConstantOp>(stmt->getLoc(), value, type);
};

View File

@ -71,8 +71,8 @@ void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
for (auto attr : opStmt->getAttrs()) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
MutableAffineMap mMap(mapAttr->getValue());
if (auto mapAttr = attr.second.dyn_cast<AffineMapAttr>()) {
MutableAffineMap mMap(mapAttr.getValue());
mMap.simplify();
auto map = mMap.getAffineMap();
opStmt->setAttr(attr.first, AffineMapAttr::get(map));

View File

@ -79,7 +79,7 @@ private:
/// As part of canonicalization, we move constants to the top of the entry
/// block of the current function and de-duplicate them. This keeps track of
/// constants we have done this for.
DenseMap<std::pair<Attribute *, Type *>, Operation *> uniquedConstants;
DenseMap<std::pair<Attribute, Type *>, Operation *> uniquedConstants;
};
}; // end anonymous namespace
@ -107,7 +107,7 @@ public:
void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
WorklistRewriter &rewriter) {
// These are scratch vectors used in the constant folding loop below.
SmallVector<Attribute *, 8> operandConstants, resultConstants;
SmallVector<Attribute, 8> operandConstants, resultConstants;
while (!worklist.empty()) {
auto *op = popFromWorklist();
@ -175,7 +175,7 @@ void GreedyPatternRewriteDriver::simplifyFunction(Function *currentFunction,
// the operation knows how to constant fold itself.
operandConstants.clear();
for (auto *operand : op->getOperands()) {
Attribute *operandCst = nullptr;
Attribute operandCst;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();

View File

@ -353,11 +353,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
// Check to see if each of the operands is the result of a constant. If so,
// get the value. If not, ignore it.
SmallVector<Attribute *, 8> operandConstants;
SmallVector<Attribute, 8> operandConstants;
auto boundOperands = lower ? forStmt->getLowerBoundOperands()
: forStmt->getUpperBoundOperands();
for (const auto *operand : boundOperands) {
Attribute *operandCst = nullptr;
Attribute operandCst;
if (auto *operandOp = operand->getDefiningOperation()) {
if (auto operandConstantOp = operandOp->dyn_cast<ConstantOp>())
operandCst = operandConstantOp->getValue();
@ -369,15 +369,15 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
lower ? forStmt->getLowerBoundMap() : forStmt->getUpperBoundMap();
assert(boundMap.getNumResults() >= 1 &&
"bound maps should have at least one result");
SmallVector<Attribute *, 4> foldedResults;
SmallVector<Attribute, 4> foldedResults;
if (boundMap.constantFold(operandConstants, foldedResults))
return true;
// Compute the max or min as applicable over the results.
assert(!foldedResults.empty() && "bounds should have at least one result");
auto maxOrMin = cast<IntegerAttr>(foldedResults[0])->getValue();
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1; i < foldedResults.size(); i++) {
auto foldedResult = cast<IntegerAttr>(foldedResults[i])->getValue();
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
: std::min(maxOrMin, foldedResult);
}

View File

@ -154,7 +154,7 @@ void OpEmitter::emitAttrGetters() {
<< val.getName() << "() const {\n";
os << " return this->getAttrOfType<"
<< attr.getValueAsString("AttrType").trim() << ">(\"" << name
<< "\")->getValue();\n }\n";
<< "\").getValue();\n }\n";
}
}
@ -207,9 +207,9 @@ void OpEmitter::emitVerifier() {
// Verify the attributes have the correct type.
for (const auto attr : attrs) {
auto name = attr.first->getName();
os << " if (!dyn_cast_or_null<"
<< attr.second->getValueAsString("AttrType") << ">(this->getAttr(\""
<< name << "\"))) return emitOpError(\"requires "
os << " if (!this->getAttr(\"" << name << "\").dyn_cast_or_null<"
<< attr.second->getValueAsString("AttrType") << ">("
<< ")) return emitOpError(\"requires "
<< attr.second->getValueAsString("PrimitiveType").trim()
<< " attribute '" << name << "'\");\n";
}