[mlir] Add support for defining Traits and Interfaces on Attributes/Types.

This revisions add mechanisms to Attribute/Type for attaching traits and interfaces. The mechanisms are modeled 1-1 after those for operations to keep the system consistent. AttrBase and TypeBase now accepts a trailing list of `Trait` types that will be attached to the object. These traits should inherit from AttributeTrait::TraitBase and TypeTrait::TraitBase respectively as necessary. A followup commit will refactor the interface gen mechanisms in ODS to support Attribute/Type interface generation and add tests for the mechanisms.

Differential Revision: https://reviews.llvm.org/D81883
This commit is contained in:
River Riddle 2020-06-30 15:42:39 -07:00
parent ffa63dde8e
commit 9fbb2de8e4
9 changed files with 305 additions and 82 deletions

View File

@ -21,6 +21,45 @@ namespace mlir {
class MLIRContext;
class Type;
//===----------------------------------------------------------------------===//
// AbstractAttribute
//===----------------------------------------------------------------------===//
/// This class contains all of the static information common to all instances of
/// a registered Attribute.
class AbstractAttribute {
public:
/// Look up the specified abstract attribute in the MLIRContext and return a
/// reference to it.
static const AbstractAttribute &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// attributes they contain.
template <typename T> static AbstractAttribute get(Dialect &dialect) {
return AbstractAttribute(dialect, T::getInterfaceMap());
}
/// Return the dialect this attribute was registered to.
Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
/// Returns an instance of the concept object for the given interface if it
/// was registered to this attribute, null otherwise. This should not be used
/// directly.
template <typename T> typename T::Concept *getInterface() const {
return interfaceMap.lookup<T>();
}
private:
AbstractAttribute(Dialect &dialect, detail::InterfaceMap &&interfaceMap)
: dialect(dialect), interfaceMap(std::move(interfaceMap)) {}
/// This is the dialect that this attribute was registered to.
Dialect &dialect;
/// This is a collection of the interfaces registered to this attribute.
detail::InterfaceMap interfaceMap;
};
//===----------------------------------------------------------------------===//
// AttributeStorage
//===----------------------------------------------------------------------===//
@ -39,10 +78,10 @@ public:
/// Get the type of this attribute.
Type getType() const;
/// Get the dialect of this attribute.
Dialect &getDialect() const {
assert(dialect && "Malformed attribute storage object.");
return const_cast<Dialect &>(*dialect);
/// Return the abstract descriptor for this attribute.
const AbstractAttribute &getAbstractAttribute() const {
assert(abstractAttribute && "Malformed attribute storage object.");
return *abstractAttribute;
}
protected:
@ -56,13 +95,15 @@ protected:
/// Set the type of this attribute.
void setType(Type type);
// Set the dialect for this storage instance. This is used by the
// Set the abstract attribute for this storage instance. This is used by the
// AttributeUniquer when initializing a newly constructed storage object.
void initializeDialect(Dialect &newDialect) { dialect = &newDialect; }
void initialize(const AbstractAttribute &abstractAttr) {
abstractAttribute = &abstractAttr;
}
private:
/// The dialect for this attribute.
Dialect *dialect;
/// The abstract descriptor for this attribute.
const AbstractAttribute *abstractAttribute;
/// The opaque type of the attribute value.
const void *type;

View File

@ -64,9 +64,10 @@ public:
/// Utility class for implementing attributes.
template <typename ConcreteType, typename BaseType = Attribute,
typename StorageType = AttributeStorage>
typename StorageType = AttributeStorage,
template <typename T> class... Traits>
using AttrBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
detail::AttributeUniquer>;
detail::AttributeUniquer, Traits...>;
using ImplType = AttributeStorage;
using ValueType = void;
@ -119,6 +120,11 @@ public:
friend ::llvm::hash_code hash_value(Attribute arg);
/// Return the abstract descriptor for this attribute.
const AbstractAttribute &getAbstractAttribute() const {
return impl->getAbstractAttribute();
}
protected:
ImplType *impl;
};
@ -128,6 +134,46 @@ inline raw_ostream &operator<<(raw_ostream &os, Attribute attr) {
return os;
}
//===----------------------------------------------------------------------===//
// AttributeTraitBase
//===----------------------------------------------------------------------===//
namespace AttributeTrait {
/// This class represents the base of an attribute trait.
template <typename ConcreteType, template <typename> class TraitType>
using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
} // namespace AttributeTrait
//===----------------------------------------------------------------------===//
// AttributeInterface
//===----------------------------------------------------------------------===//
/// This class represents the base of an attribute interface. See the definition
/// of `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
class AttributeInterface
: public detail::Interface<ConcreteType, Attribute, Traits, Attribute,
AttributeTrait::TraitBase> {
public:
using Base = AttributeInterface<ConcreteType, Traits>;
using InterfaceBase = detail::Interface<ConcreteType, Type, Traits, Type,
AttributeTrait::TraitBase>;
using InterfaceBase::InterfaceBase;
private:
/// Returns the impl interface instance for the given type.
static typename InterfaceBase::Concept *getInterfaceFor(Attribute attr) {
return attr.getAbstractAttribute().getInterface<ConcreteType>();
}
/// Allow access to 'getInterfaceFor'.
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// StandardAttributes
//===----------------------------------------------------------------------===//
namespace StandardAttributes {
enum Kind {
AffineMap = Attribute::FIRST_STANDARD_ATTR,

View File

@ -190,13 +190,19 @@ protected:
/// This method is used by derived classes to add their types to the set.
template <typename... Args> void addTypes() {
(void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
(void)std::initializer_list<int>{
0, (addType(Args::getTypeID(), AbstractType::get<Args>(*this)), 0)...};
}
void addType(TypeID typeID, AbstractType &&typeInfo);
/// This method is used by derived classes to add their attributes to the set.
template <typename... Args> void addAttributes() {
(void)std::initializer_list<int>{0, (addSymbol(Args::getTypeID()), 0)...};
(void)std::initializer_list<int>{
0,
(addAttribute(Args::getTypeID(), AbstractAttribute::get<Args>(*this)),
0)...};
}
void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
/// Enable support for unregistered operations.
void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
@ -214,9 +220,6 @@ protected:
}
private:
// Register a symbol(e.g. type) with its given unique class identifier.
void addSymbol(TypeID typeID);
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;

View File

@ -13,6 +13,7 @@
#ifndef MLIR_IR_STORAGEUNIQUERSUPPORT_H
#define MLIR_IR_STORAGEUNIQUERSUPPORT_H
#include "mlir/Support/InterfaceSupport.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/StorageUniquer.h"
#include "mlir/Support/TypeID.h"
@ -27,17 +28,41 @@ namespace detail {
/// avoid the need to include Location.h.
const AttributeStorage *generateUnknownStorageLocation(MLIRContext *ctx);
//===----------------------------------------------------------------------===//
// StorageUserTraitBase
//===----------------------------------------------------------------------===//
/// Helper class for implementing traits for storage classes. Clients are not
/// expected to interact with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class StorageUserTraitBase {
protected:
/// Return the derived instance.
ConcreteType getInstance() const {
// We have to cast up to the trait type, then to the concrete type because
// the concrete type will multiply derive from the (content free) TraitBase
// class, and we need to be able to disambiguate the path for the C++
// compiler.
auto *trait = static_cast<const TraitType<ConcreteType> *>(this);
return *static_cast<const ConcreteType *>(trait);
}
};
//===----------------------------------------------------------------------===//
// StorageUserBase
//===----------------------------------------------------------------------===//
/// Utility class for implementing users of storage classes uniqued by a
/// StorageUniquer. Clients are not expected to interact with this class
/// directly.
template <typename ConcreteT, typename BaseT, typename StorageT,
typename UniquerT>
class StorageUserBase : public BaseT {
typename UniquerT, template <typename T> class... Traits>
class StorageUserBase : public BaseT, public Traits<ConcreteT>... {
public:
using BaseT::BaseT;
/// Utility declarations for the concrete attribute class.
using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT>;
using Base = StorageUserBase<ConcreteT, BaseT, StorageT, UniquerT, Traits...>;
using ImplType = StorageT;
/// Return a unique identifier for the concrete type.
@ -51,6 +76,12 @@ public:
return ConcreteT::kindof(val.getKind());
}
/// Returns an interface map for the interfaces registered to this storage
/// user. This should not be used directly.
static detail::InterfaceMap getInterfaceMap() {
return detail::InterfaceMap::template get<Traits<ConcreteT>...>();
}
protected:
/// Get or create a new ConcreteT instance within the ctx. This
/// function is guaranteed to return a non null object and will assert if

View File

@ -20,12 +20,51 @@ namespace mlir {
class Dialect;
class MLIRContext;
//===----------------------------------------------------------------------===//
// AbstractType
//===----------------------------------------------------------------------===//
/// This class contains all of the static information common to all instances of
/// a registered Type.
class AbstractType {
public:
/// Look up the specified abstract type in the MLIRContext and return a
/// reference to it.
static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
/// This method is used by Dialect objects when they register the list of
/// types they contain.
template <typename T> static AbstractType get(Dialect &dialect) {
return AbstractType(dialect, T::getInterfaceMap());
}
/// Return the dialect this type was registered to.
Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
/// Returns an instance of the concept object for the given interface if it
/// was registered to this type, null otherwise. This should not be used
/// directly.
template <typename T> typename T::Concept *getInterface() const {
return interfaceMap.lookup<T>();
}
private:
AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap)
: dialect(dialect), interfaceMap(std::move(interfaceMap)) {}
/// This is the dialect that this type was registered to.
Dialect &dialect;
/// This is a collection of the interfaces registered to this type.
detail::InterfaceMap interfaceMap;
};
//===----------------------------------------------------------------------===//
// TypeStorage
//===----------------------------------------------------------------------===//
namespace detail {
class TypeUniquer;
struct TypeUniquer;
} // end namespace detail
/// Base storage class appearing in a Type.
@ -33,32 +72,33 @@ class TypeStorage : public StorageUniquer::BaseStorage {
friend detail::TypeUniquer;
friend StorageUniquer;
protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
/// When using this constructor, the initializeDialect function must be
/// invoked afterwards for the storage to be valid.
TypeStorage(unsigned subclassData = 0)
: dialect(nullptr), subclassData(subclassData) {}
public:
/// Get the dialect that this type is registered to.
Dialect &getDialect() {
assert(dialect && "Malformed type storage object.");
return *dialect;
/// Return the abstract type descriptor for this type.
const AbstractType &getAbstractType() {
assert(abstractType && "Malformed type storage object.");
return *abstractType;
}
/// Get the subclass data.
unsigned getSubclassData() const { return subclassData; }
/// Set the subclass data.
void setSubclassData(unsigned val) { subclassData = val; }
private:
// Set the dialect for this storage instance. This is used by the TypeUniquer
// when initializing a newly constructed type storage object.
void initializeDialect(Dialect &newDialect) { dialect = &newDialect; }
protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
TypeStorage(unsigned subclassData = 0)
: abstractType(nullptr), subclassData(subclassData) {}
/// The dialect for this type.
Dialect *dialect;
private:
/// Set the abstract type for this storage instance. This is used by the
/// TypeUniquer when initializing a newly constructed type storage object.
void initialize(const AbstractType &abstractTy) {
abstractType = &abstractTy;
}
/// The abstract description for this type.
const AbstractType *abstractType;
/// Space for subclasses to store data.
unsigned subclassData;
@ -72,36 +112,26 @@ using DefaultTypeStorage = TypeStorage;
// TypeStorageAllocator
//===----------------------------------------------------------------------===//
// This is a utility allocator used to allocate memory for instances of derived
// Types.
/// This is a utility allocator used to allocate memory for instances of derived
/// Types.
using TypeStorageAllocator = StorageUniquer::StorageAllocator;
//===----------------------------------------------------------------------===//
// TypeUniquer
//===----------------------------------------------------------------------===//
namespace detail {
// A utility class to get, or create, unique instances of types within an
// MLIRContext. This class manages all creation and uniquing of types.
class TypeUniquer {
public:
/// A utility class to get, or create, unique instances of types within an
/// MLIRContext. This class manages all creation and uniquing of types.
struct TypeUniquer {
/// Get an uniqued instance of a type T.
template <typename T, typename... Args>
static T get(MLIRContext *ctx, unsigned kind, Args &&... args) {
return ctx->getTypeUniquer().get<typename T::ImplType>(
[&](TypeStorage *storage) {
storage->initializeDialect(lookupDialectForType<T>(ctx));
storage->initialize(AbstractType::lookup(T::getTypeID(), ctx));
},
kind, std::forward<Args>(args)...);
}
private:
/// Get the dialect that the type 'T' was registered with.
template <typename T> static Dialect &lookupDialectForType(MLIRContext *ctx) {
return lookupDialectForType(ctx, T::getTypeID());
}
/// Get the dialect that registered the type with the provided typeid.
static Dialect &lookupDialectForType(MLIRContext *ctx, TypeID typeID);
};
} // namespace detail

View File

@ -101,9 +101,10 @@ public:
/// Utility class for implementing types.
template <typename ConcreteType, typename BaseType,
typename StorageType = DefaultTypeStorage>
typename StorageType = DefaultTypeStorage,
template <typename T> class... Traits>
using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
detail::TypeUniquer>;
detail::TypeUniquer, Traits...>;
using ImplType = TypeStorage;
@ -196,6 +197,9 @@ public:
return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
}
/// Return the abstract type descriptor for this type.
const AbstractType &getAbstractType() { return impl->getAbstractType(); }
protected:
ImplType *impl;
};
@ -205,6 +209,45 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) {
return os;
}
//===----------------------------------------------------------------------===//
// TypeTraitBase
//===----------------------------------------------------------------------===//
namespace TypeTrait {
/// This class represents the base of a type trait.
template <typename ConcreteType, template <typename> class TraitType>
using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
} // namespace TypeTrait
//===----------------------------------------------------------------------===//
// TypeInterface
//===----------------------------------------------------------------------===//
/// This class represents the base of a type interface. See the definition of
/// `detail::Interface` for requirements on the `Traits` type.
template <typename ConcreteType, typename Traits>
class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
TypeTrait::TraitBase> {
public:
using Base = TypeInterface<ConcreteType, Traits>;
using InterfaceBase =
detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
using InterfaceBase::InterfaceBase;
private:
/// Returns the impl interface instance for the given type.
static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
return type.getAbstractType().getInterface<ConcreteType>();
}
/// Allow access to 'getInterfaceFor'.
friend InterfaceBase;
};
//===----------------------------------------------------------------------===//
// FunctionType
//===----------------------------------------------------------------------===//
/// Function types map from a list of inputs to a list of results.
class FunctionType
: public Type::TypeBase<FunctionType, Type, detail::FunctionTypeStorage> {
@ -232,6 +275,10 @@ public:
static bool kindof(unsigned kind) { return kind == Kind::Function; }
};
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
/// Opaque types represent types of non-registered dialects. These are types
/// represented in their raw string form, and can only usefully be tested for
/// type equality.

View File

@ -47,7 +47,9 @@ Type Attribute::getType() const { return impl->getType(); }
MLIRContext *Attribute::getContext() const { return getType().getContext(); }
/// Get the dialect this attribute is registered to.
Dialect &Attribute::getDialect() const { return impl->getDialect(); }
Dialect &Attribute::getDialect() const {
return impl->getAbstractAttribute().getDialect();
}
//===----------------------------------------------------------------------===//
// AffineMapAttr

View File

@ -282,13 +282,12 @@ public:
/// operations.
llvm::StringMap<AbstractOperation> registeredOperations;
/// This is a mapping from type id to Dialect for registered attributes and
/// types.
DenseMap<TypeID, Dialect *> registeredDialectSymbols;
/// These are identifiers uniqued into this MLIRContext.
llvm::StringSet<llvm::BumpPtrAllocator &> identifiers;
/// An allocator used for AbstractAttribute and AbstractType objects.
llvm::BumpPtrAllocator abstractDialectSymbolAllocator;
//===--------------------------------------------------------------------===//
// Affine uniquing
//===--------------------------------------------------------------------===//
@ -311,6 +310,8 @@ public:
//===--------------------------------------------------------------------===//
// Type uniquing
//===--------------------------------------------------------------------===//
DenseMap<TypeID, const AbstractType *> registeredTypes;
StorageUniquer typeUniquer;
/// Cached Type Instances.
@ -322,6 +323,8 @@ public:
//===--------------------------------------------------------------------===//
// Attribute uniquing
//===--------------------------------------------------------------------===//
DenseMap<TypeID, const AbstractAttribute *> registeredAttributes;
StorageUniquer attributeUniquer;
/// Cached Attribute Instances.
@ -569,16 +572,39 @@ void Dialect::addOperation(AbstractOperation opInfo) {
}
}
/// Register a dialect-specific symbol(e.g. type) with the current context.
void Dialect::addSymbol(TypeID typeID) {
void Dialect::addType(TypeID typeID, AbstractType &&typeInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
if (!impl.registeredDialectSymbols.insert({typeID, this}).second) {
llvm::errs() << "error: dialect symbol already registered.\n";
abort();
}
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractType>())
AbstractType(std::move(typeInfo));
if (!impl.registeredTypes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Type already registered.");
}
void Dialect::addAttribute(TypeID typeID, AbstractAttribute &&attrInfo) {
auto &impl = context->getImpl();
// Lock access to the context registry.
ScopedWriterLock registryLock(impl.contextMutex, impl.threadingIsEnabled);
auto *newInfo =
new (impl.abstractDialectSymbolAllocator.Allocate<AbstractAttribute>())
AbstractAttribute(std::move(attrInfo));
if (!impl.registeredAttributes.insert({typeID, newInfo}).second)
llvm::report_fatal_error("Dialect Attribute already registered.");
}
/// Get the dialect that registered the attribute with the provided typeid.
const AbstractAttribute &AbstractAttribute::lookup(TypeID typeID,
MLIRContext *context) {
auto &impl = context->getImpl();
auto it = impl.registeredAttributes.find(typeID);
if (it == impl.registeredAttributes.end())
llvm::report_fatal_error("Trying to create an Attribute that was not "
"registered in this MLIRContext.");
return *it->second;
}
/// Look up the specified operation in the operation set and return a pointer
@ -595,6 +621,16 @@ const AbstractOperation *AbstractOperation::lookup(StringRef opName,
return nullptr;
}
/// Get the dialect that registered the type with the provided typeid.
const AbstractType &AbstractType::lookup(TypeID typeID, MLIRContext *context) {
auto &impl = context->getImpl();
auto it = impl.registeredTypes.find(typeID);
if (it == impl.registeredTypes.end())
llvm::report_fatal_error(
"Trying to create a Type that was not registered in this MLIRContext.");
return *it->second;
}
//===----------------------------------------------------------------------===//
// Identifier uniquing
//===----------------------------------------------------------------------===//
@ -628,24 +664,10 @@ Identifier Identifier::get(StringRef str, MLIRContext *context) {
// Type uniquing
//===----------------------------------------------------------------------===//
static Dialect &lookupDialectForSymbol(MLIRContext *ctx, TypeID typeID) {
auto &impl = ctx->getImpl();
auto it = impl.registeredDialectSymbols.find(typeID);
if (it == impl.registeredDialectSymbols.end())
llvm::report_fatal_error(
"Trying to create a type that was not registered in this MLIRContext.");
return *it->second;
}
/// Returns the storage uniquer used for constructing type storage instances.
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }
/// Get the dialect that registered the type with the provided typeid.
Dialect &TypeUniquer::lookupDialectForType(MLIRContext *ctx, TypeID typeID) {
return lookupDialectForSymbol(ctx, typeID);
}
FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
assert(kindof(kind) && "Not a FP kind.");
switch (kind) {
@ -738,7 +760,7 @@ StorageUniquer &MLIRContext::getAttributeUniquer() {
void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage,
MLIRContext *ctx,
TypeID attrID) {
storage->initializeDialect(lookupDialectForSymbol(ctx, attrID));
storage->initialize(AbstractAttribute::lookup(attrID, ctx));
// If the attribute did not provide a type, then default to NoneType.
if (!storage->getType())

View File

@ -21,8 +21,9 @@ using namespace mlir::detail;
unsigned Type::getKind() const { return impl->getKind(); }
/// Get the dialect this type is registered to.
Dialect &Type::getDialect() const { return impl->getDialect(); }
Dialect &Type::getDialect() const {
return impl->getAbstractType().getDialect();
}
MLIRContext *Type::getContext() const { return getDialect().getContext(); }