Split the standard types from builtin types and move them into separate source files(StandardTypes.cpp/h). After this cl only FunctionType and IndexType are builtin types, but IndexType will likely become a standard type when the ml/cfgfunc merger is done. Mechanical NFC.

PiperOrigin-RevId: 227750918
This commit is contained in:
River Riddle 2019-01-03 14:29:52 -08:00 committed by jpienaar
parent ae1a6619df
commit 54948a4380
20 changed files with 648 additions and 586 deletions

View File

@ -32,6 +32,7 @@ class Type;
class PrimitiveType;
class IntegerType;
class FunctionType;
class MemRefType;
class VectorType;
class RankedTensorType;
class UnrankedTensorType;

View File

@ -20,6 +20,7 @@
//
//===----------------------------------------------------------------------===//
DEFINE_TYPE_KIND_RANGE(STANDARD)
DEFINE_TYPE_KIND_RANGE(TENSORFLOW_CONTROL)
DEFINE_TYPE_KIND_RANGE(TENSORFLOW)

View File

@ -26,7 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include <type_traits>

View File

@ -0,0 +1,384 @@
//===- StandardTypes.h - MLIR Standard Type Classes -------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_IR_STANDARDTYPES_H
#define MLIR_IR_STANDARDTYPES_H
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
namespace llvm {
class fltSemantics;
} // namespace llvm
namespace mlir {
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class Location;
class MLIRContext;
namespace detail {
struct IntegerTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
namespace StandardTypes {
enum Kind {
// Floating point.
BF16 = Type::Kind::FIRST_STANDARD_TYPE,
F16,
F32,
F64,
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Derived types.
Integer,
Vector,
RankedTensor,
UnrankedTensor,
MemRef,
};
} // namespace StandardTypes
inline bool Type::isBF16() const { return getKind() == StandardTypes::BF16; }
inline bool Type::isF16() const { return getKind() == StandardTypes::F16; }
inline bool Type::isF32() const { return getKind() == StandardTypes::F32; }
inline bool Type::isF64() const { return getKind() == StandardTypes::F64; }
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType : public Type {
public:
using ImplType = detail::IntegerTypeStorage;
using Type::Type;
/// Get or create a new IntegerType of the given width within the context.
/// Assume the width is within the allowed range and assert on failures.
/// Use getChecked to handle failures gracefully.
static IntegerType get(unsigned width, MLIRContext *context);
/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. If the width is
/// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(unsigned width, MLIRContext *context,
Location location);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; }
/// Unique identifier for this type class.
static char typeID;
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
};
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
/// Return true if this is an integer type with the specified width.
inline bool Type::isInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
return intTy.getWidth() == width;
return false;
}
inline bool Type::isIntOrIndex() const {
return isa<IndexType>() || isa<IntegerType>();
}
inline bool Type::isIntOrIndexOrFloat() const {
return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
}
inline bool Type::isIntOrFloat() const {
return isa<IntegerType>() || isa<FloatType>();
}
class FloatType : public Type {
public:
using Type::Type;
static FloatType get(StandardTypes::Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
kind <= StandardTypes::LAST_FLOATING_POINT_TYPE;
}
/// Return the bitwidth of this float type.
unsigned getWidth() const;
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics() const;
/// Unique identifier for this type class.
static char typeID;
};
inline FloatType Type::getBF16(MLIRContext *ctx) {
return FloatType::get(StandardTypes::BF16, ctx);
}
inline FloatType Type::getF16(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F16, ctx);
}
inline FloatType Type::getF32(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F32, ctx);
}
inline FloatType Type::getF64(MLIRContext *ctx) {
return FloatType::get(StandardTypes::F64, ctx);
}
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
/// types, because many operations work on values of these aggregate types.
class VectorOrTensorType : public Type {
public:
using ImplType = detail::VectorOrTensorTypeStorage;
using Type::Type;
/// Return the element type.
Type getElementType() const;
/// If an element type is an integer or a float, return its width. Abort
/// otherwise.
unsigned getElementTypeBitWidth() const;
/// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor, abort.
unsigned getNumElements() const;
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int getRank() const;
/// If this is ranked tensor or vector type, return the shape. If it is an
/// unranked tensor, abort.
ArrayRef<int> getShape() const;
/// If this is unranked tensor or any dimension has unknown size (<0),
/// it doesn't have static shape. If all dimensions have known size (>= 0),
/// it has static shape.
bool hasStaticShape() const;
/// If this is ranked tensor or vector type, return the size of the specified
/// dimension. It aborts if the tensor is unranked (this can be checked by
/// the getRank call method).
int getDimSize(unsigned i) const;
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
/// if the size cannot be computed statically, i.e. if the tensor has a
/// dynamic shape or if its elemental type does not have a known bit width.
long getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardTypes::Vector ||
kind == StandardTypes::RankedTensor ||
kind == StandardTypes::UnrankedTensor;
}
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType {
public:
using ImplType = detail::VectorTypeStorage;
using VectorOrTensorType::VectorOrTensorType;
/// Get or create a new VectorType of the provided shape and element type.
/// Assumes the arguments define a well-formed VectorType.
static VectorType get(ArrayRef<int> shape, Type elementType);
/// Get or create a new VectorType of the provided shape and element type
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
static VectorType getChecked(ArrayRef<int> shape, Type elementType,
Location location);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
ArrayRef<int> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
/// Unique identifier for this type class.
static char typeID;
};
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public VectorOrTensorType {
public:
using ImplType = detail::TensorTypeStorage;
using VectorOrTensorType::VectorOrTensorType;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type) {
// Note: Non standard/builtin types are allowed to exist within tensor
// types. Dialects are expected to verify that tensor types have a valid
// element type within that dialect.
return type.isIntOrFloat() || type.isa<VectorType>() ||
(type.getKind() >= Type::Kind::LAST_STANDARD_TYPE);
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == StandardTypes::RankedTensor ||
kind == StandardTypes::UnrankedTensor;
}
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
/// with a fixed number of dimensions. Each shape element can be a positive
/// integer or unknown (represented -1).
class RankedTensorType : public TensorType {
public:
using ImplType = detail::RankedTensorTypeStorage;
using TensorType::TensorType;
/// Get or create a new RankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
/// Get or create a new RankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type.
static RankedTensorType getChecked(ArrayRef<int> shape, Type elementType,
Location location);
ArrayRef<int> getShape() const;
static bool kindof(unsigned kind) {
return kind == StandardTypes::RankedTensor;
}
/// Unique identifier for this type class.
static char typeID;
};
/// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape.
class UnrankedTensorType : public TensorType {
public:
using ImplType = detail::UnrankedTensorTypeStorage;
using TensorType::TensorType;
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static UnrankedTensorType get(Type elementType);
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type.
static UnrankedTensorType getChecked(Type elementType, Location location);
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool kindof(unsigned kind) {
return kind == StandardTypes::UnrankedTensor;
}
/// Unique identifier for this type class.
static char typeID;
};
/// MemRef types represent a region of memory that have a shape with a fixed
/// number of dimensions. Each shape element can be a positive integer or
/// unknown (represented by any negative integer). MemRef types also have an
/// affine map composition, represented as an array AffineMap pointers.
class MemRefType : public Type {
public:
using ImplType = detail::MemRefTypeStorage;
using Type::Type;
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space. Assumes the arguments define a
/// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
/// construction failures.
static MemRefType get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space declared at the given location.
/// If the location is unknown, the last argument should be an instance of
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
static MemRefType getChecked(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Location location);
unsigned getRank() const { return getShape().size(); }
/// Returns an array of memref shape dimension sizes.
ArrayRef<int> getShape() const;
/// Return the size of the specified dimension, or -1 if unspecified.
int getDimSize(unsigned i) const { return getShape()[i]; }
/// Returns the elemental type for this memref shape.
Type getElementType() const;
/// Returns an array of affine map pointers representing the memref affine
/// map composition.
ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
/// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const;
static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }
/// Unique identifier for this type class.
static char typeID;
private:
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Optional<Location> location);
};
} // end namespace mlir
#endif // MLIR_IR_STANDARDTYPES_H

View File

@ -23,12 +23,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMapInfo.h"
namespace llvm {
class fltSemantics;
} // namespace llvm
namespace mlir {
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
@ -36,6 +31,7 @@ class Location;
class MLIRContext;
namespace detail {
struct FunctionTypeStorage;
struct TypeStorage;
} // namespace detail
@ -96,39 +92,13 @@ public:
/// enum class also simplifies the handling of type kinds by not requiring
/// casts for each use.
enum Kind {
// Builtin types.
Function,
// TODO(riverriddle) Index shouldn't really be a builtin.
// Target pointer sized integer, used (e.g.) in affine mappings.
Index,
// TensorFlow types.
TFControl,
TFResource,
TFVariant,
TFComplex64,
TFComplex128,
TFF32REF,
TFString,
/// These are marker for the first and last 'other' type.
FIRST_OTHER_TYPE = TFControl,
LAST_OTHER_TYPE = TFString,
// Floating point.
BF16,
F16,
F32,
F64,
FIRST_FLOATING_POINT_TYPE = BF16,
LAST_FLOATING_POINT_TYPE = F64,
// Derived types.
Integer,
Function,
Vector,
RankedTensor,
UnrankedTensor,
MemRef,
LAST_BUILTIN_TYPE = 0xff,
LAST_BUILTIN_TYPE = Index,
// Reserve type kinds for dialect specific type system extensions.
#define DEFINE_TYPE_KIND_RANGE(Dialect) \
@ -224,135 +194,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Type type) {
return os;
}
/// Standard Type Utilities.
namespace detail {
struct IntegerTypeStorage;
struct FunctionTypeStorage;
struct VectorOrTensorTypeStorage;
struct VectorTypeStorage;
struct TensorTypeStorage;
struct RankedTensorTypeStorage;
struct UnrankedTensorTypeStorage;
struct MemRefTypeStorage;
} // namespace detail
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
inline bool Type::isBF16() const { return getKind() == Kind::BF16; }
inline bool Type::isF16() const { return getKind() == Kind::F16; }
inline bool Type::isF32() const { return getKind() == Kind::F32; }
inline bool Type::isF64() const { return getKind() == Kind::F64; }
/// Integer types can have arbitrary bitwidth up to a large fixed limit.
class IntegerType : public Type {
public:
using ImplType = detail::IntegerTypeStorage;
using Type::Type;
/// Get or create a new IntegerType of the given width within the context.
/// Assume the width is within the allowed range and assert on failures.
/// Use getChecked to handle failures gracefully.
static IntegerType get(unsigned width, MLIRContext *context);
/// Get or create a new IntegerType of the given width within the context,
/// defined at the given, potentially unknown, location. If the width is
/// outside the allowed range, emit errors and return a null type.
static IntegerType getChecked(unsigned width, MLIRContext *context,
Location location);
/// Return the bitwidth of this integer type.
unsigned getWidth() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == Kind::Integer; }
/// Unique identifier for this type class.
static char typeID;
/// Integer representation maximal bitwidth.
static constexpr unsigned kMaxWidth = 4096;
};
inline IntegerType Type::getInteger(unsigned width, MLIRContext *ctx) {
return IntegerType::get(width, ctx);
}
/// Return true if this is an integer type with the specified width.
inline bool Type::isInteger(unsigned width) const {
if (auto intTy = dyn_cast<IntegerType>())
return intTy.getWidth() == width;
return false;
}
inline bool Type::isIntOrIndex() const {
return isa<IndexType>() || isa<IntegerType>();
}
inline bool Type::isIntOrIndexOrFloat() const {
return isa<IndexType>() || isa<IntegerType>() || isa<FloatType>();
}
inline bool Type::isIntOrFloat() const {
return isa<IntegerType>() || isa<FloatType>();
}
class FloatType : public Type {
public:
using Type::Type;
static FloatType get(Kind kind, MLIRContext *context);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE;
}
/// Return the bitwidth of this float type.
unsigned getWidth() const;
/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics() const;
/// Unique identifier for this type class.
static char typeID;
};
inline FloatType Type::getBF16(MLIRContext *ctx) {
return FloatType::get(Kind::BF16, ctx);
}
inline FloatType Type::getF16(MLIRContext *ctx) {
return FloatType::get(Kind::F16, ctx);
}
inline FloatType Type::getF32(MLIRContext *ctx) {
return FloatType::get(Kind::F32, ctx);
}
inline FloatType Type::getF64(MLIRContext *ctx) {
return FloatType::get(Kind::F64, ctx);
}
/// Index is special integer-like type with unknown platform-dependent bit width
/// used in subscripts and loop induction variables.
class IndexType : public Type {
public:
using Type::Type;
/// Crete an IndexType instance, unique in the given context.
static IndexType get(MLIRContext *context);
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == Kind::Index; }
/// Unique identifier for this type class.
static char typeID;
};
inline IndexType Type::getIndex(MLIRContext *ctx) {
return IndexType::get(ctx);
}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
public:
@ -383,222 +224,27 @@ public:
static char typeID;
};
/// This is a common base class between Vector, UnrankedTensor, and RankedTensor
/// types, because many operations work on values of these aggregate types.
class VectorOrTensorType : public Type {
inline bool Type::isIndex() const { return getKind() == Kind::Index; }
/// Index is special integer-like type with unknown platform-dependent bit width
/// used in subscripts and loop induction variables.
class IndexType : public Type {
public:
using ImplType = detail::VectorOrTensorTypeStorage;
using Type::Type;
/// Return the element type.
Type getElementType() const;
/// Crete an IndexType instance, unique in the given context.
static IndexType get(MLIRContext *context);
/// If an element type is an integer or a float, return its width. Abort
/// otherwise.
unsigned getElementTypeBitWidth() const;
/// If this is ranked tensor or vector type, return the number of elements. If
/// it is an unranked tensor, abort.
unsigned getNumElements() const;
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int getRank() const;
/// If this is ranked tensor or vector type, return the shape. If it is an
/// unranked tensor, abort.
ArrayRef<int> getShape() const;
/// If this is unranked tensor or any dimension has unknown size (<0),
/// it doesn't have static shape. If all dimensions have known size (>= 0),
/// it has static shape.
bool hasStaticShape() const;
/// If this is ranked tensor or vector type, return the size of the specified
/// dimension. It aborts if the tensor is unranked (this can be checked by
/// the getRank call method).
int getDimSize(unsigned i) const;
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
/// if the size cannot be computed statically, i.e. if the tensor has a
/// dynamic shape or if its elemental type does not have a known bit width.
long getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == Kind::Vector || kind == Kind::RankedTensor ||
kind == Kind::UnrankedTensor;
}
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType : public VectorOrTensorType {
public:
using ImplType = detail::VectorTypeStorage;
using VectorOrTensorType::VectorOrTensorType;
/// Get or create a new VectorType of the provided shape and element type.
/// Assumes the arguments define a well-formed VectorType.
static VectorType get(ArrayRef<int> shape, Type elementType);
/// Get or create a new VectorType of the provided shape and element type
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
static VectorType getChecked(ArrayRef<int> shape, Type elementType,
Location location);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
ArrayRef<int> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == Kind::Vector; }
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == Kind::Index; }
/// Unique identifier for this type class.
static char typeID;
};
/// Tensor types represent multi-dimensional arrays, and have two variants:
/// RankedTensorType and UnrankedTensorType.
class TensorType : public VectorOrTensorType {
public:
using ImplType = detail::TensorTypeStorage;
using VectorOrTensorType::VectorOrTensorType;
/// Return true if the specified element type is ok in a tensor.
static bool isValidElementType(Type type) {
// TODO(riverriddle): TensorFlow types are currently considered valid for
// legacy reasons.
return type.isIntOrFloat() || type.isa<VectorType>() ||
(type.getKind() >=
static_cast<unsigned>(Kind::FIRST_TENSORFLOW_TYPE) &&
type.getKind() <=
static_cast<unsigned>(Kind::LAST_TENSORFLOW_TYPE));
}
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) {
return kind == Kind::RankedTensor || kind == Kind::UnrankedTensor;
}
};
/// Ranked tensor types represent multi-dimensional arrays that have a shape
/// with a fixed number of dimensions. Each shape element can be a positive
/// integer or unknown (represented -1).
class RankedTensorType : public TensorType {
public:
using ImplType = detail::RankedTensorTypeStorage;
using TensorType::TensorType;
/// Get or create a new RankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static RankedTensorType get(ArrayRef<int> shape, Type elementType);
/// Get or create a new RankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// RankedTensorType defined by the arguments would be ill-formed, emit errors
/// and return a nullptr-wrapping type.
static RankedTensorType getChecked(ArrayRef<int> shape, Type elementType,
Location location);
ArrayRef<int> getShape() const;
static bool kindof(unsigned kind) { return kind == Kind::RankedTensor; }
/// Unique identifier for this type class.
static char typeID;
};
/// Unranked tensor types represent multi-dimensional arrays that have an
/// unknown shape.
class UnrankedTensorType : public TensorType {
public:
using ImplType = detail::UnrankedTensorTypeStorage;
using TensorType::TensorType;
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type. Assumes the arguments define a well-formed type.
static UnrankedTensorType get(Type elementType);
/// Get or create a new UnrankedTensorType of the provided shape and element
/// type declared at the given, potentially unknown, location. If the
/// UnrankedTensorType defined by the arguments would be ill-formed, emit
/// errors and return a nullptr-wrapping type.
static UnrankedTensorType getChecked(Type elementType, Location location);
ArrayRef<int> getShape() const { return ArrayRef<int>(); }
static bool kindof(unsigned kind) { return kind == Kind::UnrankedTensor; }
/// Unique identifier for this type class.
static char typeID;
};
/// MemRef types represent a region of memory that have a shape with a fixed
/// number of dimensions. Each shape element can be a positive integer or
/// unknown (represented by any negative integer). MemRef types also have an
/// affine map composition, represented as an array AffineMap pointers.
class MemRefType : public Type {
public:
using ImplType = detail::MemRefTypeStorage;
using Type::Type;
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space. Assumes the arguments define a
/// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
/// construction failures.
static MemRefType get(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace);
/// Get or create a new MemRefType based on shape, element type, affine
/// map composition, and memory space declared at the given location.
/// If the location is unknown, the last argument should be an instance of
/// UnknownLoc. If the MemRefType defined by the arguments would be
/// ill-formed, emits errors (to the handler registered with the context or to
/// the error stream) and returns nullptr.
static MemRefType getChecked(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Location location);
unsigned getRank() const { return getShape().size(); }
/// Returns an array of memref shape dimension sizes.
ArrayRef<int> getShape() const;
/// Return the size of the specified dimension, or -1 if unspecified.
int getDimSize(unsigned i) const { return getShape()[i]; }
/// Returns the elemental type for this memref shape.
Type getElementType() const;
/// Returns an array of affine map pointers representing the memref affine
/// map composition.
ArrayRef<AffineMap> getAffineMaps() const;
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;
/// Returns the number of dimensions with dynamic size.
unsigned getNumDynamicDims() const;
static bool kindof(unsigned kind) { return kind == Kind::MemRef; }
/// Unique identifier for this type class.
static char typeID;
private:
static MemRefType getSafe(ArrayRef<int> shape, Type elementType,
ArrayRef<AffineMap> affineMapComposition,
unsigned memorySpace, Optional<Location> location);
};
inline IndexType Type::getIndex(MLIRContext *ctx) {
return IndexType::get(ctx);
}
// Make Type hashable.
inline ::llvm::hash_code hash_value(Type arg) {

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class AffineMap;

View File

@ -26,6 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/StandardTypes.h"
namespace mlir {

View File

@ -30,7 +30,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
@ -497,20 +497,20 @@ void ModulePrinter::printType(Type type) {
case Type::Kind::Index:
os << "index";
return;
case Type::Kind::BF16:
case StandardTypes::BF16:
os << "bf16";
return;
case Type::Kind::F16:
case StandardTypes::F16:
os << "f16";
return;
case Type::Kind::F32:
case StandardTypes::F32:
os << "f32";
return;
case Type::Kind::F64:
case StandardTypes::F64:
os << "f64";
return;
case Type::Kind::Integer: {
case StandardTypes::Integer: {
auto integer = type.cast<IntegerType>();
os << 'i' << integer.getWidth();
return;
@ -530,7 +530,7 @@ void ModulePrinter::printType(Type type) {
}
return;
}
case Type::Kind::Vector: {
case StandardTypes::Vector: {
auto v = type.cast<VectorType>();
os << "vector<";
for (auto dim : v.getShape())
@ -538,7 +538,7 @@ void ModulePrinter::printType(Type type) {
os << v.getElementType() << '>';
return;
}
case Type::Kind::RankedTensor: {
case StandardTypes::RankedTensor: {
auto v = type.cast<RankedTensorType>();
os << "tensor<";
for (auto dim : v.getShape()) {
@ -551,14 +551,14 @@ void ModulePrinter::printType(Type type) {
os << v.getElementType() << '>';
return;
}
case Type::Kind::UnrankedTensor: {
case StandardTypes::UnrankedTensor: {
auto v = type.cast<UnrankedTensorType>();
os << "tensor<*x";
printType(v.getElementType());
os << '>';
return;
}
case Type::Kind::MemRef: {
case StandardTypes::MemRef: {
auto v = type.cast<MemRefType>();
os << "memref<";
for (auto dim : v.getShape()) {

View File

@ -26,7 +26,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/TrailingObjects.h"

View File

@ -22,7 +22,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
Builder::Builder(Module *module) : context(module->getContext()) {}

View File

@ -20,7 +20,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"

View File

@ -767,7 +767,7 @@ void TypeUniquer::insert(unsigned hashValue, TypeStorage *storage) {
/// Construct a unique instance of `DerivedType` of the given `kind` in the
/// given `context` by passing `args` to the type storage.
template <typename DerivedType, typename... Args>
static DerivedType constructUniqueType(MLIRContext *context, Type::Kind kind,
static DerivedType constructUniqueType(MLIRContext *context, unsigned kind,
Args... args) {
return TypeUniquer(context).get<DerivedType>(static_cast<unsigned>(kind),
args...);
@ -787,7 +787,8 @@ static IntegerType getIntegerType(unsigned width, MLIRContext *context,
return {};
}
return constructUniqueType<IntegerType>(context, Type::Kind::Integer, width);
return constructUniqueType<IntegerType>(context, StandardTypes::Integer,
width);
}
IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context,
@ -801,9 +802,10 @@ IntegerType IntegerType::get(unsigned width, MLIRContext *context) {
return type;
}
FloatType FloatType::get(Type::Kind kind, MLIRContext *context) {
assert(kind >= Kind::FIRST_FLOATING_POINT_TYPE &&
kind <= Kind::LAST_FLOATING_POINT_TYPE && "Not an FP type kind");
FloatType FloatType::get(StandardTypes::Kind kind, MLIRContext *context) {
assert(kind >= StandardTypes::FIRST_FLOATING_POINT_TYPE &&
kind <= StandardTypes::LAST_FLOATING_POINT_TYPE &&
"Not an FP type kind");
return constructUniqueType<FloatType>(context, kind);
}
@ -841,7 +843,7 @@ static VectorType getVectorType(ArrayRef<int> shape, Type elementType,
return {};
}
return constructUniqueType<VectorType>(context, Type::Kind::Vector, shape,
return constructUniqueType<VectorType>(context, StandardTypes::Vector, shape,
elementType);
}
@ -883,7 +885,7 @@ static RankedTensorType getRankedTensorType(ArrayRef<int> shape,
auto *context = elementType.getContext();
return constructUniqueType<RankedTensorType>(
context, Type::Kind::RankedTensor, shape, elementType);
context, StandardTypes::RankedTensor, shape, elementType);
}
RankedTensorType RankedTensorType::get(ArrayRef<int> shape, Type elementType) {
@ -909,7 +911,7 @@ static UnrankedTensorType getUnrankedTensorType(Type elementType,
auto *context = elementType.getContext();
return constructUniqueType<UnrankedTensorType>(
context, Type::Kind::UnrankedTensor, elementType);
context, StandardTypes::UnrankedTensor, elementType);
}
UnrankedTensorType UnrankedTensorType::get(Type elementType) {
@ -975,7 +977,7 @@ static MemRefType getMemRefType(ArrayRef<int> shape, Type elementType,
}
affineMapComposition = cleanedAffineMapComposition;
return constructUniqueType<MemRefType>(context, Type::Kind::MemRef, shape,
return constructUniqueType<MemRefType>(context, StandardTypes::MemRef, shape,
elementType, affineMapComposition,
memorySpace);
}
@ -1343,10 +1345,10 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
// Otherwise, allocate a new one, unique it and return it.
auto eltType = type.getElementType();
switch (eltType.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64: {
auto *result = impl.allocator.Allocate<DenseFPElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);
std::uninitialized_copy(data.begin(), data.end(), copy);
@ -1356,7 +1358,7 @@ DenseElementsAttr DenseElementsAttr::get(VectorOrTensorType type,
{copy, data.size()}}};
return *existing.first = result;
}
case Type::Kind::Integer: {
case StandardTypes::Integer: {
auto width = eltType.cast<IntegerType>().getWidth();
auto *result = impl.allocator.Allocate<DenseIntElementsAttributeStorage>();
auto *copy = (char *)impl.allocator.Allocate(data.size(), 64);

View File

@ -21,6 +21,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
/// Form the OperationName for an op with the specified string. This either is
@ -293,8 +294,7 @@ bool OpTrait::impl::verifyIsTerminator(const OperationInst *op) {
bool OpTrait::impl::verifyResultsAreBoolLike(const OperationInst *op) {
for (auto *result : op->getResults()) {
auto elementType = getTensorOrVectorElementType(result->getType());
auto intType = elementType.dyn_cast<IntegerType>();
bool isBoolType = intType && intType.getWidth() == 1;
bool isBoolType = elementType.isInteger(1);
if (!isBoolType)
return op->emitOpError("requires a bool result type");
}

View File

@ -0,0 +1,196 @@
//===- Types.cpp - MLIR Type Classes --------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include "mlir/IR/StandardTypes.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
unsigned IntegerType::getWidth() const {
return static_cast<ImplType *>(type)->width;
}
unsigned FloatType::getWidth() const {
switch (getKind()) {
case StandardTypes::BF16:
case StandardTypes::F16:
return 16;
case StandardTypes::F32:
return 32;
case StandardTypes::F64:
return 64;
default:
llvm_unreachable("unexpected type");
}
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
if (isBF16())
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
// not defined in LLVM.
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
// else one could add it.
// static const fltSemantics semBF16 = {127, -126, 8, 16};
return APFloat::IEEEdouble();
if (isF16())
return APFloat::IEEEhalf();
if (isF32())
return APFloat::IEEEsingle();
if (isF64())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only ints and floats have a bitwidth");
if (auto intType = dyn_cast<IntegerType>()) {
return intType.getWidth();
}
auto floatType = cast<FloatType>();
return floatType.getWidth();
}
Type VectorOrTensorType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
return getElementType().getIntOrFloatBitWidth();
}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
case StandardTypes::Vector:
case StandardTypes::RankedTensor: {
auto shape = getShape();
unsigned num = 1;
for (auto dim : shape)
num *= dim;
return num;
}
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int VectorOrTensorType::getRank() const {
switch (getKind()) {
case StandardTypes::Vector:
case StandardTypes::RankedTensor:
return getShape().size();
case StandardTypes::UnrankedTensor:
return -1;
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
int VectorOrTensorType::getDimSize(unsigned i) const {
switch (getKind()) {
case StandardTypes::Vector:
case StandardTypes::RankedTensor:
return getShape()[i];
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
// Get the number of number of bits require to store a value of the given vector
// or tensor types. Compute the value recursively since tensors are allowed to
// have vectors as elements.
long VectorOrTensorType::getSizeInBits() const {
assert(hasStaticShape() &&
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, vectors cannot.
assert(!isa<VectorType>() && "unsupported vector element type");
auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
assert(elementVectorOrTensorType && "unsupported tensor element type");
return getNumElements() * elementVectorOrTensorType.getSizeInBits();
}
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case StandardTypes::Vector:
return cast<VectorType>().getShape();
case StandardTypes::RankedTensor:
return cast<RankedTensorType>().getShape();
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
bool VectorOrTensorType::hasStaticShape() const {
if (isa<UnrankedTensorType>())
return false;
auto dims = getShape();
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
ArrayRef<int> VectorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
ArrayRef<int> RankedTensorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
ArrayRef<int> MemRefType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
Type MemRefType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
return static_cast<ImplType *>(type)->getAffineMaps();
}
unsigned MemRefType::getMemorySpace() const {
return static_cast<ImplType *>(type)->memorySpace;
}
unsigned MemRefType::getNumDynamicDims() const {
unsigned numDynamicDims = 0;
for (int dimSize : getShape()) {
if (dimSize == -1)
++numDynamicDims;
}
return numDynamicDims;
}
// Define type identifiers.
char FloatType::typeID = 0;
char IntegerType::typeID = 0;
char VectorType::typeID = 0;
char RankedTensorType::typeID = 0;
char UnrankedTensorType::typeID = 0;
char MemRefType::typeID = 0;

View File

@ -17,11 +17,7 @@
#include "mlir/IR/Types.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace mlir::detail;
@ -36,51 +32,7 @@ MLIRContext *Type::getContext() const { return getDialect().getContext(); }
unsigned Type::getSubclassData() const { return type->getSubclassData(); }
void Type::setSubclassData(unsigned val) { type->setSubclassData(val); }
unsigned IntegerType::getWidth() const {
return static_cast<ImplType *>(type)->width;
}
unsigned FloatType::getWidth() const {
switch (getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
return 16;
case Type::Kind::F32:
return 32;
case Type::Kind::F64:
return 64;
default:
llvm_unreachable("unexpected type");
}
}
/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() const {
if (isBF16())
// Treat BF16 like a double. This is unfortunate but BF16 fltSemantics is
// not defined in LLVM.
// TODO(jpienaar): add BF16 to LLVM? fltSemantics are internal to APFloat.cc
// else one could add it.
// static const fltSemantics semBF16 = {127, -126, 8, 16};
return APFloat::IEEEdouble();
if (isF16())
return APFloat::IEEEhalf();
if (isF32())
return APFloat::IEEEsingle();
if (isF64())
return APFloat::IEEEdouble();
llvm_unreachable("non-floating point type used");
}
unsigned Type::getIntOrFloatBitWidth() const {
assert(isIntOrFloat() && "only ints and floats have a bitwidth");
if (auto intType = dyn_cast<IntegerType>()) {
return intType.getWidth();
}
auto floatType = cast<FloatType>();
return floatType.getWidth();
}
/// Function Type.
ArrayRef<Type> FunctionType::getInputs() const {
return static_cast<ImplType *>(type)->getInputs();
@ -94,128 +46,6 @@ ArrayRef<Type> FunctionType::getResults() const {
return static_cast<ImplType *>(type)->getResults();
}
Type VectorOrTensorType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
unsigned VectorOrTensorType::getElementTypeBitWidth() const {
return getElementType().getIntOrFloatBitWidth();
}
unsigned VectorOrTensorType::getNumElements() const {
switch (getKind()) {
case Kind::Vector:
case Kind::RankedTensor: {
auto shape = getShape();
unsigned num = 1;
for (auto dim : shape)
num *= dim;
return num;
}
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
/// If this is ranked tensor or vector type, return the rank. If it is an
/// unranked tensor, return -1.
int VectorOrTensorType::getRank() const {
switch (getKind()) {
case Kind::Vector:
case Kind::RankedTensor:
return getShape().size();
case Kind::UnrankedTensor:
return -1;
default:
llvm_unreachable("not a VectorOrTensorType");
}
}
int VectorOrTensorType::getDimSize(unsigned i) const {
switch (getKind()) {
case Kind::Vector:
case Kind::RankedTensor:
return getShape()[i];
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
// Get the number of number of bits require to store a value of the given vector
// or tensor types. Compute the value recursively since tensors are allowed to
// have vectors as elements.
long VectorOrTensorType::getSizeInBits() const {
assert(hasStaticShape() &&
"cannot get the bit size of an aggregate with a dynamic shape");
auto elementType = getElementType();
if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, vectors cannot.
assert(!isa<VectorType>() && "unsupported vector element type");
auto elementVectorOrTensorType = elementType.dyn_cast<VectorOrTensorType>();
assert(elementVectorOrTensorType && "unsupported tensor element type");
return getNumElements() * elementVectorOrTensorType.getSizeInBits();
}
ArrayRef<int> VectorOrTensorType::getShape() const {
switch (getKind()) {
case Kind::Vector:
return cast<VectorType>().getShape();
case Kind::RankedTensor:
return cast<RankedTensorType>().getShape();
default:
llvm_unreachable("not a VectorOrTensorType or not ranked");
}
}
bool VectorOrTensorType::hasStaticShape() const {
if (isa<UnrankedTensorType>())
return false;
auto dims = getShape();
return !std::any_of(dims.begin(), dims.end(), [](int i) { return i < 0; });
}
ArrayRef<int> VectorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
ArrayRef<int> RankedTensorType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
ArrayRef<int> MemRefType::getShape() const {
return static_cast<ImplType *>(type)->getShape();
}
Type MemRefType::getElementType() const {
return static_cast<ImplType *>(type)->elementType;
}
ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
return static_cast<ImplType *>(type)->getAffineMaps();
}
unsigned MemRefType::getMemorySpace() const {
return static_cast<ImplType *>(type)->memorySpace;
}
unsigned MemRefType::getNumDynamicDims() const {
unsigned numDynamicDims = 0;
for (int dimSize : getShape()) {
if (dimSize == -1)
++numDynamicDims;
}
return numDynamicDims;
}
// Define type identifiers.
char IndexType::typeID = 0;
char FloatType::typeID = 0;
char IntegerType::typeID = 0;
char FunctionType::typeID = 0;
char VectorType::typeID = 0;
char RankedTensorType::typeID = 0;
char UnrankedTensorType::typeID = 0;
char MemRefType::typeID = 0;
char IndexType::typeID = 0;

View File

@ -32,7 +32,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/STLExtras.h"
#include "mlir/Transforms/Utils.h"
#include "llvm/ADT/APInt.h"
@ -711,10 +711,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
return ParseResult::ParseFailure;
// check result matches the element type.
switch (eltTy.getKind()) {
case Type::Kind::BF16:
case Type::Kind::F16:
case Type::Kind::F32:
case Type::Kind::F64: {
case StandardTypes::BF16:
case StandardTypes::F16:
case StandardTypes::F32:
case StandardTypes::F64: {
// Bitcast the APFloat value to APInt and store the bit representation.
auto fpAttrResult = result.dyn_cast<FloatAttr>();
if (!fpAttrResult)
@ -731,7 +731,7 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
addToStorage(apInt.getRawData()[0]);
break;
}
case Type::Kind::Integer: {
case StandardTypes::Integer: {
if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type");
auto value = result.cast<IntegerAttr>().getValue();

View File

@ -23,7 +23,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Support/STLExtras.h"

View File

@ -25,7 +25,6 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;

View File

@ -164,13 +164,13 @@ llvm::IntegerType *ModuleLowerer::convertIntegerType(IntegerType type) {
llvm::Type *ModuleLowerer::convertFloatType(FloatType type) {
MLIRContext *context = type.getContext();
switch (type.getKind()) {
case Type::Kind::F32:
case StandardTypes::F32:
return builder.getFloatTy();
case Type::Kind::F64:
case StandardTypes::F64:
return builder.getDoubleTy();
case Type::Kind::F16:
case StandardTypes::F16:
return builder.getHalfTy();
case Type::Kind::BF16:
case StandardTypes::BF16:
return context->emitError(UnknownLoc::get(context),
"unsupported type: BF16"),
nullptr;

View File

@ -24,6 +24,7 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/VectorAnalysis.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Pass.h"
#include "mlir/Support/Functional.h"
#include "mlir/Support/STLExtras.h"