From 54948a43802b642d29478db14c12310cf174ea10 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 3 Jan 2019 14:29:52 -0800 Subject: [PATCH] Split the standard types from builtin types and move them into separate source files(StandardTypes.cpp/h). After this cl only FunctionType and IndexType are builtin types, but IndexType will likely become a standard type when the ml/cfgfunc merger is done. Mechanical NFC. PiperOrigin-RevId: 227750918 --- mlir/include/mlir/IR/Builders.h | 1 + mlir/include/mlir/IR/DialectTypeRegistry.def | 1 + mlir/include/mlir/IR/Matchers.h | 2 +- mlir/include/mlir/IR/StandardTypes.h | 384 +++++++++++++++++ mlir/include/mlir/IR/Types.h | 390 +----------------- mlir/include/mlir/StandardOps/StandardOps.h | 1 + .../mlir/SuperVectorOps/SuperVectorOps.h | 1 + mlir/lib/IR/AsmPrinter.cpp | 20 +- mlir/lib/IR/AttributeDetail.h | 2 +- mlir/lib/IR/Builders.cpp | 2 +- mlir/lib/IR/BuiltinOps.cpp | 2 +- mlir/lib/IR/MLIRContext.cpp | 30 +- mlir/lib/IR/Operation.cpp | 4 +- mlir/lib/IR/StandardTypes.cpp | 196 +++++++++ mlir/lib/IR/Types.cpp | 174 +------- mlir/lib/Parser/Parser.cpp | 12 +- mlir/lib/StandardOps/StandardOps.cpp | 2 +- mlir/lib/SuperVectorOps/SuperVectorOps.cpp | 1 - mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp | 8 +- .../Vectorization/VectorizerTestPass.cpp | 1 + 20 files changed, 648 insertions(+), 586 deletions(-) create mode 100644 mlir/include/mlir/IR/StandardTypes.h create mode 100644 mlir/lib/IR/StandardTypes.cpp diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 7d74cc3f34da..1af0f48f728f 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -32,6 +32,7 @@ class Type; class PrimitiveType; class IntegerType; class FunctionType; +class MemRefType; class VectorType; class RankedTensorType; class UnrankedTensorType; diff --git a/mlir/include/mlir/IR/DialectTypeRegistry.def b/mlir/include/mlir/IR/DialectTypeRegistry.def index 470f93fc1aa6..40d8c313c067 100644 --- a/mlir/include/mlir/IR/DialectTypeRegistry.def +++ b/mlir/include/mlir/IR/DialectTypeRegistry.def @@ -20,6 +20,7 @@ // //===----------------------------------------------------------------------===// +DEFINE_TYPE_KIND_RANGE(STANDARD) DEFINE_TYPE_KIND_RANGE(TENSORFLOW_CONTROL) DEFINE_TYPE_KIND_RANGE(TENSORFLOW) diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index da02d8650111..3fd13a4b78ff 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -26,7 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h new file mode 100644 index 000000000000..988cf8ef7bed --- /dev/null +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -0,0 +1,384 @@ +//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_IR_STANDARDTYPES_H +#define MLIR_IR_STANDARDTYPES_H + +#include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" + +namespace llvm { +class fltSemantics; +} // namespace llvm + +namespace mlir { +class AffineMap; +class FloatType; +class IndexType; +class IntegerType; +class Location; +class MLIRContext; + +namespace detail { + +struct IntegerTypeStorage; +struct VectorOrTensorTypeStorage; +struct VectorTypeStorage; +struct TensorTypeStorage; +struct RankedTensorTypeStorage; +struct UnrankedTensorTypeStorage; +struct MemRefTypeStorage; + +} // namespace detail + +namespace StandardTypes { +enum Kind { + // Floating point. + BF16 = Type::Kind::FIRST_STANDARD_TYPE, + F16, + F32, + F64, + FIRST_FLOATING_POINT_TYPE = BF16, + LAST_FLOATING_POINT_TYPE = F64, + + // Derived types. + Integer, + Vector, + RankedTensor, + UnrankedTensor, + MemRef, +}; + +} // namespace StandardTypes + +inline bool Type::isBF16() const { return getKind() == StandardTypes::BF16; } +inline bool Type::isF16() const { return getKind() == StandardTypes::F16; } +inline bool Type::isF32() const { return getKind() == StandardTypes::F32; } +inline bool Type::isF64() const { return getKind() == StandardTypes::F64; } + +/// Integer types can have arbitrary bitwidth up to a large fixed limit. +class IntegerType : public Type { +public: + using ImplType = detail::IntegerTypeStorage; + using Type::Type; + + /// Get or create a new IntegerType of the given width within the context. + /// Assume the width is within the allowed range and assert on failures. + /// Use getChecked to handle failures gracefully. + static IntegerType get(unsigned width, MLIRContext *context); + + /// Get or create a new IntegerType of the given width within the context, + /// defined at the given, potentially unknown, location. If the width is + /// outside the allowed range, emit errors and return a null type. + static IntegerType getChecked(unsigned width, MLIRContext *context, + Location location); + + /// Return the bitwidth of this integer type. + unsigned getWidth() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } + + /// Unique identifier for this type class. + static char typeID; + + /// Integer representation maximal bitwidth. + static constexpr unsigned kMaxWidth = 4096; +}; + +inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) { + return IntegerType::get(width, ctx); +} + +/// Return true if this is an integer type with the specified width. +inline bool Type::isInteger(unsigned width) const { + if (auto intTy = dyn_cast()) + return intTy.getWidth() == width; + return false; +} + +inline bool Type::isIntOrIndex() const { + return isa() || isa(); +} + +inline bool Type::isIntOrIndexOrFloat() const { + return isa() || isa() || isa(); +} + +inline bool Type::isIntOrFloat() const { + return isa() || isa(); +} + +class FloatType : public Type { +public: + using Type::Type; + + static FloatType get(StandardTypes::Kind kind, MLIRContext *context); + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && + kind <= StandardTypes::LAST_FLOATING_POINT_TYPE; + } + + /// Return the bitwidth of this float type. + unsigned getWidth() const; + + /// Return the floating semantics of this float type. + const llvm::fltSemantics &getFloatSemantics() const; + + /// Unique identifier for this type class. + static char typeID; +}; + +inline FloatType Type::getBF16(MLIRContext *ctx) { + return FloatType::get(StandardTypes::BF16, ctx); +} +inline FloatType Type::getF16(MLIRContext *ctx) { + return FloatType::get(StandardTypes::F16, ctx); +} +inline FloatType Type::getF32(MLIRContext *ctx) { + return FloatType::get(StandardTypes::F32, ctx); +} +inline FloatType Type::getF64(MLIRContext *ctx) { + return FloatType::get(StandardTypes::F64, ctx); +} + +/// This is a common base class between Vector, UnrankedTensor, and RankedTensor +/// types, because many operations work on values of these aggregate types. +class VectorOrTensorType : public Type { +public: + using ImplType = detail::VectorOrTensorTypeStorage; + using Type::Type; + + /// Return the element type. + Type getElementType() const; + + /// If an element type is an integer or a float, return its width. Abort + /// otherwise. + unsigned getElementTypeBitWidth() const; + + /// If this is ranked tensor or vector type, return the number of elements. If + /// it is an unranked tensor, abort. + unsigned getNumElements() const; + + /// If this is ranked tensor or vector type, return the rank. If it is an + /// unranked tensor, return -1. + int getRank() const; + + /// If this is ranked tensor or vector type, return the shape. If it is an + /// unranked tensor, abort. + ArrayRef getShape() const; + + /// If this is unranked tensor or any dimension has unknown size (<0), + /// it doesn't have static shape. If all dimensions have known size (>= 0), + /// it has static shape. + bool hasStaticShape() const; + + /// If this is ranked tensor or vector type, return the size of the specified + /// dimension. It aborts if the tensor is unranked (this can be checked by + /// the getRank call method). + int getDimSize(unsigned i) const; + + /// Get the total amount of bits occupied by a value of this type. This does + /// not take into account any memory layout or widening constraints, e.g. a + /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice + /// it will likely be stored as in a 4xi64 vector register. Fail an assertion + /// if the size cannot be computed statically, i.e. if the tensor has a + /// dynamic shape or if its elemental type does not have a known bit width. + long getSizeInBits() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardTypes::Vector || + kind == StandardTypes::RankedTensor || + kind == StandardTypes::UnrankedTensor; + } +}; + +/// Vector types represent multi-dimensional SIMD vectors, and have a fixed +/// known constant shape with one or more dimension. +class VectorType : public VectorOrTensorType { +public: + using ImplType = detail::VectorTypeStorage; + using VectorOrTensorType::VectorOrTensorType; + + /// Get or create a new VectorType of the provided shape and element type. + /// Assumes the arguments define a well-formed VectorType. + static VectorType get(ArrayRef shape, Type elementType); + + /// Get or create a new VectorType of the provided shape and element type + /// declared at the given, potentially unknown, location. If the VectorType + /// defined by the arguments would be ill-formed, emit errors and return + /// nullptr-wrapping type. + static VectorType getChecked(ArrayRef shape, Type elementType, + Location location); + + /// Returns true of the given type can be used as an element of a vector type. + /// In particular, vectors can consist of integer or float primitives. + static bool isValidElementType(Type t) { return t.isIntOrFloat(); } + + ArrayRef getShape() const; + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } + + /// Unique identifier for this type class. + static char typeID; +}; + +/// Tensor types represent multi-dimensional arrays, and have two variants: +/// RankedTensorType and UnrankedTensorType. +class TensorType : public VectorOrTensorType { +public: + using ImplType = detail::TensorTypeStorage; + using VectorOrTensorType::VectorOrTensorType; + + /// Return true if the specified element type is ok in a tensor. + static bool isValidElementType(Type type) { + // Note: Non standard/builtin types are allowed to exist within tensor + // types. Dialects are expected to verify that tensor types have a valid + // element type within that dialect. + return type.isIntOrFloat() || type.isa() || + (type.getKind() >= Type::Kind::LAST_STANDARD_TYPE); + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + static bool kindof(unsigned kind) { + return kind == StandardTypes::RankedTensor || + kind == StandardTypes::UnrankedTensor; + } +}; + +/// Ranked tensor types represent multi-dimensional arrays that have a shape +/// with a fixed number of dimensions. Each shape element can be a positive +/// integer or unknown (represented -1). +class RankedTensorType : public TensorType { +public: + using ImplType = detail::RankedTensorTypeStorage; + using TensorType::TensorType; + + /// Get or create a new RankedTensorType of the provided shape and element + /// type. Assumes the arguments define a well-formed type. + static RankedTensorType get(ArrayRef shape, Type elementType); + + /// Get or create a new RankedTensorType of the provided shape and element + /// type declared at the given, potentially unknown, location. If the + /// RankedTensorType defined by the arguments would be ill-formed, emit errors + /// and return a nullptr-wrapping type. + static RankedTensorType getChecked(ArrayRef shape, Type elementType, + Location location); + + ArrayRef getShape() const; + + static bool kindof(unsigned kind) { + return kind == StandardTypes::RankedTensor; + } + + /// Unique identifier for this type class. + static char typeID; +}; + +/// Unranked tensor types represent multi-dimensional arrays that have an +/// unknown shape. +class UnrankedTensorType : public TensorType { +public: + using ImplType = detail::UnrankedTensorTypeStorage; + using TensorType::TensorType; + + /// Get or create a new UnrankedTensorType of the provided shape and element + /// type. Assumes the arguments define a well-formed type. + static UnrankedTensorType get(Type elementType); + + /// Get or create a new UnrankedTensorType of the provided shape and element + /// type declared at the given, potentially unknown, location. If the + /// UnrankedTensorType defined by the arguments would be ill-formed, emit + /// errors and return a nullptr-wrapping type. + static UnrankedTensorType getChecked(Type elementType, Location location); + + ArrayRef getShape() const { return ArrayRef(); } + + static bool kindof(unsigned kind) { + return kind == StandardTypes::UnrankedTensor; + } + + /// Unique identifier for this type class. + static char typeID; +}; + +/// MemRef types represent a region of memory that have a shape with a fixed +/// number of dimensions. Each shape element can be a positive integer or +/// unknown (represented by any negative integer). MemRef types also have an +/// affine map composition, represented as an array AffineMap pointers. +class MemRefType : public Type { +public: + using ImplType = detail::MemRefTypeStorage; + using Type::Type; + + /// Get or create a new MemRefType based on shape, element type, affine + /// map composition, and memory space. Assumes the arguments define a + /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType + /// construction failures. + static MemRefType get(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace); + + /// Get or create a new MemRefType based on shape, element type, affine + /// map composition, and memory space declared at the given location. + /// If the location is unknown, the last argument should be an instance of + /// UnknownLoc. If the MemRefType defined by the arguments would be + /// ill-formed, emits errors (to the handler registered with the context or to + /// the error stream) and returns nullptr. + static MemRefType getChecked(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace, Location location); + + unsigned getRank() const { return getShape().size(); } + + /// Returns an array of memref shape dimension sizes. + ArrayRef getShape() const; + + /// Return the size of the specified dimension, or -1 if unspecified. + int getDimSize(unsigned i) const { return getShape()[i]; } + + /// Returns the elemental type for this memref shape. + Type getElementType() const; + + /// Returns an array of affine map pointers representing the memref affine + /// map composition. + ArrayRef getAffineMaps() const; + + /// Returns the memory space in which data referred to by this memref resides. + unsigned getMemorySpace() const; + + /// Returns the number of dimensions with dynamic size. + unsigned getNumDynamicDims() const; + + static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } + + /// Unique identifier for this type class. + static char typeID; + +private: + static MemRefType getSafe(ArrayRef shape, Type elementType, + ArrayRef affineMapComposition, + unsigned memorySpace, Optional location); +}; + +} // end namespace mlir + +#endif // MLIR_IR_STANDARDTYPES_H diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index ceb2afb1c778..c417cd0a4ed6 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -23,12 +23,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMapInfo.h" -namespace llvm { -class fltSemantics; -} // namespace llvm - namespace mlir { -class AffineMap; class FloatType; class IndexType; class IntegerType; @@ -36,6 +31,7 @@ class Location; class MLIRContext; namespace detail { +struct FunctionTypeStorage; struct TypeStorage; } // namespace detail @@ -96,39 +92,13 @@ public: /// enum class also simplifies the handling of type kinds by not requiring /// casts for each use. enum Kind { + // Builtin types. + Function, + + // TODO(riverriddle) Index shouldn't really be a builtin. // Target pointer sized integer, used (e.g.) in affine mappings. Index, - - // TensorFlow types. - TFControl, - TFResource, - TFVariant, - TFComplex64, - TFComplex128, - TFF32REF, - TFString, - - /// These are marker for the first and last 'other' type. - FIRST_OTHER_TYPE = TFControl, - LAST_OTHER_TYPE = TFString, - - // Floating point. - BF16, - F16, - F32, - F64, - FIRST_FLOATING_POINT_TYPE = BF16, - LAST_FLOATING_POINT_TYPE = F64, - - // Derived types. - Integer, - Function, - Vector, - RankedTensor, - UnrankedTensor, - MemRef, - - LAST_BUILTIN_TYPE = 0xff, + LAST_BUILTIN_TYPE = Index, // Reserve type kinds for dialect specific type system extensions. #define DEFINE_TYPE_KIND_RANGE(Dialect) \ @@ -224,135 +194,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) { return os; } -/// Standard Type Utilities. - -namespace detail { - -struct IntegerTypeStorage; -struct FunctionTypeStorage; -struct VectorOrTensorTypeStorage; -struct VectorTypeStorage; -struct TensorTypeStorage; -struct RankedTensorTypeStorage; -struct UnrankedTensorTypeStorage; -struct MemRefTypeStorage; - -} // namespace detail - -inline bool Type::isIndex() const { return getKind() == Kind::Index; } -inline bool Type::isBF16() const { return getKind() == Kind::BF16; } -inline bool Type::isF16() const { return getKind() == Kind::F16; } -inline bool Type::isF32() const { return getKind() == Kind::F32; } -inline bool Type::isF64() const { return getKind() == Kind::F64; } - -/// Integer types can have arbitrary bitwidth up to a large fixed limit. -class IntegerType : public Type { -public: - using ImplType = detail::IntegerTypeStorage; - using Type::Type; - - /// Get or create a new IntegerType of the given width within the context. - /// Assume the width is within the allowed range and assert on failures. - /// Use getChecked to handle failures gracefully. - static IntegerType get(unsigned width, MLIRContext *context); - - /// Get or create a new IntegerType of the given width within the context, - /// defined at the given, potentially unknown, location. If the width is - /// outside the allowed range, emit errors and return a null type. - static IntegerType getChecked(unsigned width, MLIRContext *context, - Location location); - - /// Return the bitwidth of this integer type. - unsigned getWidth() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == Kind::Integer; } - - /// Unique identifier for this type class. - static char typeID; - - /// Integer representation maximal bitwidth. - static constexpr unsigned kMaxWidth = 4096; -}; - -inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) { - return IntegerType::get(width, ctx); -} - -/// Return true if this is an integer type with the specified width. -inline bool Type::isInteger(unsigned width) const { - if (auto intTy = dyn_cast()) - return intTy.getWidth() == width; - return false; -} - -inline bool Type::isIntOrIndex() const { - return isa() || isa(); -} - -inline bool Type::isIntOrIndexOrFloat() const { - return isa() || isa() || isa(); -} - -inline bool Type::isIntOrFloat() const { - return isa() || isa(); -} - -class FloatType : public Type { -public: - using Type::Type; - - static FloatType get(Kind kind, MLIRContext *context); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind >= Kind::FIRST_FLOATING_POINT_TYPE && - kind <= Kind::LAST_FLOATING_POINT_TYPE; - } - - /// Return the bitwidth of this float type. - unsigned getWidth() const; - - /// Return the floating semantics of this float type. - const llvm::fltSemantics &getFloatSemantics() const; - - /// Unique identifier for this type class. - static char typeID; -}; - -inline FloatType Type::getBF16(MLIRContext *ctx) { - return FloatType::get(Kind::BF16, ctx); -} -inline FloatType Type::getF16(MLIRContext *ctx) { - return FloatType::get(Kind::F16, ctx); -} -inline FloatType Type::getF32(MLIRContext *ctx) { - return FloatType::get(Kind::F32, ctx); -} -inline FloatType Type::getF64(MLIRContext *ctx) { - return FloatType::get(Kind::F64, ctx); -} - -/// Index is special integer-like type with unknown platform-dependent bit width -/// used in subscripts and loop induction variables. -class IndexType : public Type { -public: - using Type::Type; - - /// Crete an IndexType instance, unique in the given context. - static IndexType get(MLIRContext *context); - - /// Support method to enable LLVM-style type casting. - static bool kindof(unsigned kind) { return kind == Kind::Index; } - - /// Unique identifier for this type class. - static char typeID; -}; - -inline IndexType Type::getIndex(MLIRContext *ctx) { - return IndexType::get(ctx); -} - /// Function types map from a list of inputs to a list of results. class FunctionType : public Type { public: @@ -383,222 +224,27 @@ public: static char typeID; }; -/// This is a common base class between Vector, UnrankedTensor, and RankedTensor -/// types, because many operations work on values of these aggregate types. -class VectorOrTensorType : public Type { +inline bool Type::isIndex() const { return getKind() == Kind::Index; } + +/// Index is special integer-like type with unknown platform-dependent bit width +/// used in subscripts and loop induction variables. +class IndexType : public Type { public: - using ImplType = detail::VectorOrTensorTypeStorage; using Type::Type; - /// Return the element type. - Type getElementType() const; + /// Crete an IndexType instance, unique in the given context. + static IndexType get(MLIRContext *context); - /// If an element type is an integer or a float, return its width. Abort - /// otherwise. - unsigned getElementTypeBitWidth() const; - - /// If this is ranked tensor or vector type, return the number of elements. If - /// it is an unranked tensor, abort. - unsigned getNumElements() const; - - /// If this is ranked tensor or vector type, return the rank. If it is an - /// unranked tensor, return -1. - int getRank() const; - - /// If this is ranked tensor or vector type, return the shape. If it is an - /// unranked tensor, abort. - ArrayRef getShape() const; - - /// If this is unranked tensor or any dimension has unknown size (<0), - /// it doesn't have static shape. If all dimensions have known size (>= 0), - /// it has static shape. - bool hasStaticShape() const; - - /// If this is ranked tensor or vector type, return the size of the specified - /// dimension. It aborts if the tensor is unranked (this can be checked by - /// the getRank call method). - int getDimSize(unsigned i) const; - - /// Get the total amount of bits occupied by a value of this type. This does - /// not take into account any memory layout or widening constraints, e.g. a - /// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice - /// it will likely be stored as in a 4xi64 vector register. Fail an assertion - /// if the size cannot be computed statically, i.e. if the tensor has a - /// dynamic shape or if its elemental type does not have a known bit width. - long getSizeInBits() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == Kind::Vector || kind == Kind::RankedTensor || - kind == Kind::UnrankedTensor; - } -}; - -/// Vector types represent multi-dimensional SIMD vectors, and have a fixed -/// known constant shape with one or more dimension. -class VectorType : public VectorOrTensorType { -public: - using ImplType = detail::VectorTypeStorage; - using VectorOrTensorType::VectorOrTensorType; - - /// Get or create a new VectorType of the provided shape and element type. - /// Assumes the arguments define a well-formed VectorType. - static VectorType get(ArrayRef shape, Type elementType); - - /// Get or create a new VectorType of the provided shape and element type - /// declared at the given, potentially unknown, location. If the VectorType - /// defined by the arguments would be ill-formed, emit errors and return - /// nullptr-wrapping type. - static VectorType getChecked(ArrayRef shape, Type elementType, - Location location); - - /// Returns true of the given type can be used as an element of a vector type. - /// In particular, vectors can consist of integer or float primitives. - static bool isValidElementType(Type t) { return t.isIntOrFloat(); } - - ArrayRef getShape() const; - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { return kind == Kind::Vector; } + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { return kind == Kind::Index; } /// Unique identifier for this type class. static char typeID; }; -/// Tensor types represent multi-dimensional arrays, and have two variants: -/// RankedTensorType and UnrankedTensorType. -class TensorType : public VectorOrTensorType { -public: - using ImplType = detail::TensorTypeStorage; - using VectorOrTensorType::VectorOrTensorType; - - /// Return true if the specified element type is ok in a tensor. - static bool isValidElementType(Type type) { - // TODO(riverriddle): TensorFlow types are currently considered valid for - // legacy reasons. - return type.isIntOrFloat() || type.isa() || - (type.getKind() >= - static_cast(Kind::FIRST_TENSORFLOW_TYPE) && - type.getKind() <= - static_cast(Kind::LAST_TENSORFLOW_TYPE)); - } - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool kindof(unsigned kind) { - return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor; - } -}; - -/// Ranked tensor types represent multi-dimensional arrays that have a shape -/// with a fixed number of dimensions. Each shape element can be a positive -/// integer or unknown (represented -1). -class RankedTensorType : public TensorType { -public: - using ImplType = detail::RankedTensorTypeStorage; - using TensorType::TensorType; - - /// Get or create a new RankedTensorType of the provided shape and element - /// type. Assumes the arguments define a well-formed type. - static RankedTensorType get(ArrayRef shape, Type elementType); - - /// Get or create a new RankedTensorType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// RankedTensorType defined by the arguments would be ill-formed, emit errors - /// and return a nullptr-wrapping type. - static RankedTensorType getChecked(ArrayRef shape, Type elementType, - Location location); - - ArrayRef getShape() const; - - static bool kindof(unsigned kind) { return kind == Kind::RankedTensor; } - - /// Unique identifier for this type class. - static char typeID; -}; - -/// Unranked tensor types represent multi-dimensional arrays that have an -/// unknown shape. -class UnrankedTensorType : public TensorType { -public: - using ImplType = detail::UnrankedTensorTypeStorage; - using TensorType::TensorType; - - /// Get or create a new UnrankedTensorType of the provided shape and element - /// type. Assumes the arguments define a well-formed type. - static UnrankedTensorType get(Type elementType); - - /// Get or create a new UnrankedTensorType of the provided shape and element - /// type declared at the given, potentially unknown, location. If the - /// UnrankedTensorType defined by the arguments would be ill-formed, emit - /// errors and return a nullptr-wrapping type. - static UnrankedTensorType getChecked(Type elementType, Location location); - - ArrayRef getShape() const { return ArrayRef(); } - - static bool kindof(unsigned kind) { return kind == Kind::UnrankedTensor; } - - /// Unique identifier for this type class. - static char typeID; -}; - -/// MemRef types represent a region of memory that have a shape with a fixed -/// number of dimensions. Each shape element can be a positive integer or -/// unknown (represented by any negative integer). MemRef types also have an -/// affine map composition, represented as an array AffineMap pointers. -class MemRefType : public Type { -public: - using ImplType = detail::MemRefTypeStorage; - using Type::Type; - - /// Get or create a new MemRefType based on shape, element type, affine - /// map composition, and memory space. Assumes the arguments define a - /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType - /// construction failures. - static MemRefType get(ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, - unsigned memorySpace); - - /// Get or create a new MemRefType based on shape, element type, affine - /// map composition, and memory space declared at the given location. - /// If the location is unknown, the last argument should be an instance of - /// UnknownLoc. If the MemRefType defined by the arguments would be - /// ill-formed, emits errors (to the handler registered with the context or to - /// the error stream) and returns nullptr. - static MemRefType getChecked(ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, - unsigned memorySpace, Location location); - - unsigned getRank() const { return getShape().size(); } - - /// Returns an array of memref shape dimension sizes. - ArrayRef getShape() const; - - /// Return the size of the specified dimension, or -1 if unspecified. - int getDimSize(unsigned i) const { return getShape()[i]; } - - /// Returns the elemental type for this memref shape. - Type getElementType() const; - - /// Returns an array of affine map pointers representing the memref affine - /// map composition. - ArrayRef getAffineMaps() const; - - /// Returns the memory space in which data referred to by this memref resides. - unsigned getMemorySpace() const; - - /// Returns the number of dimensions with dynamic size. - unsigned getNumDynamicDims() const; - - static bool kindof(unsigned kind) { return kind == Kind::MemRef; } - - /// Unique identifier for this type class. - static char typeID; - -private: - static MemRefType getSafe(ArrayRef shape, Type elementType, - ArrayRef affineMapComposition, - unsigned memorySpace, Optional location); -}; +inline IndexType Type::getIndex(MLIRContext *ctx) { + return IndexType::get(ctx); +} // Make Type hashable. inline ::llvm::hash_code hash_value(Type arg) { diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 7bda1384a176..abb8b5f821fa 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { class AffineMap; diff --git a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h index dcdfafe87205..9b32f3f98109 100644 --- a/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h +++ b/mlir/include/mlir/SuperVectorOps/SuperVectorOps.h @@ -26,6 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" namespace mlir { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 984e0d702d85..8dd89f92bd5d 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -30,7 +30,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/STLExtras.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" @@ -497,20 +497,20 @@ void ModulePrinter::printType(Type type) { case Type::Kind::Index: os << "index"; return; - case Type::Kind::BF16: + case StandardTypes::BF16: os << "bf16"; return; - case Type::Kind::F16: + case StandardTypes::F16: os << "f16"; return; - case Type::Kind::F32: + case StandardTypes::F32: os << "f32"; return; - case Type::Kind::F64: + case StandardTypes::F64: os << "f64"; return; - case Type::Kind::Integer: { + case StandardTypes::Integer: { auto integer = type.cast(); os << 'i' << integer.getWidth(); return; @@ -530,7 +530,7 @@ void ModulePrinter::printType(Type type) { } return; } - case Type::Kind::Vector: { + case StandardTypes::Vector: { auto v = type.cast(); os << "vector<"; for (auto dim : v.getShape()) @@ -538,7 +538,7 @@ void ModulePrinter::printType(Type type) { os << v.getElementType() << '>'; return; } - case Type::Kind::RankedTensor: { + case StandardTypes::RankedTensor: { auto v = type.cast(); os << "tensor<"; for (auto dim : v.getShape()) { @@ -551,14 +551,14 @@ void ModulePrinter::printType(Type type) { os << v.getElementType() << '>'; return; } - case Type::Kind::UnrankedTensor: { + case StandardTypes::UnrankedTensor: { auto v = type.cast(); os << "tensor<*x"; printType(v.getElementType()); os << '>'; return; } - case Type::Kind::MemRef: { + case StandardTypes::MemRef: { auto v = type.cast(); os << "memref<"; for (auto dim : v.getShape()) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 635d5940a163..433d7b95edb8 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -26,7 +26,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/TrailingObjects.h" diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 346b164b524d..ef5f00c255ce 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -22,7 +22,7 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Location.h" #include "mlir/IR/Module.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" using namespace mlir; Builder::Builder(Module *module) : context(module->getContext()) {} diff --git a/mlir/lib/IR/BuiltinOps.cpp b/mlir/lib/IR/BuiltinOps.cpp index 0ba66a0754ff..a78d28ce7158 100644 --- a/mlir/lib/IR/BuiltinOps.cpp +++ b/mlir/lib/IR/BuiltinOps.cpp @@ -20,7 +20,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 56c4ef30f0ea..e4947a4dbc1f 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -767,7 +767,7 @@ void TypeUniquer::insert(unsigned hashValue, TypeStorage *storage) { /// Construct a unique instance of `DerivedType` of the given `kind` in the /// given `context` by passing `args` to the type storage. template -static DerivedType constructUniqueType(MLIRContext *context, Type::Kind kind, +static DerivedType constructUniqueType(MLIRContext *context, unsigned kind, Args... args) { return TypeUniquer(context).get(static_cast(kind), args...); @@ -787,7 +787,8 @@ static IntegerType getIntegerType(unsigned width, MLIRContext *context, return {}; } - return constructUniqueType(context, Type::Kind::Integer, width); + return constructUniqueType(context, StandardTypes::Integer, + width); } IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context, @@ -801,9 +802,10 @@ IntegerType IntegerType::get(unsigned width, MLIRContext *context) { return type; } -FloatType FloatType::get(Type::Kind kind, MLIRContext *context) { - assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE && - kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind"); +FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) { + assert(kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE && + kind <= StandardTypes::LAST_FLOATING_POINT_TYPE && + "Not an FP type kind"); return constructUniqueType(context, kind); } @@ -841,7 +843,7 @@ static VectorType getVectorType(ArrayRef shape, Type elementType, return {}; } - return constructUniqueType(context, Type::Kind::Vector, shape, + return constructUniqueType(context, StandardTypes::Vector, shape, elementType); } @@ -883,7 +885,7 @@ static RankedTensorType getRankedTensorType(ArrayRef shape, auto *context = elementType.getContext(); return constructUniqueType( - context, Type::Kind::RankedTensor, shape, elementType); + context, StandardTypes::RankedTensor, shape, elementType); } RankedTensorType RankedTensorType::get(ArrayRef shape, Type elementType) { @@ -909,7 +911,7 @@ static UnrankedTensorType getUnrankedTensorType(Type elementType, auto *context = elementType.getContext(); return constructUniqueType( - context, Type::Kind::UnrankedTensor, elementType); + context, StandardTypes::UnrankedTensor, elementType); } UnrankedTensorType UnrankedTensorType::get(Type elementType) { @@ -975,7 +977,7 @@ static MemRefType getMemRefType(ArrayRef shape, Type elementType, } affineMapComposition = cleanedAffineMapComposition; - return constructUniqueType(context, Type::Kind::MemRef, shape, + return constructUniqueType(context, StandardTypes::MemRef, shape, elementType, affineMapComposition, memorySpace); } @@ -1343,10 +1345,10 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, // Otherwise, allocate a new one, unique it and return it. auto eltType = type.getElementType(); switch (eltType.getKind()) { - case Type::Kind::BF16: - case Type::Kind::F16: - case Type::Kind::F32: - case Type::Kind::F64: { + case StandardTypes::BF16: + case StandardTypes::F16: + case StandardTypes::F32: + case StandardTypes::F64: { auto *result = impl.allocator.Allocate(); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); std::uninitialized_copy(data.begin(), data.end(), copy); @@ -1356,7 +1358,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type, {copy, data.size()}}}; return *existing.first = result; } - case Type::Kind::Integer: { + case StandardTypes::Integer: { auto width = eltType.cast().getWidth(); auto *result = impl.allocator.Allocate(); auto *copy = (char *)impl.allocator.Allocate(data.size(), 64); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index fa949a30aedc..4dc83481f811 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" using namespace mlir; /// Form the OperationName for an op with the specified string. This either is @@ -293,8 +294,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) { bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) { for (auto *result : op->getResults()) { auto elementType = getTensorOrVectorElementType(result->getType()); - auto intType = elementType.dyn_cast(); - bool isBoolType = intType && intType.getWidth() == 1; + bool isBoolType = elementType.isInteger(1); if (!isBoolType) return op->emitOpError("requires a bool result type"); } diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp new file mode 100644 index 000000000000..fa423ed68cf0 --- /dev/null +++ b/mlir/lib/IR/StandardTypes.cpp @@ -0,0 +1,196 @@ +//===- Types.cpp - MLIR Type Classes --------------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/StandardTypes.h" +#include "TypeDetail.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/STLExtras.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::detail; + +unsigned IntegerType::getWidth() const { + return static_cast(type)->width; +} + +unsigned FloatType::getWidth() const { + switch (getKind()) { + case StandardTypes::BF16: + case StandardTypes::F16: + return 16; + case StandardTypes::F32: + return 32; + case StandardTypes::F64: + return 64; + default: + llvm_unreachable("unexpected type"); + } +} + +/// Returns the floating semantics for the given type. +const llvm::fltSemantics &FloatType::getFloatSemantics() const { + if (isBF16()) + // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is + // not defined in LLVM. + // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc + // else one could add it. + // static const fltSemantics semBF16 = {127, -126, 8, 16}; + return APFloat::IEEEdouble(); + if (isF16()) + return APFloat::IEEEhalf(); + if (isF32()) + return APFloat::IEEEsingle(); + if (isF64()) + return APFloat::IEEEdouble(); + llvm_unreachable("non-floating point type used"); +} + +unsigned Type::getIntOrFloatBitWidth() const { + assert(isIntOrFloat() && "only ints and floats have a bitwidth"); + if (auto intType = dyn_cast()) { + return intType.getWidth(); + } + + auto floatType = cast(); + return floatType.getWidth(); +} + +Type VectorOrTensorType::getElementType() const { + return static_cast(type)->elementType; +} + +unsigned VectorOrTensorType::getElementTypeBitWidth() const { + return getElementType().getIntOrFloatBitWidth(); +} + +unsigned VectorOrTensorType::getNumElements() const { + switch (getKind()) { + case StandardTypes::Vector: + case StandardTypes::RankedTensor: { + auto shape = getShape(); + unsigned num = 1; + for (auto dim : shape) + num *= dim; + return num; + } + default: + llvm_unreachable("not a VectorOrTensorType or not ranked"); + } +} + +/// If this is ranked tensor or vector type, return the rank. If it is an +/// unranked tensor, return -1. +int VectorOrTensorType::getRank() const { + switch (getKind()) { + case StandardTypes::Vector: + case StandardTypes::RankedTensor: + return getShape().size(); + case StandardTypes::UnrankedTensor: + return -1; + default: + llvm_unreachable("not a VectorOrTensorType"); + } +} + +int VectorOrTensorType::getDimSize(unsigned i) const { + switch (getKind()) { + case StandardTypes::Vector: + case StandardTypes::RankedTensor: + return getShape()[i]; + default: + llvm_unreachable("not a VectorOrTensorType or not ranked"); + } +} + +// Get the number of number of bits require to store a value of the given vector +// or tensor types. Compute the value recursively since tensors are allowed to +// have vectors as elements. +long VectorOrTensorType::getSizeInBits() const { + assert(hasStaticShape() && + "cannot get the bit size of an aggregate with a dynamic shape"); + + auto elementType = getElementType(); + if (elementType.isIntOrFloat()) + return elementType.getIntOrFloatBitWidth() * getNumElements(); + + // Tensors can have vectors and other tensors as elements, vectors cannot. + assert(!isa() && "unsupported vector element type"); + auto elementVectorOrTensorType = elementType.dyn_cast(); + assert(elementVectorOrTensorType && "unsupported tensor element type"); + return getNumElements() * elementVectorOrTensorType.getSizeInBits(); +} + +ArrayRef VectorOrTensorType::getShape() const { + switch (getKind()) { + case StandardTypes::Vector: + return cast().getShape(); + case StandardTypes::RankedTensor: + return cast().getShape(); + default: + llvm_unreachable("not a VectorOrTensorType or not ranked"); + } +} + +bool VectorOrTensorType::hasStaticShape() const { + if (isa()) + return false; + auto dims = getShape(); + return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); +} + +ArrayRef VectorType::getShape() const { + return static_cast(type)->getShape(); +} + +ArrayRef RankedTensorType::getShape() const { + return static_cast(type)->getShape(); +} + +ArrayRef MemRefType::getShape() const { + return static_cast(type)->getShape(); +} + +Type MemRefType::getElementType() const { + return static_cast(type)->elementType; +} + +ArrayRef MemRefType::getAffineMaps() const { + return static_cast(type)->getAffineMaps(); +} + +unsigned MemRefType::getMemorySpace() const { + return static_cast(type)->memorySpace; +} + +unsigned MemRefType::getNumDynamicDims() const { + unsigned numDynamicDims = 0; + for (int dimSize : getShape()) { + if (dimSize == -1) + ++numDynamicDims; + } + return numDynamicDims; +} + +// Define type identifiers. +char FloatType::typeID = 0; +char IntegerType::typeID = 0; +char VectorType::typeID = 0; +char RankedTensorType::typeID = 0; +char UnrankedTensorType::typeID = 0; +char MemRefType::typeID = 0; diff --git a/mlir/lib/IR/Types.cpp b/mlir/lib/IR/Types.cpp index 07cfaf084d2f..21dd04e855f8 100644 --- a/mlir/lib/IR/Types.cpp +++ b/mlir/lib/IR/Types.cpp @@ -17,11 +17,7 @@ #include "mlir/IR/Types.h" #include "TypeDetail.h" -#include "mlir/IR/AffineMap.h" #include "mlir/IR/Dialect.h" -#include "mlir/Support/STLExtras.h" -#include "llvm/ADT/APFloat.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::detail; @@ -36,51 +32,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); } unsigned Type::getSubclassData() const { return type->getSubclassData(); } void Type::setSubclassData(unsigned val) { type->setSubclassData(val); } -unsigned IntegerType::getWidth() const { - return static_cast(type)->width; -} - -unsigned FloatType::getWidth() const { - switch (getKind()) { - case Type::Kind::BF16: - case Type::Kind::F16: - return 16; - case Type::Kind::F32: - return 32; - case Type::Kind::F64: - return 64; - default: - llvm_unreachable("unexpected type"); - } -} - -/// Returns the floating semantics for the given type. -const llvm::fltSemantics &FloatType::getFloatSemantics() const { - if (isBF16()) - // Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is - // not defined in LLVM. - // TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc - // else one could add it. - // static const fltSemantics semBF16 = {127, -126, 8, 16}; - return APFloat::IEEEdouble(); - if (isF16()) - return APFloat::IEEEhalf(); - if (isF32()) - return APFloat::IEEEsingle(); - if (isF64()) - return APFloat::IEEEdouble(); - llvm_unreachable("non-floating point type used"); -} - -unsigned Type::getIntOrFloatBitWidth() const { - assert(isIntOrFloat() && "only ints and floats have a bitwidth"); - if (auto intType = dyn_cast()) { - return intType.getWidth(); - } - - auto floatType = cast(); - return floatType.getWidth(); -} +/// Function Type. ArrayRef FunctionType::getInputs() const { return static_cast(type)->getInputs(); @@ -94,128 +46,6 @@ ArrayRef FunctionType::getResults() const { return static_cast(type)->getResults(); } -Type VectorOrTensorType::getElementType() const { - return static_cast(type)->elementType; -} - -unsigned VectorOrTensorType::getElementTypeBitWidth() const { - return getElementType().getIntOrFloatBitWidth(); -} - -unsigned VectorOrTensorType::getNumElements() const { - switch (getKind()) { - case Kind::Vector: - case Kind::RankedTensor: { - auto shape = getShape(); - unsigned num = 1; - for (auto dim : shape) - num *= dim; - return num; - } - default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); - } -} - -/// If this is ranked tensor or vector type, return the rank. If it is an -/// unranked tensor, return -1. -int VectorOrTensorType::getRank() const { - switch (getKind()) { - case Kind::Vector: - case Kind::RankedTensor: - return getShape().size(); - case Kind::UnrankedTensor: - return -1; - default: - llvm_unreachable("not a VectorOrTensorType"); - } -} - -int VectorOrTensorType::getDimSize(unsigned i) const { - switch (getKind()) { - case Kind::Vector: - case Kind::RankedTensor: - return getShape()[i]; - default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); - } -} - -// Get the number of number of bits require to store a value of the given vector -// or tensor types. Compute the value recursively since tensors are allowed to -// have vectors as elements. -long VectorOrTensorType::getSizeInBits() const { - assert(hasStaticShape() && - "cannot get the bit size of an aggregate with a dynamic shape"); - - auto elementType = getElementType(); - if (elementType.isIntOrFloat()) - return elementType.getIntOrFloatBitWidth() * getNumElements(); - - // Tensors can have vectors and other tensors as elements, vectors cannot. - assert(!isa() && "unsupported vector element type"); - auto elementVectorOrTensorType = elementType.dyn_cast(); - assert(elementVectorOrTensorType && "unsupported tensor element type"); - return getNumElements() * elementVectorOrTensorType.getSizeInBits(); -} - -ArrayRef VectorOrTensorType::getShape() const { - switch (getKind()) { - case Kind::Vector: - return cast().getShape(); - case Kind::RankedTensor: - return cast().getShape(); - default: - llvm_unreachable("not a VectorOrTensorType or not ranked"); - } -} - -bool VectorOrTensorType::hasStaticShape() const { - if (isa()) - return false; - auto dims = getShape(); - return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; }); -} - -ArrayRef VectorType::getShape() const { - return static_cast(type)->getShape(); -} - -ArrayRef RankedTensorType::getShape() const { - return static_cast(type)->getShape(); -} - -ArrayRef MemRefType::getShape() const { - return static_cast(type)->getShape(); -} - -Type MemRefType::getElementType() const { - return static_cast(type)->elementType; -} - -ArrayRef MemRefType::getAffineMaps() const { - return static_cast(type)->getAffineMaps(); -} - -unsigned MemRefType::getMemorySpace() const { - return static_cast(type)->memorySpace; -} - -unsigned MemRefType::getNumDynamicDims() const { - unsigned numDynamicDims = 0; - for (int dimSize : getShape()) { - if (dimSize == -1) - ++numDynamicDims; - } - return numDynamicDims; -} - // Define type identifiers. -char IndexType::typeID = 0; -char FloatType::typeID = 0; -char IntegerType::typeID = 0; char FunctionType::typeID = 0; -char VectorType::typeID = 0; -char RankedTensorType::typeID = 0; -char UnrankedTensorType::typeID = 0; -char MemRefType::typeID = 0; +char IndexType::typeID = 0; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 54ac6177b18a..9a4cf61124fc 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -32,7 +32,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/STLExtras.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/APInt.h" @@ -711,10 +711,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { return ParseResult::ParseFailure; // check result matches the element type. switch (eltTy.getKind()) { - case Type::Kind::BF16: - case Type::Kind::F16: - case Type::Kind::F32: - case Type::Kind::F64: { + case StandardTypes::BF16: + case StandardTypes::F16: + case StandardTypes::F32: + case StandardTypes::F64: { // Bitcast the APFloat value to APInt and store the bit representation. auto fpAttrResult = result.dyn_cast(); if (!fpAttrResult) @@ -731,7 +731,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { addToStorage(apInt.getRawData()[0]); break; } - case Type::Kind::Integer: { + case StandardTypes::Integer: { if (!result.isa()) return p.emitError("expected tensor literal element has integer type"); auto value = result.cast().getValue(); diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 4dfb3c5f3a81..fa827661897f 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -23,7 +23,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Support/STLExtras.h" diff --git a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp index aecb0b840001..4a106b066d68 100644 --- a/mlir/lib/SuperVectorOps/SuperVectorOps.cpp +++ b/mlir/lib/SuperVectorOps/SuperVectorOps.cpp @@ -25,7 +25,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" using namespace mlir; diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 841219dd4732..4b70f46c1c05 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -164,13 +164,13 @@ llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) { llvm::Type *ModuleLowerer::convertFloatType(FloatType type) { MLIRContext *context = type.getContext(); switch (type.getKind()) { - case Type::Kind::F32: + case StandardTypes::F32: return builder.getFloatTy(); - case Type::Kind::F64: + case StandardTypes::F64: return builder.getDoubleTy(); - case Type::Kind::F16: + case StandardTypes::F16: return builder.getHalfTy(); - case Type::Kind::BF16: + case StandardTypes::BF16: return context->emitError(UnknownLoc::get(context), "unsupported type: BF16"), nullptr; diff --git a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp index f4020f3e1c79..9dfcda4081ff 100644 --- a/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp +++ b/mlir/lib/Transforms/Vectorization/VectorizerTestPass.cpp @@ -24,6 +24,7 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/VectorAnalysis.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Support/STLExtras.h"