Implement initial support for dialect specific types.

Dialect specific types are registered similarly to operations, i.e. registerType<...> within the dialect. Unlike operations, there is no notion of a "verbose" type, that is *all* types must be registered to a dialect. Casting support(isa/dyn_cast/etc.) is implemented by reserving a range of type kinds in the top level Type class as opposed to string comparison like operations.

To support derived types a few hooks need to be implemented:

In the concrete type class:
    - static char typeID;
      * A unique identifier for the type used during registration.

In the Dialect:
    - typeParseHook and typePrintHook must be implemented to provide parser support.

The syntax for dialect extended types is as follows:
 dialect-type:  '!' dialect-namespace '<' '"' type-specific-data '"' '>'

The 'type-specific-data' is information used to identify different types within the dialect, e.g:
 - !tf<"variant"> // Tensor Flow Variant Type
 - !tf<"string">  // Tensor Flow String Type

TensorFlow/TensorFlowControl types are now implemented as dialect specific types as a proof
 of concept.

PiperOrigin-RevId: 227580052
This commit is contained in:
River Riddle 2019-01-02 14:16:40 -08:00 committed by jpienaar
parent 0c4ee54198
commit 8abc06f3d5
14 changed files with 302 additions and 213 deletions

View File

@ -76,14 +76,6 @@ public:
IndexType getIndexType();
OtherType getTFControlType();
OtherType getTFStringType();
OtherType getTFResourceType();
OtherType getTFVariantType();
OtherType getTFComplex64Type();
OtherType getTFComplex128Type();
OtherType getTFF32REFType();
IntegerType getI1Type();
IntegerType getIntegerType(unsigned width);
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
@ -94,6 +86,11 @@ public:
RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
UnrankedTensorType getTensorType(Type elementType);
/// Get or construct an instance of the type 'ty' with provided arguments.
template <typename Ty, typename... Args> Ty getType(Args... args) {
return Ty::get(context, args...);
}
// Attributes.
BoolAttr getBoolAttr(bool value);
IntegerAttr getIntegerAttr(Type type, int64_t value);

View File

@ -25,9 +25,13 @@
#include "mlir/IR/OperationSupport.h"
namespace mlir {
class Type;
using DialectConstantFoldHook = std::function<bool(
const OperationInst *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
using DialectTypeParserHook =
std::function<Type(StringRef, Location, MLIRContext *)>;
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
/// Dialects are groups of MLIR operations and behavior associated with the
/// entire group. For example, hooks into other systems for constant folding,
@ -53,9 +57,13 @@ public:
[](const OperationInst *op, ArrayRef<Attribute> operands,
SmallVectorImpl<Attribute> &results) { return true; };
// TODO: Hook to return the list of named types that are known.
/// Registered parsing/printing hooks for types registered to the dialect.
DialectTypeParserHook typeParseHook = nullptr;
/// Note: The data printed for the provided type must not include any '"'
/// characters.
DialectTypePrinterHook typePrintHook = nullptr;
// TODO: Hook to return list of dialect defined types, like tf_control.
// TODO: Hook to return the list of named types that are known.
virtual ~Dialect();
@ -95,6 +103,31 @@ protected:
void addOperation(AbstractOperation opInfo);
/// This method is used by derived classes to add their types to the set.
template <typename... Args> void addTypes() {
VariadicTypeAdder<Args...>::addToSet(*this);
}
// It would be nice to define this as variadic functions instead of a nested
// variadic type, but we can't do that: function template partial
// specialization is not allowed, and we can't define an overload set
// because we don't have any arguments of the types we are pushing around.
template <typename First, typename... Rest> class VariadicTypeAdder {
public:
static void addToSet(Dialect &dialect) {
VariadicTypeAdder<First>::addToSet(dialect);
VariadicTypeAdder<Rest...>::addToSet(dialect);
}
};
template <typename First> class VariadicTypeAdder<First> {
public:
static void addToSet(Dialect &dialect) { dialect.addType(&First::typeID); }
};
// Register a type with its given unqiue type identifer.
void addType(const void *const typeID);
private:
Dialect(const Dialect &) = delete;
void operator=(Dialect &) = delete;

View File

@ -0,0 +1,26 @@
//===- DialectTypeRegistry.def - MLIR Dialect Type Registry -----*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file enumerates the different dialects that define custom classes
// within the type system.
//
//===----------------------------------------------------------------------===//
DEFINE_TYPE_KIND_RANGE(TENSORFLOW_CONTROL)
DEFINE_TYPE_KIND_RANGE(TENSORFLOW)
#undef DEFINE_TYPE_KIND_RANGE

View File

@ -1,4 +1,4 @@
//===- TypeSupport.h -------------------------------------------*- C++ -*-===//
//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
@ -15,7 +15,7 @@
// limitations under the License.
// =============================================================================
//
// This file defines support types for the type system.
// This file defines support types for registering dialect extended types.
//
//===----------------------------------------------------------------------===//
@ -28,6 +28,7 @@
#include <memory>
namespace mlir {
class Dialect;
class MLIRContext;
//===----------------------------------------------------------------------===//
@ -46,48 +47,45 @@ protected:
/// This constructor is used by derived classes as part of the TypeUniquer.
/// When using this constructor, the initializeTypeInfo function must be
/// invoked afterwards for the storage to be valid.
TypeStorage(unsigned subclassData = 0) : kind(0), context(nullptr) {
setSubclassData(subclassData);
}
TypeStorage(unsigned subclassData = 0)
: dialect(nullptr), kind(0), subclassData(subclassData) {}
public:
/// Get the dialect that this type is registered to.
const Dialect &getDialect() const {
assert(dialect && "Malformed type storage object.");
return *dialect;
}
/// Get the kind classification of this type.
unsigned getKind() const { return kind; }
/// Get the context this type storage was uniqued in.
MLIRContext *getContext() const { return context; }
/// Get the subclass data.
unsigned getSubclassData() const { return subclassData; }
/// Set the subclass data for this type. The value provided must fit within
/// the bitsize of the subclass data.
void setSubclassData(unsigned val) {
subclassData = val;
// Ensure we don't have any accidental truncation.
assert(getSubclassData() == val && "Subclass data too large for field");
}
/// Set the subclass data.
void setSubclassData(unsigned val) { subclassData = val; }
private:
// Constructor used for simple type storage that have no subclass data. This
// constructor should not be used by derived storage classes.
TypeStorage(unsigned kind, MLIRContext *ctx)
: kind(kind), context(ctx), subclassData(0) {}
TypeStorage(const Dialect &dialect, unsigned kind)
: dialect(&dialect), kind(kind), subclassData(0) {}
// Initialize an existing type storage with a kind and a context. This is used
// by the TypeUniquer when initializing a newly constructed derived type
// storage object.
void initializeTypeInfo(unsigned newKind, MLIRContext *ctx) {
void initializeTypeInfo(const Dialect &newDialect, unsigned newKind) {
dialect = &newDialect;
kind = newKind;
context = ctx;
}
/// The registered information for the current type.
const Dialect *dialect;
/// Classification of the subclass, used for type checking.
unsigned kind;
/// The context the storage was uniqued in.
MLIRContext *context;
/// Space for subclasses to store data.
unsigned subclassData;
};
@ -179,11 +177,14 @@ public:
if (storage)
return T(storage);
// Get the dialect this type was registered to.
auto &dialect = lookupDialectForType<T>();
// Otherwise, construct and initialize the derived storage for this type
// instance.
TypeStorageAllocator allocator(ctx);
storage = ImplType::construct(allocator, args...);
storage->initializeTypeInfo(kind, ctx);
storage->initializeTypeInfo(dialect, kind);
// Insert the new type storage instance into the context.
insert(hashValue, storage);
@ -198,13 +199,22 @@ public:
std::is_same<typename T::ImplType, detail::DefaultTypeStorage>::value,
T>::type
get(unsigned kind) {
return T(getSimple(kind));
auto &dialect = lookupDialectForType<T>();
return T(getSimple(dialect, kind));
}
private:
/// Get the dialect that the type 'T' was registered with.
template <typename T> const Dialect &lookupDialectForType() {
return lookupDialectForType(&T::typeID);
}
/// Get the dialect that registered the type with the provided typeid.
const Dialect &lookupDialectForType(const void *const typeID);
/// Get or create a uniqued type by its kind. This overload is used for
/// simple types that are only uniqued by kind.
detail::TypeStorage *getSimple(unsigned kind);
detail::TypeStorage *getSimple(const Dialect &dialect, unsigned kind);
/// Utilities for generating a derived storage key.
/// Overload for if the key can be directly constructed from the provided

View File

@ -18,6 +18,7 @@
#ifndef MLIR_IR_TYPES_H
#define MLIR_IR_TYPES_H
#include "mlir/IR/TypeSupport.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
@ -33,21 +34,9 @@ class IndexType;
class IntegerType;
class Location;
class MLIRContext;
class OtherType;
namespace detail {
struct TypeStorage;
struct IntegerTypeStorage;
struct FunctionTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
/// Instances of the Type class are immutable and uniqued. They wrap a pointer
@ -65,6 +54,8 @@ struct MemRefTypeStorage;
/// Derived type classes are expected to implement several required
/// implementaiton hooks:
/// * Required:
/// - static char typeID;
/// * A unique identifier for this type used during registration.
///
/// - static bool kindof(unsigned kind);
/// * Returns if the provided type kind corresponds to an instance of the
@ -100,7 +91,11 @@ struct MemRefTypeStorage;
class Type {
public:
/// Integer identifier for all the concrete type kinds.
enum class Kind {
/// Note: This is not an enum class as each dialect will likely define a
/// separate enumeration for the specific types that they define. Not being an
/// enum class also simplifies the handling of type kinds by not requiring
/// casts for each use.
enum Kind {
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
@ -132,6 +127,13 @@ public:
RankedTensor,
UnrankedTensor,
MemRef,
LAST_BUILTIN_TYPE = 0xff,
// Reserve type kinds for dialect specific type system extensions.
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
#include "DialectTypeRegistry.def"
};
using ImplType = detail::TypeStorage;
@ -158,25 +160,21 @@ public:
template <typename U> U cast() const;
/// Return the classification for this type.
Kind getKind() const;
unsigned getKind() const;
/// Return the LLVMContext in which this type was uniqued.
MLIRContext *getContext() const;
// Convenience predicates. This is only for 'other' and floating point types,
/// Get the dialect this type is registered to.
const Dialect &getDialect() const;
// Convenience predicates. This is only for floating point types,
// derived types should use isa/dyn_cast.
bool isIndex() const { return getKind() == Kind::Index; }
bool isTFControl() const { return getKind() == Kind::TFControl; }
bool isTFResource() const { return getKind() == Kind::TFResource; }
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
bool isTFComplex64() const { return getKind() == Kind::TFComplex64; }
bool isTFComplex128() const { return getKind() == Kind::TFComplex128; }
bool isTFF32REF() const { return getKind() == Kind::TFF32REF; }
bool isTFString() const { return getKind() == Kind::TFString; }
bool isBF16() const { return getKind() == Kind::BF16; }
bool isF16() const { return getKind() == Kind::F16; }
bool isF32() const { return getKind() == Kind::F32; }
bool isF64() const { return getKind() == Kind::F64; }
bool isIndex() const;
bool isBF16() const;
bool isF16() const;
bool isF32() const;
bool isF64() const;
/// Return true if this is an integer type with the specified width.
bool isInteger(unsigned width) const;
@ -199,13 +197,6 @@ public:
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static OtherType getTFControl(MLIRContext *ctx);
static OtherType getTFString(MLIRContext *ctx);
static OtherType getTFResource(MLIRContext *ctx);
static OtherType getTFVariant(MLIRContext *ctx);
static OtherType getTFComplex64(MLIRContext *ctx);
static OtherType getTFComplex128(MLIRContext *ctx);
static OtherType getTFF32REF(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
@ -233,6 +224,27 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) {
return os;
}
/// Standard Type Utilities.
namespace detail {
struct IntegerTypeStorage;
struct FunctionTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
inline bool Type::isBF16() const { return getKind() == Kind::BF16; }
inline bool Type::isF16() const { return getKind() == Kind::F16; }
inline bool Type::isF32() const { return getKind() == Kind::F32; }
inline bool Type::isF64() const { return getKind() == Kind::F64; }
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType : public Type {
public:
@ -254,7 +266,10 @@ public:
unsigned getWidth() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Integer; }
static bool kindof(unsigned kind) { return kind == Kind::Integer; }
/// Unique identifier for this type class.
static char typeID;
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
@ -290,7 +305,7 @@ public:
static FloatType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
static bool kindof(unsigned kind) {
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE;
}
@ -300,6 +315,9 @@ public:
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics() const;
/// Unique identifier for this type class.
static char typeID;
};
inline FloatType Type::getBF16(MLIRContext *ctx) {
@ -325,46 +343,15 @@ public:
static IndexType get(MLIRContext *context);
/// Support method to enable LLVM-style type casting.
static bool kindof(Kind kind) { return kind == Kind::Index; }
};
static bool kindof(unsigned kind) { return kind == Kind::Index; }
/// This is a type for the random collection of special base types.
class OtherType : public Type {
public:
using Type::Type;
static OtherType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
}
/// Unique identifier for this type class.
static char typeID;
};
inline IndexType Type::getIndex(MLIRContext *ctx) {
return IndexType::get(ctx);
}
inline OtherType Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
}
inline OtherType Type::getTFResource(MLIRContext *ctx) {
return OtherType::get(Kind::TFResource, ctx);
}
inline OtherType Type::getTFString(MLIRContext *ctx) {
return OtherType::get(Kind::TFString, ctx);
}
inline OtherType Type::getTFVariant(MLIRContext *ctx) {
return OtherType::get(Kind::TFVariant, ctx);
}
inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex64, ctx);
}
inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
return OtherType::get(Kind::TFComplex128, ctx);
}
inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
return OtherType::get(Kind::TFF32REF, ctx);
}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
@ -390,7 +377,10 @@ public:
ArrayRef<Type> getResults() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Function; }
static bool kindof(unsigned kind) { return kind == Kind::Function; }
/// Unique identifier for this type class.
static char typeID;
};
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
@ -438,7 +428,7 @@ public:
long getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
static bool kindof(unsigned kind) {
return kind == Kind::Vector || kind == Kind::RankedTensor ||
kind == Kind::UnrankedTensor;
}
@ -469,7 +459,10 @@ public:
ArrayRef<int> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Vector; }
static bool kindof(unsigned kind) { return kind == Kind::Vector; }
/// Unique identifier for this type class.
static char typeID;
};
/// Tensor types represent multi-dimensional arrays, and have two variants:
@ -479,20 +472,19 @@ public:
using ImplType = detail::TensorTypeStorage;
using VectorOrTensorType::VectorOrTensorType;
/// Return true if the specified element type is a TensorFlow type that is ok
/// in a tensor.
static bool isValidTFElementType(Type type) {
return type.isa<FloatType>() || type.isa<IntegerType>() ||
type.isa<OtherType>();
}
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type) {
return isValidTFElementType(type) || type.isa<VectorType>();
// TODO(riverriddle): TensorFlow types are currently considered valid for
// legacy reasons.
return type.isIntOrFloat() || type.isa<VectorType>() ||
(type.getKind() >=
static_cast<unsigned>(Kind::FIRST_TENSORFLOW_TYPE) &&
type.getKind() <=
static_cast<unsigned>(Kind::LAST_TENSORFLOW_TYPE));
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) {
static bool kindof(unsigned kind) {
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
}
};
@ -518,7 +510,10 @@ public:
ArrayRef<int> getShape() const;
static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
static bool kindof(unsigned kind) { return kind == Kind::RankedTensor; }
/// Unique identifier for this type class.
static char typeID;
};
/// Unranked tensor types represent multi-dimensional arrays that have an
@ -540,7 +535,10 @@ public:
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
static bool kindof(unsigned kind) { return kind == Kind::UnrankedTensor; }
/// Unique identifier for this type class.
static char typeID;
};
/// MemRef types represent a region of memory that have a shape with a fixed
@ -591,7 +589,10 @@ public:
/// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const;
static bool kindof(Kind kind) { return kind == Kind::MemRef; }
static bool kindof(unsigned kind) { return kind == Kind::MemRef; }
/// Unique identifier for this type class.
static char typeID;
private:
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
@ -651,4 +652,4 @@ public:
} // namespace llvm
#endif // MLIR_IR_TYPES_H
#endif // MLIR_IR_TYPES_H

View File

@ -486,6 +486,14 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
void ModulePrinter::printType(Type type) {
switch (type.getKind()) {
default: {
auto &dialect = type.getDialect();
os << "!" << dialect.getNamespace() << "<\"";
assert(dialect.typePrintHook && "Expected dialect type printing hook.");
dialect.typePrintHook(type, os);
os << "\">";
return;
}
case Type::Kind::Index:
os << "index";
return;
@ -501,27 +509,6 @@ void ModulePrinter::printType(Type type) {
case Type::Kind::F64:
os << "f64";
return;
case Type::Kind::TFControl:
os << "tf_control";
return;
case Type::Kind::TFResource:
os << "tf_resource";
return;
case Type::Kind::TFVariant:
os << "tf_variant";
return;
case Type::Kind::TFComplex64:
os << "tf_complex64";
return;
case Type::Kind::TFComplex128:
os << "tf_complex128";
return;
case Type::Kind::TFF32REF:
os << "tf_f32ref";
return;
case Type::Kind::TFString:
os << "tf_string";
return;
case Type::Kind::Integer: {
auto integer = type.cast<IntegerType>();

View File

@ -66,24 +66,6 @@ FloatType Builder::getF64Type() { return Type::getF64(context); }
IndexType Builder::getIndexType() { return Type::getIndex(context); }
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
OtherType Builder::getTFComplex64Type() {
return Type::getTFComplex64(context);
}
OtherType Builder::getTFComplex128Type() {
return Type::getTFComplex128(context);
}
OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
IntegerType Builder::getIntegerType(unsigned width) {

View File

@ -34,6 +34,8 @@ using namespace mlir;
BuiltinDialect::BuiltinDialect(MLIRContext *context)
: Dialect(/*namePrefix=*/"", context) {
addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
addTypes<IndexType, FloatType, IntegerType, FunctionType, VectorType,
RankedTensorType, UnrankedTensorType, MemRefType>();
}
void mlir::printDimAndSymbolList(OperationInst::const_operand_iterator begin,
@ -360,12 +362,6 @@ bool ConstantOp::verify() const {
return false;
}
if (type.isTFString()) {
if (!value.isa<StringAttr>())
return emitOpError("requires 'value' to be a string constant");
return false;
}
if (type.isa<FunctionType>()) {
if (!value.isa<FunctionAttr>())
return emitOpError("requires 'value' to be a function reference");

View File

@ -354,6 +354,9 @@ public:
/// operations.
StringMap<AbstractOperation> registeredOperations;
/// This is a mapping from type identifier to Dialect for registered types.
DenseMap<const void *, Dialect *> registeredTypes;
/// These are identifiers uniqued into this MLIRContext.
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
@ -557,6 +560,16 @@ void Dialect::addOperation(AbstractOperation opInfo) {
}
}
/// Register a dialect-specific type with the current context.
void Dialect::addType(const void *const typeID) {
auto &impl = context->getImpl();
if (impl.registeredTypes.count(typeID)) {
llvm::errs() << "error: type already registered.\n";
abort();
}
impl.registeredTypes.try_emplace(typeID, this);
}
/// Look up the specified operation in the operation set and return a pointer
/// to it if present. Otherwise, return a null pointer.
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
@ -717,7 +730,7 @@ llvm::BumpPtrAllocator &TypeStorageAllocator::getAllocator() {
/// Get or create a uniqued type by it's kind. This overload is used for
/// simple types that are only uniqued by kind.
TypeStorage *TypeUniquer::getSimple(unsigned kind) {
TypeStorage *TypeUniquer::getSimple(const Dialect &dialect, unsigned kind) {
auto &impl = ctx->getImpl();
// Check for an existing instance with this kind.
@ -727,7 +740,14 @@ TypeStorage *TypeUniquer::getSimple(unsigned kind) {
// Otherwise, create a new instance and return it.
result = impl.allocator.Allocate<DefaultTypeStorage>();
return new (result) DefaultTypeStorage{kind, ctx};
return new (result) DefaultTypeStorage{dialect, kind};
}
/// Get the dialect that registered the type with the provided typeid.
const Dialect &TypeUniquer::lookupDialectForType(const void *const typeID) {
auto &impl = ctx->getImpl();
assert(impl.registeredTypes.count(typeID) && "typeID is not registered.");
return *impl.registeredTypes[typeID];
}
/// Look up a uniqued type with a lookup key. This is used if the type defines
@ -757,10 +777,6 @@ IndexType IndexType::get(MLIRContext *context) {
return constructUniqueType<IndexType>(context, Kind::Index);
}
OtherType OtherType::get(Kind kind, MLIRContext *context) {
return constructUniqueType<OtherType>(context, kind);
}
static IntegerType getIntegerType(unsigned width, MLIRContext *context,
llvm::Optional<Location> location) {
if (width > IntegerType::kMaxWidth) {

View File

@ -18,6 +18,7 @@
#include "mlir/IR/Types.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/raw_ostream.h"
@ -25,11 +26,12 @@
using namespace mlir;
using namespace mlir::detail;
Type::Kind Type::getKind() const {
return static_cast<Type::Kind>(type->getKind());
}
unsigned Type::getKind() const { return type->getKind(); }
MLIRContext *Type::getContext() const { return type->getContext(); }
/// Get the dialect this type is registered to.
const Dialect &Type::getDialect() const { return type->getDialect(); }
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
@ -207,3 +209,13 @@ unsigned MemRefType::getNumDynamicDims() const {
}
return numDynamicDims;
}
// Define type identifiers.
char IndexType::typeID = 0;
char FloatType::typeID = 0;
char IntegerType::typeID = 0;
char FunctionType::typeID = 0;
char VectorType::typeID = 0;
char RankedTensorType::typeID = 0;
char UnrankedTensorType::typeID = 0;
char MemRefType::typeID = 0;

View File

@ -125,6 +125,8 @@ Token Lexer::lexToken() {
case '@':
return lexAtIdentifier(tokStart);
case '!':
LLVM_FALLTHROUGH;
case '^':
LLVM_FALLTHROUGH;
case '#':
@ -237,6 +239,10 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
kind = Token::caret_identifier;
errorKind = "invalid block name";
break;
case '!':
kind = Token::exclamation_identifier;
errorKind = "invalid dialect type namespace";
break;
default:
llvm_unreachable("invalid caller");
}

View File

@ -181,6 +181,7 @@ public:
VectorType parseVectorType();
ParseResult parseXInDimensionList();
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
Type parseDialectType();
Type parseTensorType();
Type parseMemRefType();
Type parseFunctionType();
@ -286,7 +287,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
/// type ::= integer-type
/// | index-type
/// | float-type
/// | other-type
/// | dialect-type
/// | vector-type
/// | tensor-type
/// | memref-type
@ -294,7 +295,6 @@ ParseResult Parser::parseCommaSeparatedListUntil(
///
/// index-type ::= `index`
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
/// other-type ::= `tf_control`
///
Type Parser::parseType() {
switch (getToken().getKind()) {
@ -337,25 +337,9 @@ Type Parser::parseType() {
consumeToken(Token::kw_index);
return builder.getIndexType();
// other-type
case Token::kw_tf_control:
consumeToken(Token::kw_tf_control);
return builder.getTFControlType();
case Token::kw_tf_resource:
consumeToken(Token::kw_tf_resource);
return builder.getTFResourceType();
case Token::kw_tf_variant:
consumeToken(Token::kw_tf_variant);
return builder.getTFVariantType();
case Token::kw_tf_complex64:
consumeToken(Token::kw_tf_complex64);
return builder.getTFComplex64Type();
case Token::kw_tf_complex128:
consumeToken(Token::kw_tf_complex128);
return builder.getTFComplex128Type();
case Token::kw_tf_string:
consumeToken(Token::kw_tf_string);
return builder.getTFStringType();
// dialect-specific type
case Token::exclamation_identifier:
return parseDialectType();
}
}
@ -450,6 +434,51 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
return ParseSuccess;
}
/// Parse a dialect-specific type.
///
/// dialect-type ::= `!` dialect-namespace `<` '"' type-data '"' `>`
///
Type Parser::parseDialectType() {
assert(getToken().is(Token::exclamation_identifier));
// Parse the dialect namespace.
StringRef dialectName = getTokenSpelling().drop_front();
consumeToken(Token::exclamation_identifier);
auto *dialect = state.context->getRegisteredDialect(dialectName);
if (!dialect)
return (emitError("no registered dialect with namespace: " + dialectName),
nullptr);
// Make sure that the dialect provides a parsing hook.
if (!dialect->typeParseHook)
return (emitError("dialect '" + dialect->getNamespace() +
"' provides no type parsing hook"),
nullptr);
// Consume the '<'.
if (parseToken(Token::less, "expected '<' in dialect type"))
return nullptr;
// Parse the type specific data.
if (getToken().isNot(Token::string))
return (emitError("expected string literal type data in dialect type"),
nullptr);
auto typeData = getToken().getStringValue();
auto loc = getEncodedSourceLocation(getToken().getLoc());
consumeToken(Token::string);
Type result = dialect->typeParseHook(typeData, loc, state.context);
if (!result)
return nullptr;
// Consume the '>'.
if (parseToken(Token::greater, "expected '>' in dialect type"))
return nullptr;
return result;
}
/// Parse a tensor type.
///
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`

View File

@ -50,11 +50,12 @@ TOK_MARKER(eof)
TOK_MARKER(error)
// Identifiers.
TOK_IDENTIFIER(bare_identifier) // foo
TOK_IDENTIFIER(at_identifier) // @foo
TOK_IDENTIFIER(hash_identifier) // #foo
TOK_IDENTIFIER(percent_identifier) // %foo
TOK_IDENTIFIER(caret_identifier) // ^foo
TOK_IDENTIFIER(bare_identifier) // foo
TOK_IDENTIFIER(at_identifier) // @foo
TOK_IDENTIFIER(hash_identifier) // #foo
TOK_IDENTIFIER(percent_identifier) // %foo
TOK_IDENTIFIER(caret_identifier) // ^foo
TOK_IDENTIFIER(exclamation_identifier) // !foo
// Literals
TOK_LITERAL(floatliteral) // 2.0
@ -109,13 +110,6 @@ TOK_KEYWORD(opaque)
TOK_KEYWORD(size)
TOK_KEYWORD(step)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf_control)
TOK_KEYWORD(tf_resource)
TOK_KEYWORD(tf_variant)
TOK_KEYWORD(tf_complex64)
TOK_KEYWORD(tf_complex128)
TOK_KEYWORD(tf_string)
TOK_KEYWORD(tf_f32ref)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
TOK_KEYWORD(sparse)

View File

@ -537,10 +537,10 @@ func @opaquetensorattr() -> () {
// CHECK: "opaqueFloatTensor"() {bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
"opaqueFloatTensor"(){bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
return
}