Implement initial support for dialect specific types.
Dialect specific types are registered similarly to operations, i.e. registerType<...> within the dialect. Unlike operations, there is no notion of a "verbose" type, that is *all* types must be registered to a dialect. Casting support(isa/dyn_cast/etc.) is implemented by reserving a range of type kinds in the top level Type class as opposed to string comparison like operations. To support derived types a few hooks need to be implemented: In the concrete type class: - static char typeID; * A unique identifier for the type used during registration. In the Dialect: - typeParseHook and typePrintHook must be implemented to provide parser support. The syntax for dialect extended types is as follows: dialect-type: '!' dialect-namespace '<' '"' type-specific-data '"' '>' The 'type-specific-data' is information used to identify different types within the dialect, e.g: - !tf<"variant"> // Tensor Flow Variant Type - !tf<"string"> // Tensor Flow String Type TensorFlow/TensorFlowControl types are now implemented as dialect specific types as a proof of concept. PiperOrigin-RevId: 227580052
This commit is contained in:
parent
0c4ee54198
commit
8abc06f3d5
|
@ -76,14 +76,6 @@ public:
|
|||
|
||||
IndexType getIndexType();
|
||||
|
||||
OtherType getTFControlType();
|
||||
OtherType getTFStringType();
|
||||
OtherType getTFResourceType();
|
||||
OtherType getTFVariantType();
|
||||
OtherType getTFComplex64Type();
|
||||
OtherType getTFComplex128Type();
|
||||
OtherType getTFF32REFType();
|
||||
|
||||
IntegerType getI1Type();
|
||||
IntegerType getIntegerType(unsigned width);
|
||||
FunctionType getFunctionType(ArrayRef<Type> inputs, ArrayRef<Type> results);
|
||||
|
@ -94,6 +86,11 @@ public:
|
|||
RankedTensorType getTensorType(ArrayRef<int> shape, Type elementType);
|
||||
UnrankedTensorType getTensorType(Type elementType);
|
||||
|
||||
/// Get or construct an instance of the type 'ty' with provided arguments.
|
||||
template <typename Ty, typename... Args> Ty getType(Args... args) {
|
||||
return Ty::get(context, args...);
|
||||
}
|
||||
|
||||
// Attributes.
|
||||
BoolAttr getBoolAttr(bool value);
|
||||
IntegerAttr getIntegerAttr(Type type, int64_t value);
|
||||
|
|
|
@ -25,9 +25,13 @@
|
|||
#include "mlir/IR/OperationSupport.h"
|
||||
|
||||
namespace mlir {
|
||||
class Type;
|
||||
|
||||
using DialectConstantFoldHook = std::function<bool(
|
||||
const OperationInst *, ArrayRef<Attribute>, SmallVectorImpl<Attribute> &)>;
|
||||
using DialectTypeParserHook =
|
||||
std::function<Type(StringRef, Location, MLIRContext *)>;
|
||||
using DialectTypePrinterHook = std::function<void(Type, raw_ostream &)>;
|
||||
|
||||
/// Dialects are groups of MLIR operations and behavior associated with the
|
||||
/// entire group. For example, hooks into other systems for constant folding,
|
||||
|
@ -53,9 +57,13 @@ public:
|
|||
[](const OperationInst *op, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<Attribute> &results) { return true; };
|
||||
|
||||
// TODO: Hook to return the list of named types that are known.
|
||||
/// Registered parsing/printing hooks for types registered to the dialect.
|
||||
DialectTypeParserHook typeParseHook = nullptr;
|
||||
/// Note: The data printed for the provided type must not include any '"'
|
||||
/// characters.
|
||||
DialectTypePrinterHook typePrintHook = nullptr;
|
||||
|
||||
// TODO: Hook to return list of dialect defined types, like tf_control.
|
||||
// TODO: Hook to return the list of named types that are known.
|
||||
|
||||
virtual ~Dialect();
|
||||
|
||||
|
@ -95,6 +103,31 @@ protected:
|
|||
|
||||
void addOperation(AbstractOperation opInfo);
|
||||
|
||||
/// This method is used by derived classes to add their types to the set.
|
||||
template <typename... Args> void addTypes() {
|
||||
VariadicTypeAdder<Args...>::addToSet(*this);
|
||||
}
|
||||
|
||||
// It would be nice to define this as variadic functions instead of a nested
|
||||
// variadic type, but we can't do that: function template partial
|
||||
// specialization is not allowed, and we can't define an overload set
|
||||
// because we don't have any arguments of the types we are pushing around.
|
||||
template <typename First, typename... Rest> class VariadicTypeAdder {
|
||||
public:
|
||||
static void addToSet(Dialect &dialect) {
|
||||
VariadicTypeAdder<First>::addToSet(dialect);
|
||||
VariadicTypeAdder<Rest...>::addToSet(dialect);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename First> class VariadicTypeAdder<First> {
|
||||
public:
|
||||
static void addToSet(Dialect &dialect) { dialect.addType(&First::typeID); }
|
||||
};
|
||||
|
||||
// Register a type with its given unqiue type identifer.
|
||||
void addType(const void *const typeID);
|
||||
|
||||
private:
|
||||
Dialect(const Dialect &) = delete;
|
||||
void operator=(Dialect &) = delete;
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
//===- DialectTypeRegistry.def - MLIR Dialect Type Registry -----*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file enumerates the different dialects that define custom classes
|
||||
// within the type system.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DEFINE_TYPE_KIND_RANGE(TENSORFLOW_CONTROL)
|
||||
DEFINE_TYPE_KIND_RANGE(TENSORFLOW)
|
||||
|
||||
#undef DEFINE_TYPE_KIND_RANGE
|
|
@ -1,4 +1,4 @@
|
|||
//===- TypeSupport.h -------------------------------------------*- C++ -*-===//
|
||||
//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
|
@ -15,7 +15,7 @@
|
|||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// This file defines support types for the type system.
|
||||
// This file defines support types for registering dialect extended types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
@ -28,6 +28,7 @@
|
|||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class Dialect;
|
||||
class MLIRContext;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -46,48 +47,45 @@ protected:
|
|||
/// This constructor is used by derived classes as part of the TypeUniquer.
|
||||
/// When using this constructor, the initializeTypeInfo function must be
|
||||
/// invoked afterwards for the storage to be valid.
|
||||
TypeStorage(unsigned subclassData = 0) : kind(0), context(nullptr) {
|
||||
setSubclassData(subclassData);
|
||||
}
|
||||
TypeStorage(unsigned subclassData = 0)
|
||||
: dialect(nullptr), kind(0), subclassData(subclassData) {}
|
||||
|
||||
public:
|
||||
/// Get the dialect that this type is registered to.
|
||||
const Dialect &getDialect() const {
|
||||
assert(dialect && "Malformed type storage object.");
|
||||
return *dialect;
|
||||
}
|
||||
|
||||
/// Get the kind classification of this type.
|
||||
unsigned getKind() const { return kind; }
|
||||
|
||||
/// Get the context this type storage was uniqued in.
|
||||
MLIRContext *getContext() const { return context; }
|
||||
|
||||
/// Get the subclass data.
|
||||
unsigned getSubclassData() const { return subclassData; }
|
||||
|
||||
/// Set the subclass data for this type. The value provided must fit within
|
||||
/// the bitsize of the subclass data.
|
||||
void setSubclassData(unsigned val) {
|
||||
subclassData = val;
|
||||
// Ensure we don't have any accidental truncation.
|
||||
assert(getSubclassData() == val && "Subclass data too large for field");
|
||||
}
|
||||
/// Set the subclass data.
|
||||
void setSubclassData(unsigned val) { subclassData = val; }
|
||||
|
||||
private:
|
||||
// Constructor used for simple type storage that have no subclass data. This
|
||||
// constructor should not be used by derived storage classes.
|
||||
TypeStorage(unsigned kind, MLIRContext *ctx)
|
||||
: kind(kind), context(ctx), subclassData(0) {}
|
||||
TypeStorage(const Dialect &dialect, unsigned kind)
|
||||
: dialect(&dialect), kind(kind), subclassData(0) {}
|
||||
|
||||
// Initialize an existing type storage with a kind and a context. This is used
|
||||
// by the TypeUniquer when initializing a newly constructed derived type
|
||||
// storage object.
|
||||
void initializeTypeInfo(unsigned newKind, MLIRContext *ctx) {
|
||||
void initializeTypeInfo(const Dialect &newDialect, unsigned newKind) {
|
||||
dialect = &newDialect;
|
||||
kind = newKind;
|
||||
context = ctx;
|
||||
}
|
||||
|
||||
/// The registered information for the current type.
|
||||
const Dialect *dialect;
|
||||
|
||||
/// Classification of the subclass, used for type checking.
|
||||
unsigned kind;
|
||||
|
||||
/// The context the storage was uniqued in.
|
||||
MLIRContext *context;
|
||||
|
||||
/// Space for subclasses to store data.
|
||||
unsigned subclassData;
|
||||
};
|
||||
|
@ -179,11 +177,14 @@ public:
|
|||
if (storage)
|
||||
return T(storage);
|
||||
|
||||
// Get the dialect this type was registered to.
|
||||
auto &dialect = lookupDialectForType<T>();
|
||||
|
||||
// Otherwise, construct and initialize the derived storage for this type
|
||||
// instance.
|
||||
TypeStorageAllocator allocator(ctx);
|
||||
storage = ImplType::construct(allocator, args...);
|
||||
storage->initializeTypeInfo(kind, ctx);
|
||||
storage->initializeTypeInfo(dialect, kind);
|
||||
|
||||
// Insert the new type storage instance into the context.
|
||||
insert(hashValue, storage);
|
||||
|
@ -198,13 +199,22 @@ public:
|
|||
std::is_same<typename T::ImplType, detail::DefaultTypeStorage>::value,
|
||||
T>::type
|
||||
get(unsigned kind) {
|
||||
return T(getSimple(kind));
|
||||
auto &dialect = lookupDialectForType<T>();
|
||||
return T(getSimple(dialect, kind));
|
||||
}
|
||||
|
||||
private:
|
||||
/// Get the dialect that the type 'T' was registered with.
|
||||
template <typename T> const Dialect &lookupDialectForType() {
|
||||
return lookupDialectForType(&T::typeID);
|
||||
}
|
||||
|
||||
/// Get the dialect that registered the type with the provided typeid.
|
||||
const Dialect &lookupDialectForType(const void *const typeID);
|
||||
|
||||
/// Get or create a uniqued type by its kind. This overload is used for
|
||||
/// simple types that are only uniqued by kind.
|
||||
detail::TypeStorage *getSimple(unsigned kind);
|
||||
detail::TypeStorage *getSimple(const Dialect &dialect, unsigned kind);
|
||||
|
||||
/// Utilities for generating a derived storage key.
|
||||
/// Overload for if the key can be directly constructed from the provided
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#ifndef MLIR_IR_TYPES_H
|
||||
#define MLIR_IR_TYPES_H
|
||||
|
||||
#include "mlir/IR/TypeSupport.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
@ -33,21 +34,9 @@ class IndexType;
|
|||
class IntegerType;
|
||||
class Location;
|
||||
class MLIRContext;
|
||||
class OtherType;
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct TypeStorage;
|
||||
|
||||
struct IntegerTypeStorage;
|
||||
struct FunctionTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// Instances of the Type class are immutable and uniqued. They wrap a pointer
|
||||
|
@ -65,6 +54,8 @@ struct MemRefTypeStorage;
|
|||
/// Derived type classes are expected to implement several required
|
||||
/// implementaiton hooks:
|
||||
/// * Required:
|
||||
/// - static char typeID;
|
||||
/// * A unique identifier for this type used during registration.
|
||||
///
|
||||
/// - static bool kindof(unsigned kind);
|
||||
/// * Returns if the provided type kind corresponds to an instance of the
|
||||
|
@ -100,7 +91,11 @@ struct MemRefTypeStorage;
|
|||
class Type {
|
||||
public:
|
||||
/// Integer identifier for all the concrete type kinds.
|
||||
enum class Kind {
|
||||
/// Note: This is not an enum class as each dialect will likely define a
|
||||
/// separate enumeration for the specific types that they define. Not being an
|
||||
/// enum class also simplifies the handling of type kinds by not requiring
|
||||
/// casts for each use.
|
||||
enum Kind {
|
||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
||||
Index,
|
||||
|
||||
|
@ -132,6 +127,13 @@ public:
|
|||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
|
||||
LAST_BUILTIN_TYPE = 0xff,
|
||||
|
||||
// Reserve type kinds for dialect specific type system extensions.
|
||||
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
|
||||
FIRST_##Dialect##_TYPE, LAST_##Dialect##_TYPE = FIRST_##Dialect##_TYPE + 0xff,
|
||||
#include "DialectTypeRegistry.def"
|
||||
};
|
||||
|
||||
using ImplType = detail::TypeStorage;
|
||||
|
@ -158,25 +160,21 @@ public:
|
|||
template <typename U> U cast() const;
|
||||
|
||||
/// Return the classification for this type.
|
||||
Kind getKind() const;
|
||||
unsigned getKind() const;
|
||||
|
||||
/// Return the LLVMContext in which this type was uniqued.
|
||||
MLIRContext *getContext() const;
|
||||
|
||||
// Convenience predicates. This is only for 'other' and floating point types,
|
||||
/// Get the dialect this type is registered to.
|
||||
const Dialect &getDialect() const;
|
||||
|
||||
// Convenience predicates. This is only for floating point types,
|
||||
// derived types should use isa/dyn_cast.
|
||||
bool isIndex() const { return getKind() == Kind::Index; }
|
||||
bool isTFControl() const { return getKind() == Kind::TFControl; }
|
||||
bool isTFResource() const { return getKind() == Kind::TFResource; }
|
||||
bool isTFVariant() const { return getKind() == Kind::TFVariant; }
|
||||
bool isTFComplex64() const { return getKind() == Kind::TFComplex64; }
|
||||
bool isTFComplex128() const { return getKind() == Kind::TFComplex128; }
|
||||
bool isTFF32REF() const { return getKind() == Kind::TFF32REF; }
|
||||
bool isTFString() const { return getKind() == Kind::TFString; }
|
||||
bool isBF16() const { return getKind() == Kind::BF16; }
|
||||
bool isF16() const { return getKind() == Kind::F16; }
|
||||
bool isF32() const { return getKind() == Kind::F32; }
|
||||
bool isF64() const { return getKind() == Kind::F64; }
|
||||
bool isIndex() const;
|
||||
bool isBF16() const;
|
||||
bool isF16() const;
|
||||
bool isF32() const;
|
||||
bool isF64() const;
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
bool isInteger(unsigned width) const;
|
||||
|
@ -199,13 +197,6 @@ public:
|
|||
static FloatType getF16(MLIRContext *ctx);
|
||||
static FloatType getF32(MLIRContext *ctx);
|
||||
static FloatType getF64(MLIRContext *ctx);
|
||||
static OtherType getTFControl(MLIRContext *ctx);
|
||||
static OtherType getTFString(MLIRContext *ctx);
|
||||
static OtherType getTFResource(MLIRContext *ctx);
|
||||
static OtherType getTFVariant(MLIRContext *ctx);
|
||||
static OtherType getTFComplex64(MLIRContext *ctx);
|
||||
static OtherType getTFComplex128(MLIRContext *ctx);
|
||||
static OtherType getTFF32REF(MLIRContext *ctx);
|
||||
|
||||
/// Print the current type.
|
||||
void print(raw_ostream &os) const;
|
||||
|
@ -233,6 +224,27 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) {
|
|||
return os;
|
||||
}
|
||||
|
||||
/// Standard Type Utilities.
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct IntegerTypeStorage;
|
||||
struct FunctionTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
|
||||
inline bool Type::isBF16() const { return getKind() == Kind::BF16; }
|
||||
inline bool Type::isF16() const { return getKind() == Kind::F16; }
|
||||
inline bool Type::isF32() const { return getKind() == Kind::F32; }
|
||||
inline bool Type::isF64() const { return getKind() == Kind::F64; }
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
class IntegerType : public Type {
|
||||
public:
|
||||
|
@ -254,7 +266,10 @@ public:
|
|||
unsigned getWidth() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Integer; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
/// Integer representation maximal bitwidth.
|
||||
static constexpr unsigned kMaxWidth = 4096;
|
||||
|
@ -290,7 +305,7 @@ public:
|
|||
static FloatType get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= Kind::LAST_FLOATING_POINT_TYPE;
|
||||
}
|
||||
|
@ -300,6 +315,9 @@ public:
|
|||
|
||||
/// Return the floating semantics of this float type.
|
||||
const llvm::fltSemantics &getFloatSemantics() const;
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
|
@ -325,46 +343,15 @@ public:
|
|||
static IndexType get(MLIRContext *context);
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Index; }
|
||||
};
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Index; }
|
||||
|
||||
/// This is a type for the random collection of special base types.
|
||||
class OtherType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
static OtherType get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
return kind >= Kind::FIRST_OTHER_TYPE && kind <= Kind::LAST_OTHER_TYPE;
|
||||
}
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
inline IndexType Type::getIndex(MLIRContext *ctx) {
|
||||
return IndexType::get(ctx);
|
||||
}
|
||||
inline OtherType Type::getTFControl(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFControl, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFResource(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFResource, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFString(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFString, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFVariant(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFVariant, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFComplex64(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFComplex64, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFComplex128(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFComplex128, ctx);
|
||||
}
|
||||
inline OtherType Type::getTFF32REF(MLIRContext *ctx) {
|
||||
return OtherType::get(Kind::TFF32REF, ctx);
|
||||
}
|
||||
|
||||
/// Function types map from a list of inputs to a list of results.
|
||||
class FunctionType : public Type {
|
||||
|
@ -390,7 +377,10 @@ public:
|
|||
ArrayRef<Type> getResults() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Function; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Function; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
|
@ -438,7 +428,7 @@ public:
|
|||
long getSizeInBits() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == Kind::Vector || kind == Kind::RankedTensor ||
|
||||
kind == Kind::UnrankedTensor;
|
||||
}
|
||||
|
@ -469,7 +459,10 @@ public:
|
|||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Vector; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Vector; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
|
@ -479,20 +472,19 @@ public:
|
|||
using ImplType = detail::TensorTypeStorage;
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
|
||||
/// Return true if the specified element type is a TensorFlow type that is ok
|
||||
/// in a tensor.
|
||||
static bool isValidTFElementType(Type type) {
|
||||
return type.isa<FloatType>() || type.isa<IntegerType>() ||
|
||||
type.isa<OtherType>();
|
||||
}
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidElementType(Type type) {
|
||||
return isValidTFElementType(type) || type.isa<VectorType>();
|
||||
// TODO(riverriddle): TensorFlow types are currently considered valid for
|
||||
// legacy reasons.
|
||||
return type.isIntOrFloat() || type.isa<VectorType>() ||
|
||||
(type.getKind() >=
|
||||
static_cast<unsigned>(Kind::FIRST_TENSORFLOW_TYPE) &&
|
||||
type.getKind() <=
|
||||
static_cast<unsigned>(Kind::LAST_TENSORFLOW_TYPE));
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) {
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
|
||||
}
|
||||
};
|
||||
|
@ -518,7 +510,10 @@ public:
|
|||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
static bool kindof(Kind kind) { return kind == Kind::RankedTensor; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::RankedTensor; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||
|
@ -540,7 +535,10 @@ public:
|
|||
|
||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||
|
||||
static bool kindof(Kind kind) { return kind == Kind::UnrankedTensor; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::UnrankedTensor; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||
|
@ -591,7 +589,10 @@ public:
|
|||
/// Returns the number of dimensions with dynamic size.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
static bool kindof(Kind kind) { return kind == Kind::MemRef; }
|
||||
static bool kindof(unsigned kind) { return kind == Kind::MemRef; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
private:
|
||||
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
|
||||
|
@ -651,4 +652,4 @@ public:
|
|||
|
||||
} // namespace llvm
|
||||
|
||||
#endif // MLIR_IR_TYPES_H
|
||||
#endif // MLIR_IR_TYPES_H
|
||||
|
|
|
@ -486,6 +486,14 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
|
|||
|
||||
void ModulePrinter::printType(Type type) {
|
||||
switch (type.getKind()) {
|
||||
default: {
|
||||
auto &dialect = type.getDialect();
|
||||
os << "!" << dialect.getNamespace() << "<\"";
|
||||
assert(dialect.typePrintHook && "Expected dialect type printing hook.");
|
||||
dialect.typePrintHook(type, os);
|
||||
os << "\">";
|
||||
return;
|
||||
}
|
||||
case Type::Kind::Index:
|
||||
os << "index";
|
||||
return;
|
||||
|
@ -501,27 +509,6 @@ void ModulePrinter::printType(Type type) {
|
|||
case Type::Kind::F64:
|
||||
os << "f64";
|
||||
return;
|
||||
case Type::Kind::TFControl:
|
||||
os << "tf_control";
|
||||
return;
|
||||
case Type::Kind::TFResource:
|
||||
os << "tf_resource";
|
||||
return;
|
||||
case Type::Kind::TFVariant:
|
||||
os << "tf_variant";
|
||||
return;
|
||||
case Type::Kind::TFComplex64:
|
||||
os << "tf_complex64";
|
||||
return;
|
||||
case Type::Kind::TFComplex128:
|
||||
os << "tf_complex128";
|
||||
return;
|
||||
case Type::Kind::TFF32REF:
|
||||
os << "tf_f32ref";
|
||||
return;
|
||||
case Type::Kind::TFString:
|
||||
os << "tf_string";
|
||||
return;
|
||||
|
||||
case Type::Kind::Integer: {
|
||||
auto integer = type.cast<IntegerType>();
|
||||
|
|
|
@ -66,24 +66,6 @@ FloatType Builder::getF64Type() { return Type::getF64(context); }
|
|||
|
||||
IndexType Builder::getIndexType() { return Type::getIndex(context); }
|
||||
|
||||
OtherType Builder::getTFControlType() { return Type::getTFControl(context); }
|
||||
|
||||
OtherType Builder::getTFResourceType() { return Type::getTFResource(context); }
|
||||
|
||||
OtherType Builder::getTFVariantType() { return Type::getTFVariant(context); }
|
||||
|
||||
OtherType Builder::getTFComplex64Type() {
|
||||
return Type::getTFComplex64(context);
|
||||
}
|
||||
|
||||
OtherType Builder::getTFComplex128Type() {
|
||||
return Type::getTFComplex128(context);
|
||||
}
|
||||
|
||||
OtherType Builder::getTFF32REFType() { return Type::getTFF32REF(context); }
|
||||
|
||||
OtherType Builder::getTFStringType() { return Type::getTFString(context); }
|
||||
|
||||
IntegerType Builder::getI1Type() { return Type::getInteger(1, context); }
|
||||
|
||||
IntegerType Builder::getIntegerType(unsigned width) {
|
||||
|
|
|
@ -34,6 +34,8 @@ using namespace mlir;
|
|||
BuiltinDialect::BuiltinDialect(MLIRContext *context)
|
||||
: Dialect(/*namePrefix=*/"", context) {
|
||||
addOperations<AffineApplyOp, BranchOp, CondBranchOp, ConstantOp, ReturnOp>();
|
||||
addTypes<IndexType, FloatType, IntegerType, FunctionType, VectorType,
|
||||
RankedTensorType, UnrankedTensorType, MemRefType>();
|
||||
}
|
||||
|
||||
void mlir::printDimAndSymbolList(OperationInst::const_operand_iterator begin,
|
||||
|
@ -360,12 +362,6 @@ bool ConstantOp::verify() const {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (type.isTFString()) {
|
||||
if (!value.isa<StringAttr>())
|
||||
return emitOpError("requires 'value' to be a string constant");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (type.isa<FunctionType>()) {
|
||||
if (!value.isa<FunctionAttr>())
|
||||
return emitOpError("requires 'value' to be a function reference");
|
||||
|
|
|
@ -354,6 +354,9 @@ public:
|
|||
/// operations.
|
||||
StringMap<AbstractOperation> registeredOperations;
|
||||
|
||||
/// This is a mapping from type identifier to Dialect for registered types.
|
||||
DenseMap<const void *, Dialect *> registeredTypes;
|
||||
|
||||
/// These are identifiers uniqued into this MLIRContext.
|
||||
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
|
||||
|
||||
|
@ -557,6 +560,16 @@ void Dialect::addOperation(AbstractOperation opInfo) {
|
|||
}
|
||||
}
|
||||
|
||||
/// Register a dialect-specific type with the current context.
|
||||
void Dialect::addType(const void *const typeID) {
|
||||
auto &impl = context->getImpl();
|
||||
if (impl.registeredTypes.count(typeID)) {
|
||||
llvm::errs() << "error: type already registered.\n";
|
||||
abort();
|
||||
}
|
||||
impl.registeredTypes.try_emplace(typeID, this);
|
||||
}
|
||||
|
||||
/// Look up the specified operation in the operation set and return a pointer
|
||||
/// to it if present. Otherwise, return a null pointer.
|
||||
const AbstractOperation *AbstractOperation::lookup(StringRef opName,
|
||||
|
@ -717,7 +730,7 @@ llvm::BumpPtrAllocator &TypeStorageAllocator::getAllocator() {
|
|||
|
||||
/// Get or create a uniqued type by it's kind. This overload is used for
|
||||
/// simple types that are only uniqued by kind.
|
||||
TypeStorage *TypeUniquer::getSimple(unsigned kind) {
|
||||
TypeStorage *TypeUniquer::getSimple(const Dialect &dialect, unsigned kind) {
|
||||
auto &impl = ctx->getImpl();
|
||||
|
||||
// Check for an existing instance with this kind.
|
||||
|
@ -727,7 +740,14 @@ TypeStorage *TypeUniquer::getSimple(unsigned kind) {
|
|||
|
||||
// Otherwise, create a new instance and return it.
|
||||
result = impl.allocator.Allocate<DefaultTypeStorage>();
|
||||
return new (result) DefaultTypeStorage{kind, ctx};
|
||||
return new (result) DefaultTypeStorage{dialect, kind};
|
||||
}
|
||||
|
||||
/// Get the dialect that registered the type with the provided typeid.
|
||||
const Dialect &TypeUniquer::lookupDialectForType(const void *const typeID) {
|
||||
auto &impl = ctx->getImpl();
|
||||
assert(impl.registeredTypes.count(typeID) && "typeID is not registered.");
|
||||
return *impl.registeredTypes[typeID];
|
||||
}
|
||||
|
||||
/// Look up a uniqued type with a lookup key. This is used if the type defines
|
||||
|
@ -757,10 +777,6 @@ IndexType IndexType::get(MLIRContext *context) {
|
|||
return constructUniqueType<IndexType>(context, Kind::Index);
|
||||
}
|
||||
|
||||
OtherType OtherType::get(Kind kind, MLIRContext *context) {
|
||||
return constructUniqueType<OtherType>(context, kind);
|
||||
}
|
||||
|
||||
static IntegerType getIntegerType(unsigned width, MLIRContext *context,
|
||||
llvm::Optional<Location> location) {
|
||||
if (width > IntegerType::kMaxWidth) {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
#include "mlir/IR/Types.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
@ -25,11 +26,12 @@
|
|||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
Type::Kind Type::getKind() const {
|
||||
return static_cast<Type::Kind>(type->getKind());
|
||||
}
|
||||
unsigned Type::getKind() const { return type->getKind(); }
|
||||
|
||||
MLIRContext *Type::getContext() const { return type->getContext(); }
|
||||
/// Get the dialect this type is registered to.
|
||||
const Dialect &Type::getDialect() const { return type->getDialect(); }
|
||||
|
||||
MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
||||
|
||||
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
|
||||
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
|
||||
|
@ -207,3 +209,13 @@ unsigned MemRefType::getNumDynamicDims() const {
|
|||
}
|
||||
return numDynamicDims;
|
||||
}
|
||||
|
||||
// Define type identifiers.
|
||||
char IndexType::typeID = 0;
|
||||
char FloatType::typeID = 0;
|
||||
char IntegerType::typeID = 0;
|
||||
char FunctionType::typeID = 0;
|
||||
char VectorType::typeID = 0;
|
||||
char RankedTensorType::typeID = 0;
|
||||
char UnrankedTensorType::typeID = 0;
|
||||
char MemRefType::typeID = 0;
|
||||
|
|
|
@ -125,6 +125,8 @@ Token Lexer::lexToken() {
|
|||
case '@':
|
||||
return lexAtIdentifier(tokStart);
|
||||
|
||||
case '!':
|
||||
LLVM_FALLTHROUGH;
|
||||
case '^':
|
||||
LLVM_FALLTHROUGH;
|
||||
case '#':
|
||||
|
@ -237,6 +239,10 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
|
|||
kind = Token::caret_identifier;
|
||||
errorKind = "invalid block name";
|
||||
break;
|
||||
case '!':
|
||||
kind = Token::exclamation_identifier;
|
||||
errorKind = "invalid dialect type namespace";
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("invalid caller");
|
||||
}
|
||||
|
|
|
@ -181,6 +181,7 @@ public:
|
|||
VectorType parseVectorType();
|
||||
ParseResult parseXInDimensionList();
|
||||
ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
|
||||
Type parseDialectType();
|
||||
Type parseTensorType();
|
||||
Type parseMemRefType();
|
||||
Type parseFunctionType();
|
||||
|
@ -286,7 +287,7 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
|||
/// type ::= integer-type
|
||||
/// | index-type
|
||||
/// | float-type
|
||||
/// | other-type
|
||||
/// | dialect-type
|
||||
/// | vector-type
|
||||
/// | tensor-type
|
||||
/// | memref-type
|
||||
|
@ -294,7 +295,6 @@ ParseResult Parser::parseCommaSeparatedListUntil(
|
|||
///
|
||||
/// index-type ::= `index`
|
||||
/// float-type ::= `f16` | `bf16` | `f32` | `f64`
|
||||
/// other-type ::= `tf_control`
|
||||
///
|
||||
Type Parser::parseType() {
|
||||
switch (getToken().getKind()) {
|
||||
|
@ -337,25 +337,9 @@ Type Parser::parseType() {
|
|||
consumeToken(Token::kw_index);
|
||||
return builder.getIndexType();
|
||||
|
||||
// other-type
|
||||
case Token::kw_tf_control:
|
||||
consumeToken(Token::kw_tf_control);
|
||||
return builder.getTFControlType();
|
||||
case Token::kw_tf_resource:
|
||||
consumeToken(Token::kw_tf_resource);
|
||||
return builder.getTFResourceType();
|
||||
case Token::kw_tf_variant:
|
||||
consumeToken(Token::kw_tf_variant);
|
||||
return builder.getTFVariantType();
|
||||
case Token::kw_tf_complex64:
|
||||
consumeToken(Token::kw_tf_complex64);
|
||||
return builder.getTFComplex64Type();
|
||||
case Token::kw_tf_complex128:
|
||||
consumeToken(Token::kw_tf_complex128);
|
||||
return builder.getTFComplex128Type();
|
||||
case Token::kw_tf_string:
|
||||
consumeToken(Token::kw_tf_string);
|
||||
return builder.getTFStringType();
|
||||
// dialect-specific type
|
||||
case Token::exclamation_identifier:
|
||||
return parseDialectType();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -450,6 +434,51 @@ ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
|
|||
return ParseSuccess;
|
||||
}
|
||||
|
||||
/// Parse a dialect-specific type.
|
||||
///
|
||||
/// dialect-type ::= `!` dialect-namespace `<` '"' type-data '"' `>`
|
||||
///
|
||||
Type Parser::parseDialectType() {
|
||||
assert(getToken().is(Token::exclamation_identifier));
|
||||
|
||||
// Parse the dialect namespace.
|
||||
StringRef dialectName = getTokenSpelling().drop_front();
|
||||
consumeToken(Token::exclamation_identifier);
|
||||
|
||||
auto *dialect = state.context->getRegisteredDialect(dialectName);
|
||||
if (!dialect)
|
||||
return (emitError("no registered dialect with namespace: " + dialectName),
|
||||
nullptr);
|
||||
|
||||
// Make sure that the dialect provides a parsing hook.
|
||||
if (!dialect->typeParseHook)
|
||||
return (emitError("dialect '" + dialect->getNamespace() +
|
||||
"' provides no type parsing hook"),
|
||||
nullptr);
|
||||
|
||||
// Consume the '<'.
|
||||
if (parseToken(Token::less, "expected '<' in dialect type"))
|
||||
return nullptr;
|
||||
|
||||
// Parse the type specific data.
|
||||
if (getToken().isNot(Token::string))
|
||||
return (emitError("expected string literal type data in dialect type"),
|
||||
nullptr);
|
||||
|
||||
auto typeData = getToken().getStringValue();
|
||||
auto loc = getEncodedSourceLocation(getToken().getLoc());
|
||||
consumeToken(Token::string);
|
||||
|
||||
Type result = dialect->typeParseHook(typeData, loc, state.context);
|
||||
if (!result)
|
||||
return nullptr;
|
||||
|
||||
// Consume the '>'.
|
||||
if (parseToken(Token::greater, "expected '>' in dialect type"))
|
||||
return nullptr;
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Parse a tensor type.
|
||||
///
|
||||
/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
|
||||
|
|
|
@ -50,11 +50,12 @@ TOK_MARKER(eof)
|
|||
TOK_MARKER(error)
|
||||
|
||||
// Identifiers.
|
||||
TOK_IDENTIFIER(bare_identifier) // foo
|
||||
TOK_IDENTIFIER(at_identifier) // @foo
|
||||
TOK_IDENTIFIER(hash_identifier) // #foo
|
||||
TOK_IDENTIFIER(percent_identifier) // %foo
|
||||
TOK_IDENTIFIER(caret_identifier) // ^foo
|
||||
TOK_IDENTIFIER(bare_identifier) // foo
|
||||
TOK_IDENTIFIER(at_identifier) // @foo
|
||||
TOK_IDENTIFIER(hash_identifier) // #foo
|
||||
TOK_IDENTIFIER(percent_identifier) // %foo
|
||||
TOK_IDENTIFIER(caret_identifier) // ^foo
|
||||
TOK_IDENTIFIER(exclamation_identifier) // !foo
|
||||
|
||||
// Literals
|
||||
TOK_LITERAL(floatliteral) // 2.0
|
||||
|
@ -109,13 +110,6 @@ TOK_KEYWORD(opaque)
|
|||
TOK_KEYWORD(size)
|
||||
TOK_KEYWORD(step)
|
||||
TOK_KEYWORD(tensor)
|
||||
TOK_KEYWORD(tf_control)
|
||||
TOK_KEYWORD(tf_resource)
|
||||
TOK_KEYWORD(tf_variant)
|
||||
TOK_KEYWORD(tf_complex64)
|
||||
TOK_KEYWORD(tf_complex128)
|
||||
TOK_KEYWORD(tf_string)
|
||||
TOK_KEYWORD(tf_f32ref)
|
||||
TOK_KEYWORD(to)
|
||||
TOK_KEYWORD(true)
|
||||
TOK_KEYWORD(sparse)
|
||||
|
|
|
@ -537,10 +537,10 @@ func @opaquetensorattr() -> () {
|
|||
// CHECK: "opaqueFloatTensor"() {bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueFloatTensor"(){bar: opaque<tensor<2x1x4xf32>, "0x68656C6C6F">} : () -> ()
|
||||
|
||||
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4xtf_string>, "0x68656C6C6F">} : () -> ()
|
||||
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4xtf_resource>, "0x68656C6C6F">} : () -> ()
|
||||
// CHECK: "opaqueStringTensor"() {bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueStringTensor"(){bar: opaque<tensor<2x1x4x!tf<"string">>, "0x68656C6C6F">} : () -> ()
|
||||
// CHECK: "opaqueResourceTensor"() {bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
|
||||
"opaqueResourceTensor"(){bar: opaque<tensor<2x1x4x!tf<"resource">>, "0x68656C6C6F">} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue