Split the standard types from builtin types and move them into separate source files(StandardTypes.cpp/h). After this cl only FunctionType and IndexType are builtin types, but IndexType will likely become a standard type when the ml/cfgfunc merger is done. Mechanical NFC.
PiperOrigin-RevId: 227750918
This commit is contained in:
parent
ae1a6619df
commit
54948a4380
|
@ -32,6 +32,7 @@ class Type;
|
|||
class PrimitiveType;
|
||||
class IntegerType;
|
||||
class FunctionType;
|
||||
class MemRefType;
|
||||
class VectorType;
|
||||
class RankedTensorType;
|
||||
class UnrankedTensorType;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
DEFINE_TYPE_KIND_RANGE(STANDARD)
|
||||
DEFINE_TYPE_KIND_RANGE(TENSORFLOW_CONTROL)
|
||||
DEFINE_TYPE_KIND_RANGE(TENSORFLOW)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include <type_traits>
|
||||
|
||||
|
|
|
@ -0,0 +1,384 @@
|
|||
//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_IR_STANDARDTYPES_H
|
||||
#define MLIR_IR_STANDARDTYPES_H
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace llvm {
|
||||
class fltSemantics;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class FloatType;
|
||||
class IndexType;
|
||||
class IntegerType;
|
||||
class Location;
|
||||
class MLIRContext;
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct IntegerTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
namespace StandardTypes {
|
||||
enum Kind {
|
||||
// Floating point.
|
||||
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
FIRST_FLOATING_POINT_TYPE = BF16,
|
||||
LAST_FLOATING_POINT_TYPE = F64,
|
||||
|
||||
// Derived types.
|
||||
Integer,
|
||||
Vector,
|
||||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
};
|
||||
|
||||
} // namespace StandardTypes
|
||||
|
||||
inline bool Type::isBF16() const { return getKind() == StandardTypes::BF16; }
|
||||
inline bool Type::isF16() const { return getKind() == StandardTypes::F16; }
|
||||
inline bool Type::isF32() const { return getKind() == StandardTypes::F32; }
|
||||
inline bool Type::isF64() const { return getKind() == StandardTypes::F64; }
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
class IntegerType : public Type {
|
||||
public:
|
||||
using ImplType = detail::IntegerTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context.
|
||||
/// Assume the width is within the allowed range and assert on failures.
|
||||
/// Use getChecked to handle failures gracefully.
|
||||
static IntegerType get(unsigned width, MLIRContext *context);
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context,
|
||||
/// defined at the given, potentially unknown, location. If the width is
|
||||
/// outside the allowed range, emit errors and return a null type.
|
||||
static IntegerType getChecked(unsigned width, MLIRContext *context,
|
||||
Location location);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
/// Integer representation maximal bitwidth.
|
||||
static constexpr unsigned kMaxWidth = 4096;
|
||||
};
|
||||
|
||||
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||
return IntegerType::get(width, ctx);
|
||||
}
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
inline bool Type::isInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
return intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrIndex() const {
|
||||
return isa<IndexType>() || isa<IntegerType>();
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrIndexOrFloat() const {
|
||||
return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrFloat() const {
|
||||
return isa<IntegerType>() || isa<FloatType>();
|
||||
}
|
||||
|
||||
class FloatType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
|
||||
}
|
||||
|
||||
/// Return the bitwidth of this float type.
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Return the floating semantics of this float type.
|
||||
const llvm::fltSemantics &getFloatSemantics() const;
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::BF16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF16(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF32(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F32, ctx);
|
||||
}
|
||||
inline FloatType Type::getF64(MLIRContext *ctx) {
|
||||
return FloatType::get(StandardTypes::F64, ctx);
|
||||
}
|
||||
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
/// types, because many operations work on values of these aggregate types.
|
||||
class VectorOrTensorType : public Type {
|
||||
public:
|
||||
using ImplType = detail::VectorOrTensorTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Return the element type.
|
||||
Type getElementType() const;
|
||||
|
||||
/// If an element type is an integer or a float, return its width. Abort
|
||||
/// otherwise.
|
||||
unsigned getElementTypeBitWidth() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the number of elements. If
|
||||
/// it is an unranked tensor, abort.
|
||||
unsigned getNumElements() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int getRank() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the shape. If it is an
|
||||
/// unranked tensor, abort.
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// If this is unranked tensor or any dimension has unknown size (<0),
|
||||
/// it doesn't have static shape. If all dimensions have known size (>= 0),
|
||||
/// it has static shape.
|
||||
bool hasStaticShape() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the size of the specified
|
||||
/// dimension. It aborts if the tensor is unranked (this can be checked by
|
||||
/// the getRank call method).
|
||||
int getDimSize(unsigned i) const;
|
||||
|
||||
/// Get the total amount of bits occupied by a value of this type. This does
|
||||
/// not take into account any memory layout or widening constraints, e.g. a
|
||||
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
|
||||
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
|
||||
/// if the size cannot be computed statically, i.e. if the tensor has a
|
||||
/// dynamic shape or if its elemental type does not have a known bit width.
|
||||
long getSizeInBits() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardTypes::Vector ||
|
||||
kind == StandardTypes::RankedTensor ||
|
||||
kind == StandardTypes::UnrankedTensor;
|
||||
}
|
||||
};
|
||||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public VectorOrTensorType {
|
||||
public:
|
||||
using ImplType = detail::VectorTypeStorage;
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
|
||||
/// Get or create a new VectorType of the provided shape and element type.
|
||||
/// Assumes the arguments define a well-formed VectorType.
|
||||
static VectorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
/// Get or create a new VectorType of the provided shape and element type
|
||||
/// declared at the given, potentially unknown, location. If the VectorType
|
||||
/// defined by the arguments would be ill-formed, emit errors and return
|
||||
/// nullptr-wrapping type.
|
||||
static VectorType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
Location location);
|
||||
|
||||
/// Returns true of the given type can be used as an element of a vector type.
|
||||
/// In particular, vectors can consist of integer or float primitives.
|
||||
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
|
||||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
/// RankedTensorType and UnrankedTensorType.
|
||||
class TensorType : public VectorOrTensorType {
|
||||
public:
|
||||
using ImplType = detail::TensorTypeStorage;
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidElementType(Type type) {
|
||||
// Note: Non standard/builtin types are allowed to exist within tensor
|
||||
// types. Dialects are expected to verify that tensor types have a valid
|
||||
// element type within that dialect.
|
||||
return type.isIntOrFloat() || type.isa<VectorType>() ||
|
||||
(type.getKind() >= Type::Kind::LAST_STANDARD_TYPE);
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardTypes::RankedTensor ||
|
||||
kind == StandardTypes::UnrankedTensor;
|
||||
}
|
||||
};
|
||||
|
||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||
/// with a fixed number of dimensions. Each shape element can be a positive
|
||||
/// integer or unknown (represented -1).
|
||||
class RankedTensorType : public TensorType {
|
||||
public:
|
||||
using ImplType = detail::RankedTensorTypeStorage;
|
||||
using TensorType::TensorType;
|
||||
|
||||
/// Get or create a new RankedTensorType of the provided shape and element
|
||||
/// type. Assumes the arguments define a well-formed type.
|
||||
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
/// Get or create a new RankedTensorType of the provided shape and element
|
||||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
|
||||
/// and return a nullptr-wrapping type.
|
||||
static RankedTensorType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
Location location);
|
||||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardTypes::RankedTensor;
|
||||
}
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||
/// unknown shape.
|
||||
class UnrankedTensorType : public TensorType {
|
||||
public:
|
||||
using ImplType = detail::UnrankedTensorTypeStorage;
|
||||
using TensorType::TensorType;
|
||||
|
||||
/// Get or create a new UnrankedTensorType of the provided shape and element
|
||||
/// type. Assumes the arguments define a well-formed type.
|
||||
static UnrankedTensorType get(Type elementType);
|
||||
|
||||
/// Get or create a new UnrankedTensorType of the provided shape and element
|
||||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
|
||||
/// errors and return a nullptr-wrapping type.
|
||||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
||||
|
||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == StandardTypes::UnrankedTensor;
|
||||
}
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||
/// number of dimensions. Each shape element can be a positive integer or
|
||||
/// unknown (represented by any negative integer). MemRef types also have an
|
||||
/// affine map composition, represented as an array AffineMap pointers.
|
||||
class MemRefType : public Type {
|
||||
public:
|
||||
using ImplType = detail::MemRefTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space. Assumes the arguments define a
|
||||
/// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
|
||||
/// construction failures.
|
||||
static MemRefType get(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace);
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space declared at the given location.
|
||||
/// If the location is unknown, the last argument should be an instance of
|
||||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||
/// the error stream) and returns nullptr.
|
||||
static MemRefType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Location location);
|
||||
|
||||
unsigned getRank() const { return getShape().size(); }
|
||||
|
||||
/// Returns an array of memref shape dimension sizes.
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
/// Returns the elemental type for this memref shape.
|
||||
Type getElementType() const;
|
||||
|
||||
/// Returns an array of affine map pointers representing the memref affine
|
||||
/// map composition.
|
||||
ArrayRef<AffineMap> getAffineMaps() const;
|
||||
|
||||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const;
|
||||
|
||||
/// Returns the number of dimensions with dynamic size.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
private:
|
||||
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Optional<Location> location);
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_IR_STANDARDTYPES_H
|
|
@ -23,12 +23,7 @@
|
|||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMapInfo.h"
|
||||
|
||||
namespace llvm {
|
||||
class fltSemantics;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
class FloatType;
|
||||
class IndexType;
|
||||
class IntegerType;
|
||||
|
@ -36,6 +31,7 @@ class Location;
|
|||
class MLIRContext;
|
||||
|
||||
namespace detail {
|
||||
struct FunctionTypeStorage;
|
||||
struct TypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
|
@ -96,39 +92,13 @@ public:
|
|||
/// enum class also simplifies the handling of type kinds by not requiring
|
||||
/// casts for each use.
|
||||
enum Kind {
|
||||
// Builtin types.
|
||||
Function,
|
||||
|
||||
// TODO(riverriddle) Index shouldn't really be a builtin.
|
||||
// Target pointer sized integer, used (e.g.) in affine mappings.
|
||||
Index,
|
||||
|
||||
// TensorFlow types.
|
||||
TFControl,
|
||||
TFResource,
|
||||
TFVariant,
|
||||
TFComplex64,
|
||||
TFComplex128,
|
||||
TFF32REF,
|
||||
TFString,
|
||||
|
||||
/// These are marker for the first and last 'other' type.
|
||||
FIRST_OTHER_TYPE = TFControl,
|
||||
LAST_OTHER_TYPE = TFString,
|
||||
|
||||
// Floating point.
|
||||
BF16,
|
||||
F16,
|
||||
F32,
|
||||
F64,
|
||||
FIRST_FLOATING_POINT_TYPE = BF16,
|
||||
LAST_FLOATING_POINT_TYPE = F64,
|
||||
|
||||
// Derived types.
|
||||
Integer,
|
||||
Function,
|
||||
Vector,
|
||||
RankedTensor,
|
||||
UnrankedTensor,
|
||||
MemRef,
|
||||
|
||||
LAST_BUILTIN_TYPE = 0xff,
|
||||
LAST_BUILTIN_TYPE = Index,
|
||||
|
||||
// Reserve type kinds for dialect specific type system extensions.
|
||||
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
|
||||
|
@ -224,135 +194,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) {
|
|||
return os;
|
||||
}
|
||||
|
||||
/// Standard Type Utilities.
|
||||
|
||||
namespace detail {
|
||||
|
||||
struct IntegerTypeStorage;
|
||||
struct FunctionTypeStorage;
|
||||
struct VectorOrTensorTypeStorage;
|
||||
struct VectorTypeStorage;
|
||||
struct TensorTypeStorage;
|
||||
struct RankedTensorTypeStorage;
|
||||
struct UnrankedTensorTypeStorage;
|
||||
struct MemRefTypeStorage;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
|
||||
inline bool Type::isBF16() const { return getKind() == Kind::BF16; }
|
||||
inline bool Type::isF16() const { return getKind() == Kind::F16; }
|
||||
inline bool Type::isF32() const { return getKind() == Kind::F32; }
|
||||
inline bool Type::isF64() const { return getKind() == Kind::F64; }
|
||||
|
||||
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
|
||||
class IntegerType : public Type {
|
||||
public:
|
||||
using ImplType = detail::IntegerTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context.
|
||||
/// Assume the width is within the allowed range and assert on failures.
|
||||
/// Use getChecked to handle failures gracefully.
|
||||
static IntegerType get(unsigned width, MLIRContext *context);
|
||||
|
||||
/// Get or create a new IntegerType of the given width within the context,
|
||||
/// defined at the given, potentially unknown, location. If the width is
|
||||
/// outside the allowed range, emit errors and return a null type.
|
||||
static IntegerType getChecked(unsigned width, MLIRContext *context,
|
||||
Location location);
|
||||
|
||||
/// Return the bitwidth of this integer type.
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Integer; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
/// Integer representation maximal bitwidth.
|
||||
static constexpr unsigned kMaxWidth = 4096;
|
||||
};
|
||||
|
||||
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
|
||||
return IntegerType::get(width, ctx);
|
||||
}
|
||||
|
||||
/// Return true if this is an integer type with the specified width.
|
||||
inline bool Type::isInteger(unsigned width) const {
|
||||
if (auto intTy = dyn_cast<IntegerType>())
|
||||
return intTy.getWidth() == width;
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrIndex() const {
|
||||
return isa<IndexType>() || isa<IntegerType>();
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrIndexOrFloat() const {
|
||||
return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
|
||||
}
|
||||
|
||||
inline bool Type::isIntOrFloat() const {
|
||||
return isa<IntegerType>() || isa<FloatType>();
|
||||
}
|
||||
|
||||
class FloatType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
static FloatType get(Kind kind, MLIRContext *context);
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= Kind::LAST_FLOATING_POINT_TYPE;
|
||||
}
|
||||
|
||||
/// Return the bitwidth of this float type.
|
||||
unsigned getWidth() const;
|
||||
|
||||
/// Return the floating semantics of this float type.
|
||||
const llvm::fltSemantics &getFloatSemantics() const;
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
inline FloatType Type::getBF16(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::BF16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF16(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F16, ctx);
|
||||
}
|
||||
inline FloatType Type::getF32(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F32, ctx);
|
||||
}
|
||||
inline FloatType Type::getF64(MLIRContext *ctx) {
|
||||
return FloatType::get(Kind::F64, ctx);
|
||||
}
|
||||
|
||||
/// Index is special integer-like type with unknown platform-dependent bit width
|
||||
/// used in subscripts and loop induction variables.
|
||||
class IndexType : public Type {
|
||||
public:
|
||||
using Type::Type;
|
||||
|
||||
/// Crete an IndexType instance, unique in the given context.
|
||||
static IndexType get(MLIRContext *context);
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Index; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
inline IndexType Type::getIndex(MLIRContext *ctx) {
|
||||
return IndexType::get(ctx);
|
||||
}
|
||||
|
||||
/// Function types map from a list of inputs to a list of results.
|
||||
class FunctionType : public Type {
|
||||
public:
|
||||
|
@ -383,222 +224,27 @@ public:
|
|||
static char typeID;
|
||||
};
|
||||
|
||||
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
|
||||
/// types, because many operations work on values of these aggregate types.
|
||||
class VectorOrTensorType : public Type {
|
||||
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
|
||||
|
||||
/// Index is special integer-like type with unknown platform-dependent bit width
|
||||
/// used in subscripts and loop induction variables.
|
||||
class IndexType : public Type {
|
||||
public:
|
||||
using ImplType = detail::VectorOrTensorTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Return the element type.
|
||||
Type getElementType() const;
|
||||
/// Crete an IndexType instance, unique in the given context.
|
||||
static IndexType get(MLIRContext *context);
|
||||
|
||||
/// If an element type is an integer or a float, return its width. Abort
|
||||
/// otherwise.
|
||||
unsigned getElementTypeBitWidth() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the number of elements. If
|
||||
/// it is an unranked tensor, abort.
|
||||
unsigned getNumElements() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int getRank() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the shape. If it is an
|
||||
/// unranked tensor, abort.
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// If this is unranked tensor or any dimension has unknown size (<0),
|
||||
/// it doesn't have static shape. If all dimensions have known size (>= 0),
|
||||
/// it has static shape.
|
||||
bool hasStaticShape() const;
|
||||
|
||||
/// If this is ranked tensor or vector type, return the size of the specified
|
||||
/// dimension. It aborts if the tensor is unranked (this can be checked by
|
||||
/// the getRank call method).
|
||||
int getDimSize(unsigned i) const;
|
||||
|
||||
/// Get the total amount of bits occupied by a value of this type. This does
|
||||
/// not take into account any memory layout or widening constraints, e.g. a
|
||||
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
|
||||
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
|
||||
/// if the size cannot be computed statically, i.e. if the tensor has a
|
||||
/// dynamic shape or if its elemental type does not have a known bit width.
|
||||
long getSizeInBits() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == Kind::Vector || kind == Kind::RankedTensor ||
|
||||
kind == Kind::UnrankedTensor;
|
||||
}
|
||||
};
|
||||
|
||||
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
|
||||
/// known constant shape with one or more dimension.
|
||||
class VectorType : public VectorOrTensorType {
|
||||
public:
|
||||
using ImplType = detail::VectorTypeStorage;
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
|
||||
/// Get or create a new VectorType of the provided shape and element type.
|
||||
/// Assumes the arguments define a well-formed VectorType.
|
||||
static VectorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
/// Get or create a new VectorType of the provided shape and element type
|
||||
/// declared at the given, potentially unknown, location. If the VectorType
|
||||
/// defined by the arguments would be ill-formed, emit errors and return
|
||||
/// nullptr-wrapping type.
|
||||
static VectorType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
Location location);
|
||||
|
||||
/// Returns true of the given type can be used as an element of a vector type.
|
||||
/// In particular, vectors can consist of integer or float primitives.
|
||||
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
|
||||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Vector; }
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) { return kind == Kind::Index; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Tensor types represent multi-dimensional arrays, and have two variants:
|
||||
/// RankedTensorType and UnrankedTensorType.
|
||||
class TensorType : public VectorOrTensorType {
|
||||
public:
|
||||
using ImplType = detail::TensorTypeStorage;
|
||||
using VectorOrTensorType::VectorOrTensorType;
|
||||
|
||||
/// Return true if the specified element type is ok in a tensor.
|
||||
static bool isValidElementType(Type type) {
|
||||
// TODO(riverriddle): TensorFlow types are currently considered valid for
|
||||
// legacy reasons.
|
||||
return type.isIntOrFloat() || type.isa<VectorType>() ||
|
||||
(type.getKind() >=
|
||||
static_cast<unsigned>(Kind::FIRST_TENSORFLOW_TYPE) &&
|
||||
type.getKind() <=
|
||||
static_cast<unsigned>(Kind::LAST_TENSORFLOW_TYPE));
|
||||
}
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
|
||||
}
|
||||
};
|
||||
|
||||
/// Ranked tensor types represent multi-dimensional arrays that have a shape
|
||||
/// with a fixed number of dimensions. Each shape element can be a positive
|
||||
/// integer or unknown (represented -1).
|
||||
class RankedTensorType : public TensorType {
|
||||
public:
|
||||
using ImplType = detail::RankedTensorTypeStorage;
|
||||
using TensorType::TensorType;
|
||||
|
||||
/// Get or create a new RankedTensorType of the provided shape and element
|
||||
/// type. Assumes the arguments define a well-formed type.
|
||||
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
|
||||
|
||||
/// Get or create a new RankedTensorType of the provided shape and element
|
||||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
|
||||
/// and return a nullptr-wrapping type.
|
||||
static RankedTensorType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
Location location);
|
||||
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == Kind::RankedTensor; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// Unranked tensor types represent multi-dimensional arrays that have an
|
||||
/// unknown shape.
|
||||
class UnrankedTensorType : public TensorType {
|
||||
public:
|
||||
using ImplType = detail::UnrankedTensorTypeStorage;
|
||||
using TensorType::TensorType;
|
||||
|
||||
/// Get or create a new UnrankedTensorType of the provided shape and element
|
||||
/// type. Assumes the arguments define a well-formed type.
|
||||
static UnrankedTensorType get(Type elementType);
|
||||
|
||||
/// Get or create a new UnrankedTensorType of the provided shape and element
|
||||
/// type declared at the given, potentially unknown, location. If the
|
||||
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
|
||||
/// errors and return a nullptr-wrapping type.
|
||||
static UnrankedTensorType getChecked(Type elementType, Location location);
|
||||
|
||||
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == Kind::UnrankedTensor; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
};
|
||||
|
||||
/// MemRef types represent a region of memory that have a shape with a fixed
|
||||
/// number of dimensions. Each shape element can be a positive integer or
|
||||
/// unknown (represented by any negative integer). MemRef types also have an
|
||||
/// affine map composition, represented as an array AffineMap pointers.
|
||||
class MemRefType : public Type {
|
||||
public:
|
||||
using ImplType = detail::MemRefTypeStorage;
|
||||
using Type::Type;
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space. Assumes the arguments define a
|
||||
/// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
|
||||
/// construction failures.
|
||||
static MemRefType get(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace);
|
||||
|
||||
/// Get or create a new MemRefType based on shape, element type, affine
|
||||
/// map composition, and memory space declared at the given location.
|
||||
/// If the location is unknown, the last argument should be an instance of
|
||||
/// UnknownLoc. If the MemRefType defined by the arguments would be
|
||||
/// ill-formed, emits errors (to the handler registered with the context or to
|
||||
/// the error stream) and returns nullptr.
|
||||
static MemRefType getChecked(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Location location);
|
||||
|
||||
unsigned getRank() const { return getShape().size(); }
|
||||
|
||||
/// Returns an array of memref shape dimension sizes.
|
||||
ArrayRef<int> getShape() const;
|
||||
|
||||
/// Return the size of the specified dimension, or -1 if unspecified.
|
||||
int getDimSize(unsigned i) const { return getShape()[i]; }
|
||||
|
||||
/// Returns the elemental type for this memref shape.
|
||||
Type getElementType() const;
|
||||
|
||||
/// Returns an array of affine map pointers representing the memref affine
|
||||
/// map composition.
|
||||
ArrayRef<AffineMap> getAffineMaps() const;
|
||||
|
||||
/// Returns the memory space in which data referred to by this memref resides.
|
||||
unsigned getMemorySpace() const;
|
||||
|
||||
/// Returns the number of dimensions with dynamic size.
|
||||
unsigned getNumDynamicDims() const;
|
||||
|
||||
static bool kindof(unsigned kind) { return kind == Kind::MemRef; }
|
||||
|
||||
/// Unique identifier for this type class.
|
||||
static char typeID;
|
||||
|
||||
private:
|
||||
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
|
||||
ArrayRef<AffineMap> affineMapComposition,
|
||||
unsigned memorySpace, Optional<Location> location);
|
||||
};
|
||||
inline IndexType Type::getIndex(MLIRContext *ctx) {
|
||||
return IndexType::get(ctx);
|
||||
}
|
||||
|
||||
// Make Type hashable.
|
||||
inline ::llvm::hash_code hash_value(Type arg) {
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
class AffineMap;
|
||||
|
|
|
@ -26,6 +26,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
|
@ -497,20 +497,20 @@ void ModulePrinter::printType(Type type) {
|
|||
case Type::Kind::Index:
|
||||
os << "index";
|
||||
return;
|
||||
case Type::Kind::BF16:
|
||||
case StandardTypes::BF16:
|
||||
os << "bf16";
|
||||
return;
|
||||
case Type::Kind::F16:
|
||||
case StandardTypes::F16:
|
||||
os << "f16";
|
||||
return;
|
||||
case Type::Kind::F32:
|
||||
case StandardTypes::F32:
|
||||
os << "f32";
|
||||
return;
|
||||
case Type::Kind::F64:
|
||||
case StandardTypes::F64:
|
||||
os << "f64";
|
||||
return;
|
||||
|
||||
case Type::Kind::Integer: {
|
||||
case StandardTypes::Integer: {
|
||||
auto integer = type.cast<IntegerType>();
|
||||
os << 'i' << integer.getWidth();
|
||||
return;
|
||||
|
@ -530,7 +530,7 @@ void ModulePrinter::printType(Type type) {
|
|||
}
|
||||
return;
|
||||
}
|
||||
case Type::Kind::Vector: {
|
||||
case StandardTypes::Vector: {
|
||||
auto v = type.cast<VectorType>();
|
||||
os << "vector<";
|
||||
for (auto dim : v.getShape())
|
||||
|
@ -538,7 +538,7 @@ void ModulePrinter::printType(Type type) {
|
|||
os << v.getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::RankedTensor: {
|
||||
case StandardTypes::RankedTensor: {
|
||||
auto v = type.cast<RankedTensorType>();
|
||||
os << "tensor<";
|
||||
for (auto dim : v.getShape()) {
|
||||
|
@ -551,14 +551,14 @@ void ModulePrinter::printType(Type type) {
|
|||
os << v.getElementType() << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::UnrankedTensor: {
|
||||
case StandardTypes::UnrankedTensor: {
|
||||
auto v = type.cast<UnrankedTensorType>();
|
||||
os << "tensor<*x";
|
||||
printType(v.getElementType());
|
||||
os << '>';
|
||||
return;
|
||||
}
|
||||
case Type::Kind::MemRef: {
|
||||
case StandardTypes::MemRef: {
|
||||
auto v = type.cast<MemRefType>();
|
||||
os << "memref<";
|
||||
for (auto dim : v.getShape()) {
|
||||
|
|
|
@ -26,7 +26,7 @@
|
|||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@
|
|||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/Location.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
using namespace mlir;
|
||||
|
||||
Builder::Builder(Module *module) : context(module->getContext()) {}
|
||||
|
|
|
@ -20,7 +20,7 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
|
|
@ -767,7 +767,7 @@ void TypeUniquer::insert(unsigned hashValue, TypeStorage *storage) {
|
|||
/// Construct a unique instance of `DerivedType` of the given `kind` in the
|
||||
/// given `context` by passing `args` to the type storage.
|
||||
template <typename DerivedType, typename... Args>
|
||||
static DerivedType constructUniqueType(MLIRContext *context, Type::Kind kind,
|
||||
static DerivedType constructUniqueType(MLIRContext *context, unsigned kind,
|
||||
Args... args) {
|
||||
return TypeUniquer(context).get<DerivedType>(static_cast<unsigned>(kind),
|
||||
args...);
|
||||
|
@ -787,7 +787,8 @@ static IntegerType getIntegerType(unsigned width, MLIRContext *context,
|
|||
return {};
|
||||
}
|
||||
|
||||
return constructUniqueType<IntegerType>(context, Type::Kind::Integer, width);
|
||||
return constructUniqueType<IntegerType>(context, StandardTypes::Integer,
|
||||
width);
|
||||
}
|
||||
|
||||
IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
|
||||
|
@ -801,9 +802,10 @@ IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
|
|||
return type;
|
||||
}
|
||||
|
||||
FloatType FloatType::get(Type::Kind kind, MLIRContext *context) {
|
||||
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
|
||||
FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
|
||||
assert(kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
|
||||
kind <= StandardTypes::LAST_FLOATING_POINT_TYPE &&
|
||||
"Not an FP type kind");
|
||||
return constructUniqueType<FloatType>(context, kind);
|
||||
}
|
||||
|
||||
|
@ -841,7 +843,7 @@ static VectorType getVectorType(ArrayRef<int> shape, Type elementType,
|
|||
return {};
|
||||
}
|
||||
|
||||
return constructUniqueType<VectorType>(context, Type::Kind::Vector, shape,
|
||||
return constructUniqueType<VectorType>(context, StandardTypes::Vector, shape,
|
||||
elementType);
|
||||
}
|
||||
|
||||
|
@ -883,7 +885,7 @@ static RankedTensorType getRankedTensorType(ArrayRef<int> shape,
|
|||
|
||||
auto *context = elementType.getContext();
|
||||
return constructUniqueType<RankedTensorType>(
|
||||
context, Type::Kind::RankedTensor, shape, elementType);
|
||||
context, StandardTypes::RankedTensor, shape, elementType);
|
||||
}
|
||||
|
||||
RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
|
||||
|
@ -909,7 +911,7 @@ static UnrankedTensorType getUnrankedTensorType(Type elementType,
|
|||
|
||||
auto *context = elementType.getContext();
|
||||
return constructUniqueType<UnrankedTensorType>(
|
||||
context, Type::Kind::UnrankedTensor, elementType);
|
||||
context, StandardTypes::UnrankedTensor, elementType);
|
||||
}
|
||||
|
||||
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
|
||||
|
@ -975,7 +977,7 @@ static MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
|
|||
}
|
||||
affineMapComposition = cleanedAffineMapComposition;
|
||||
|
||||
return constructUniqueType<MemRefType>(context, Type::Kind::MemRef, shape,
|
||||
return constructUniqueType<MemRefType>(context, StandardTypes::MemRef, shape,
|
||||
elementType, affineMapComposition,
|
||||
memorySpace);
|
||||
}
|
||||
|
@ -1343,10 +1345,10 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
// Otherwise, allocate a new one, unique it and return it.
|
||||
auto eltType = type.getElementType();
|
||||
switch (eltType.getKind()) {
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
case Type::Kind::F64: {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64: {
|
||||
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
std::uninitialized_copy(data.begin(), data.end(), copy);
|
||||
|
@ -1356,7 +1358,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
|
|||
{copy, data.size()}}};
|
||||
return *existing.first = result;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
case StandardTypes::Integer: {
|
||||
auto width = eltType.cast<IntegerType>().getWidth();
|
||||
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
|
||||
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
using namespace mlir;
|
||||
|
||||
/// Form the OperationName for an op with the specified string. This either is
|
||||
|
@ -293,8 +294,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
|
|||
bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) {
|
||||
for (auto *result : op->getResults()) {
|
||||
auto elementType = getTensorOrVectorElementType(result->getType());
|
||||
auto intType = elementType.dyn_cast<IntegerType>();
|
||||
bool isBoolType = intType && intType.getWidth() == 1;
|
||||
bool isBoolType = elementType.isInteger(1);
|
||||
if (!isBoolType)
|
||||
return op->emitOpError("requires a bool result type");
|
||||
}
|
||||
|
|
|
@ -0,0 +1,196 @@
|
|||
//===- Types.cpp - MLIR Type Classes --------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
||||
unsigned IntegerType::getWidth() const {
|
||||
return static_cast<ImplType *>(type)->width;
|
||||
}
|
||||
|
||||
unsigned FloatType::getWidth() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
return 16;
|
||||
case StandardTypes::F32:
|
||||
return 32;
|
||||
case StandardTypes::F64:
|
||||
return 64;
|
||||
default:
|
||||
llvm_unreachable("unexpected type");
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the floating semantics for the given type.
|
||||
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
|
||||
if (isBF16())
|
||||
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
|
||||
// not defined in LLVM.
|
||||
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
|
||||
// else one could add it.
|
||||
// static const fltSemantics semBF16 = {127, -126, 8, 16};
|
||||
return APFloat::IEEEdouble();
|
||||
if (isF16())
|
||||
return APFloat::IEEEhalf();
|
||||
if (isF32())
|
||||
return APFloat::IEEEsingle();
|
||||
if (isF64())
|
||||
return APFloat::IEEEdouble();
|
||||
llvm_unreachable("non-floating point type used");
|
||||
}
|
||||
|
||||
unsigned Type::getIntOrFloatBitWidth() const {
|
||||
assert(isIntOrFloat() && "only ints and floats have a bitwidth");
|
||||
if (auto intType = dyn_cast<IntegerType>()) {
|
||||
return intType.getWidth();
|
||||
}
|
||||
|
||||
auto floatType = cast<FloatType>();
|
||||
return floatType.getWidth();
|
||||
}
|
||||
|
||||
Type VectorOrTensorType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
|
||||
return getElementType().getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getNumElements() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor: {
|
||||
auto shape = getShape();
|
||||
unsigned num = 1;
|
||||
for (auto dim : shape)
|
||||
num *= dim;
|
||||
return num;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int VectorOrTensorType::getRank() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor:
|
||||
return getShape().size();
|
||||
case StandardTypes::UnrankedTensor:
|
||||
return -1;
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
case StandardTypes::RankedTensor:
|
||||
return getShape()[i];
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the number of number of bits require to store a value of the given vector
|
||||
// or tensor types. Compute the value recursively since tensors are allowed to
|
||||
// have vectors as elements.
|
||||
long VectorOrTensorType::getSizeInBits() const {
|
||||
assert(hasStaticShape() &&
|
||||
"cannot get the bit size of an aggregate with a dynamic shape");
|
||||
|
||||
auto elementType = getElementType();
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() * getNumElements();
|
||||
|
||||
// Tensors can have vectors and other tensors as elements, vectors cannot.
|
||||
assert(!isa<VectorType>() && "unsupported vector element type");
|
||||
auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
|
||||
assert(elementVectorOrTensorType && "unsupported tensor element type");
|
||||
return getNumElements() * elementVectorOrTensorType.getSizeInBits();
|
||||
}
|
||||
|
||||
ArrayRef<int> VectorOrTensorType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case StandardTypes::Vector:
|
||||
return cast<VectorType>().getShape();
|
||||
case StandardTypes::RankedTensor:
|
||||
return cast<RankedTensorType>().getShape();
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
bool VectorOrTensorType::hasStaticShape() const {
|
||||
if (isa<UnrankedTensorType>())
|
||||
return false;
|
||||
auto dims = getShape();
|
||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
|
||||
ArrayRef<int> VectorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
ArrayRef<int> RankedTensorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
ArrayRef<int> MemRefType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
Type MemRefType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||
return static_cast<ImplType *>(type)->getAffineMaps();
|
||||
}
|
||||
|
||||
unsigned MemRefType::getMemorySpace() const {
|
||||
return static_cast<ImplType *>(type)->memorySpace;
|
||||
}
|
||||
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
unsigned numDynamicDims = 0;
|
||||
for (int dimSize : getShape()) {
|
||||
if (dimSize == -1)
|
||||
++numDynamicDims;
|
||||
}
|
||||
return numDynamicDims;
|
||||
}
|
||||
|
||||
// Define type identifiers.
|
||||
char FloatType::typeID = 0;
|
||||
char IntegerType::typeID = 0;
|
||||
char VectorType::typeID = 0;
|
||||
char RankedTensorType::typeID = 0;
|
||||
char UnrankedTensorType::typeID = 0;
|
||||
char MemRefType::typeID = 0;
|
|
@ -17,11 +17,7 @@
|
|||
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "TypeDetail.h"
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::detail;
|
||||
|
@ -36,51 +32,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
|
|||
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
|
||||
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
|
||||
|
||||
unsigned IntegerType::getWidth() const {
|
||||
return static_cast<ImplType *>(type)->width;
|
||||
}
|
||||
|
||||
unsigned FloatType::getWidth() const {
|
||||
switch (getKind()) {
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
return 16;
|
||||
case Type::Kind::F32:
|
||||
return 32;
|
||||
case Type::Kind::F64:
|
||||
return 64;
|
||||
default:
|
||||
llvm_unreachable("unexpected type");
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the floating semantics for the given type.
|
||||
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
|
||||
if (isBF16())
|
||||
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
|
||||
// not defined in LLVM.
|
||||
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
|
||||
// else one could add it.
|
||||
// static const fltSemantics semBF16 = {127, -126, 8, 16};
|
||||
return APFloat::IEEEdouble();
|
||||
if (isF16())
|
||||
return APFloat::IEEEhalf();
|
||||
if (isF32())
|
||||
return APFloat::IEEEsingle();
|
||||
if (isF64())
|
||||
return APFloat::IEEEdouble();
|
||||
llvm_unreachable("non-floating point type used");
|
||||
}
|
||||
|
||||
unsigned Type::getIntOrFloatBitWidth() const {
|
||||
assert(isIntOrFloat() && "only ints and floats have a bitwidth");
|
||||
if (auto intType = dyn_cast<IntegerType>()) {
|
||||
return intType.getWidth();
|
||||
}
|
||||
|
||||
auto floatType = cast<FloatType>();
|
||||
return floatType.getWidth();
|
||||
}
|
||||
/// Function Type.
|
||||
|
||||
ArrayRef<Type> FunctionType::getInputs() const {
|
||||
return static_cast<ImplType *>(type)->getInputs();
|
||||
|
@ -94,128 +46,6 @@ ArrayRef<Type> FunctionType::getResults() const {
|
|||
return static_cast<ImplType *>(type)->getResults();
|
||||
}
|
||||
|
||||
Type VectorOrTensorType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
|
||||
return getElementType().getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
unsigned VectorOrTensorType::getNumElements() const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
case Kind::RankedTensor: {
|
||||
auto shape = getShape();
|
||||
unsigned num = 1;
|
||||
for (auto dim : shape)
|
||||
num *= dim;
|
||||
return num;
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
/// If this is ranked tensor or vector type, return the rank. If it is an
|
||||
/// unranked tensor, return -1.
|
||||
int VectorOrTensorType::getRank() const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
case Kind::RankedTensor:
|
||||
return getShape().size();
|
||||
case Kind::UnrankedTensor:
|
||||
return -1;
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType");
|
||||
}
|
||||
}
|
||||
|
||||
int VectorOrTensorType::getDimSize(unsigned i) const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
case Kind::RankedTensor:
|
||||
return getShape()[i];
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
// Get the number of number of bits require to store a value of the given vector
|
||||
// or tensor types. Compute the value recursively since tensors are allowed to
|
||||
// have vectors as elements.
|
||||
long VectorOrTensorType::getSizeInBits() const {
|
||||
assert(hasStaticShape() &&
|
||||
"cannot get the bit size of an aggregate with a dynamic shape");
|
||||
|
||||
auto elementType = getElementType();
|
||||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() * getNumElements();
|
||||
|
||||
// Tensors can have vectors and other tensors as elements, vectors cannot.
|
||||
assert(!isa<VectorType>() && "unsupported vector element type");
|
||||
auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
|
||||
assert(elementVectorOrTensorType && "unsupported tensor element type");
|
||||
return getNumElements() * elementVectorOrTensorType.getSizeInBits();
|
||||
}
|
||||
|
||||
ArrayRef<int> VectorOrTensorType::getShape() const {
|
||||
switch (getKind()) {
|
||||
case Kind::Vector:
|
||||
return cast<VectorType>().getShape();
|
||||
case Kind::RankedTensor:
|
||||
return cast<RankedTensorType>().getShape();
|
||||
default:
|
||||
llvm_unreachable("not a VectorOrTensorType or not ranked");
|
||||
}
|
||||
}
|
||||
|
||||
bool VectorOrTensorType::hasStaticShape() const {
|
||||
if (isa<UnrankedTensorType>())
|
||||
return false;
|
||||
auto dims = getShape();
|
||||
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
|
||||
}
|
||||
|
||||
ArrayRef<int> VectorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
ArrayRef<int> RankedTensorType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
ArrayRef<int> MemRefType::getShape() const {
|
||||
return static_cast<ImplType *>(type)->getShape();
|
||||
}
|
||||
|
||||
Type MemRefType::getElementType() const {
|
||||
return static_cast<ImplType *>(type)->elementType;
|
||||
}
|
||||
|
||||
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
|
||||
return static_cast<ImplType *>(type)->getAffineMaps();
|
||||
}
|
||||
|
||||
unsigned MemRefType::getMemorySpace() const {
|
||||
return static_cast<ImplType *>(type)->memorySpace;
|
||||
}
|
||||
|
||||
unsigned MemRefType::getNumDynamicDims() const {
|
||||
unsigned numDynamicDims = 0;
|
||||
for (int dimSize : getShape()) {
|
||||
if (dimSize == -1)
|
||||
++numDynamicDims;
|
||||
}
|
||||
return numDynamicDims;
|
||||
}
|
||||
|
||||
// Define type identifiers.
|
||||
char IndexType::typeID = 0;
|
||||
char FloatType::typeID = 0;
|
||||
char IntegerType::typeID = 0;
|
||||
char FunctionType::typeID = 0;
|
||||
char VectorType::typeID = 0;
|
||||
char RankedTensorType::typeID = 0;
|
||||
char UnrankedTensorType::typeID = 0;
|
||||
char MemRefType::typeID = 0;
|
||||
char IndexType::typeID = 0;
|
||||
|
|
|
@ -32,7 +32,7 @@
|
|||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
#include "mlir/Transforms/Utils.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
|
@ -711,10 +711,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
return ParseResult::ParseFailure;
|
||||
// check result matches the element type.
|
||||
switch (eltTy.getKind()) {
|
||||
case Type::Kind::BF16:
|
||||
case Type::Kind::F16:
|
||||
case Type::Kind::F32:
|
||||
case Type::Kind::F64: {
|
||||
case StandardTypes::BF16:
|
||||
case StandardTypes::F16:
|
||||
case StandardTypes::F32:
|
||||
case StandardTypes::F64: {
|
||||
// Bitcast the APFloat value to APInt and store the bit representation.
|
||||
auto fpAttrResult = result.dyn_cast<FloatAttr>();
|
||||
if (!fpAttrResult)
|
||||
|
@ -731,7 +731,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
addToStorage(apInt.getRawData()[0]);
|
||||
break;
|
||||
}
|
||||
case Type::Kind::Integer: {
|
||||
case StandardTypes::Integer: {
|
||||
if (!result.isa<IntegerAttr>())
|
||||
return p.emitError("expected tensor literal element has integer type");
|
||||
auto value = result.cast<IntegerAttr>().getValue();
|
||||
|
|
|
@ -23,7 +23,7 @@
|
|||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/Value.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
|
|
@ -25,7 +25,6 @@
|
|||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
using namespace mlir;
|
||||
|
||||
|
|
|
@ -164,13 +164,13 @@ llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
|
|||
llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
|
||||
MLIRContext *context = type.getContext();
|
||||
switch (type.getKind()) {
|
||||
case Type::Kind::F32:
|
||||
case StandardTypes::F32:
|
||||
return builder.getFloatTy();
|
||||
case Type::Kind::F64:
|
||||
case StandardTypes::F64:
|
||||
return builder.getDoubleTy();
|
||||
case Type::Kind::F16:
|
||||
case StandardTypes::F16:
|
||||
return builder.getHalfTy();
|
||||
case Type::Kind::BF16:
|
||||
case StandardTypes::BF16:
|
||||
return context->emitError(UnknownLoc::get(context),
|
||||
"unsupported type: BF16"),
|
||||
nullptr;
|
||||
|
|
|
@ -24,6 +24,7 @@
|
|||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Analysis/VectorAnalysis.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass.h"
|
||||
#include "mlir/Support/Functional.h"
|
||||
#include "mlir/Support/STLExtras.h"
|
||||
|
|
Loading…
Reference in New Issue